guava中topK算法实现解析

对于topK问题,我们首先想到的是通过最小(大)堆实现,jdk也提供了相关的实现PriorityQueue,可以很方便地实现,建堆时间复杂度O(n),实现原理不赘述,同时guava也提供了相关的工具,在Ordering中提供了leastOf和对应的greatestof方法,获取集合中的最小或最大的k个元素,实现原理和优先队列不一样。可以看到guava的设计者更注重工程实践的需要。

public <E extends T> List<E> leastOf(Iterator<E> elements, int k) {
   checkNotNull(elements);
   checkNonnegative(k, "k");

   if (k == 0 || !elements.hasNext()) {
     return ImmutableList.of();
   } else if (k >= Integer.MAX_VALUE / 2) {
     // k is really large; just do a straightforward sorted-copy-and-sublist
     ArrayList<E> list = Lists.newArrayList(elements);
     Collections.sort(list, this);
     if (list.size() > k) {
       list.subList(k, list.size()).clear();
     }
     list.trimToSize();
     return Collections.unmodifiableList(list);
   }
   // and then ......
 }

可以看到,有一些防御性的判断,如果k >= Integer.MAX_VALUE / 2,直接排序然后取子序列,因为下面的算法需要分配2*k的数组,如果k太大,还不如直接排序。

    int bufferCap = k * 2;
    @SuppressWarnings("unchecked") // we'll only put E's in
    E[] buffer = (E[]) new Object[bufferCap];
    E threshold = elements.next();
    buffer[0] = threshold;
    int bufferSize = 1;
    // threshold is the kth smallest element seen so far.  Once bufferSize >= k,
    // anything larger than threshold can be ignored immediately.

    while (bufferSize < k && elements.hasNext()) {
      E e = elements.next();
      buffer[bufferSize++] = e;
      threshold = max(threshold, e);
    }

然后分配2*k长度的数组,先填充k个元素,threshold记录了当前最小的k个元素的上限。

      E e = elements.next();
      if (compare(e, threshold) >= 0) {
        continue;
      }

      buffer[bufferSize++] = e;

k+1个元素开始,根据threshold的定义可以知道,如果大于threshold,肯定不是topK的,直接抛弃。

if (bufferSize == bufferCap) {
        // We apply the quickselect algorithm to partition about the median,
        // and then ignore the last k elements.
        int left = 0;
        int right = bufferCap - 1;

        int minThresholdPosition = 0;
        // The leftmost position at which the greatest of the k lower elements
        // -- the new value of threshold -- might be found.

        while (left < right) {
          int pivotIndex = (left + right + 1) >>> 1;
          int pivotNewIndex = partition(buffer, left, right, pivotIndex);
          if (pivotNewIndex > k) {
            right = pivotNewIndex - 1;
          } else if (pivotNewIndex < k) {
            left = Math.max(pivotNewIndex, left + 1);
            minThresholdPosition = pivotNewIndex;
          } else {
            break;
          }
        }
        bufferSize = k;

        threshold = buffer[minThresholdPosition];
        for (int i = minThresholdPosition + 1; i < bufferSize; i++) {
          threshold = max(threshold, buffer[i]);
        }
      }

当填满2*k个元素后,就通过快速选择算法找到中位数k的值,那么左边就是当前元素序列的最小topK,右边舍弃,循环遍历余下的序列,最终就会得到原始输入的所有元素的最小的k个元素。
guava提供的算法,实现了时间复杂度O(n)和空间复杂度O(k),不需要把所有的元素都载入数组,相对于PriorityQueue,实现过程复杂一些,但使用更小的空间,这对于工程实践是很重要的。

如果觉得我的文章对您有用,请在支付宝公益平台找个项目捐点钱。 @sxzhou Mar 11, 2018

奉献爱心