Wednesday, March 17, 2010

Selecting k smallest or largest elements

There are cases when you need to select a number of best (according to some definition) elements out of finite sequence (list). For example, select 10 most popular baby names in a particular year or select 10 biggest files on your hard drive.  

While selecting single minimum or maximum element can easily be done iteratively in O(n) selecting k smallest or largest elements (k smallest for short) is not that simple.

It makes sense to take advantage of sequences APIs composability. We’ll design an extension method with the signature defined below:

public static IEnumerable<TSource> TakeSmallest<TSource>(
 this IEnumerable<TSource> source, int count, IComparer<TSource> comparer)

The name originates from the fact that selecting k smallest elements can logically be expressed in terms of Enumerable.TakeWhile supplying predicate that returns true if an element is one of the k smallest. As the logical predicate is not changing only count do it is burned into method’s name (instead of “While” that represents changing predicate we have “Smallest”).

Now let’s find the solution.

If the whole list is sorted first k elements is what we are looking for.

public static IEnumerable<TSource> TakeSmallest<TSource>(
 this IEnumerable<TSource> source, int count, IComparer<TSource> comparer)
{
 return source.OrderBy(x => x, comparer).Take(count);
}

It is O(n log n) solution where n is the number of elements in the source sequence. We can do better.

Priority queue yields better performance characteristics if only subset of sorted sequence is required.

public static IEnumerable<TSource> TakeSmallest<TSource>(
 this IEnumerable<TSource> source, int count, IComparer<TSource> comparer)
{
 var queue = new PriorityQueue<TSource>(source, comparer);
 while (count > 0 && queue.Count > 0)
 {
  yield return queue.Dequeue();
  count--;
 }
}

It requires O(n) to build priority queue based on binary min heap and O(k log n) to retrieve first k elements. Better but we’ll improve more.

Quicksort algorithm picks pivot element, reorders elements such that the ones less than pivot go before it while greater elements go after it (equal can go either way). After that pivot is in its final position. Then both partitions are sorted recursively making whole sequence sorted. In order to prevent worst case scenario pivot selection can be randomized.

Basically we are interested in the k smallest elements themselves and not the ordering relation between them. Assuming partitioning just completed let’s denote set of elements that are before pivot (including pivot itself) by L and set of elements that are after pivot by H. According to partition definition L contains |L| (where |X| denotes number of elements in a set X) smallest elements. If |L| is equal to k we are done. If it is less than k than look for k smallest elements in L. Otherwise as we already have |L| smallest elements look for k - |L| smallest elements in H.

public static IEnumerable<TSource> TakeSmallest<TSource>(
  this IEnumerable<TSource> source, int count)
{
  return TakeSmallest(source, count, Comparer<TSource>.Default);
}

public static IEnumerable<TSource> TakeSmallest<TSource>(
  this IEnumerable<TSource> source, int count, IComparer<TSource> comparer)
{
  Contract.Requires<ArgumentNullException>(source != null);
  // Sieve handles situation when count >= source.Count()
  Contract.Requires<ArgumentOutOfRangeException>(count > 0);
  Contract.Requires<ArgumentNullException>(comparer != null);

  return new Sieve<TSource>(source, count, comparer);
}

class Sieve<T> : IEnumerable<T>
{
  private readonly IEnumerable<T> m_source;
  private readonly IComparer<T> m_comparer;
  private readonly int m_count;

  private readonly Random m_random;

  public Sieve(IEnumerable<T> source, int count, IComparer<T> comparer)
  {
    m_source = source;
    m_count = count;
    m_comparer = comparer;
    m_random = new Random();
  }

  public IEnumerator<T> GetEnumerator()
  {
    var col = m_source as ICollection<T>;
    if (col != null && m_count >= col.Count)
    {
      // There is not point in copying data
      return m_source.GetEnumerator();
    }
    var buf = m_source.ToArray();
    if (m_count >= buf.Length)
    {
      // Buffer already contains exact amount elements
      return buf.AsEnumerable().GetEnumerator();
    }
    // Find the solution
    return GetEnumerator(buf);
  }

  IEnumerator IEnumerable.GetEnumerator()
  {
    return GetEnumerator();
  }

  private IEnumerator<T> GetEnumerator(T[] buf)
  {
    var n = buf.Length;
    var k = m_count;
    // After rearrange is completed fist k 
    // items are the smallest elements
    Rearrange(buf, 0, n - 1, k);
    for (int i = 0; i < k; i++)
    {
      yield return buf[i];
    }
  }

  private void Rearrange(T[] buf, int l, int u, int k)
  {
    if (l == u)
    {
      return;
    }
    // Partition elements around randomly selected pivot
    var q = RandomizedPartition(buf, l, u);
    // Compute size of low partition (includes pivot)
    var s = q - l + 1;
    // We are done as low partition is what we were looking for
    if (k == s)
    {
      return;
    }

    if (k < s)
    {
      // Smallest elements group is less than low partition
      // find it there
      Rearrange(buf, l, q - 1, k);
    }
    else
    {
      // Low partition is in smallest elements group, find the 
      // rest in high partition
      Rearrange(buf, q + 1, u, k - s);
    }
  }

  private int RandomizedPartition(T[] buf, int l, int u)
  {
    // Select pivot randomly and swap it with the last element
    // to prevent worst case scenario where pivot is the 
    // largest remaining element
    Swap(buf, m_random.Next(l, u + 1), u);
    // Divides elements into two partitions:
    // - Low partition where elements that are less than pivot 
    // and pivot itself
    // - High partition contains the rest 
    var k = l;
    for (var i = l; i < u; i++)
    {
      if (m_comparer.Compare(buf[i], buf[u]) < 0)
      {
        Swap(buf, k++, i);
      }
    }
    // Put pivot into its final location
    Swap(buf, k, u);
    return k;
  }

  private static void Swap(T[] a, int i, int j)
  {
    var tmp = a[i];
    a[i] = a[j];
    a[j] = tmp;
  }
}

The solution is expected O(n) which means quit good performance in practice. Let’s run the thing.

const int count = 100;
const int max = 100;
var rnd = new Random();
var seq = Enumerable.Range(0, count).Select(_ => rnd.Next(max)).ToArray();
Func<int, int> i = x => x;

for(var k = 1; k < count / 2; k++)
{
  var a = seq.TakeSmallest(k).OrderBy(i);
  var b = seq.OrderBy(i).Take(k);

  Debug.Assert(a.SequenceEqual(b));
}

Enjoy!

No comments: