from __future__ import print_function import numpy as np # Tim Peter's select by median-of-median-of-k # https://mail.python.org/pipermail/python-list/2002-May/132170.html # # with modifications by Steven D'Aprano. def select_by_sorting(a, rank): """Return rank'th item of list a by sorting. May modify a.""" assert 0 <= rank < len(a) a.sort() return a[rank] def selectk(k, a, rank): """Return rank'th item of list a using median-of-median-of-k. May modify a in place. """ n = len(a) assert 0 <= rank < n if n <= k: return sorted(a)[rank] # Find the high median-of-median-of-k's. medians = [sorted(a[i:i+k])[k//2] for i in range(0, n-k+1, k)] median = selectk(k, medians, len(medians)//2) # Partition around the median. # a[:i] <= median <= a[j+1:] i, j = 0, n-1 while i <= j: while a[i] < median: i += 1 while a[j] > median: j -= 1 if i <= j: a[i], a[j] = a[j], a[i] i += 1 j -= 1 if rank < i: return selectk(k, a[:i], rank) else: return selectk(k, a[i:], rank-i) # Thomas Dybdahl Ahle's select and select2. # http://pastebin.com/30x0j39a def select(xs, k): """ Returns sorted(xs)[k] using expected linear time """ # Inv: xs[:a] <= res < xs[b:] a, b = 0, len(xs) while a + 1 != b: x = choose_pivot(xs, a, b) # Partition xs such that xs[:i] < x, xs[i:j] = x, xs[j:] > x i, j = partition3(xs, a, b, x, x+1) if k < i: b = i # If i <= res < j, res must be equal to the pivot elif k < j: a, b = i, i+1 else: a = j return xs[a] def choose_pivot(xs, a, b): """ Chooses a single pivot """ x = median_of_three(xs[a], xs[(a+b)//2], xs[b-1]) return x def median_of_three(x, y, z): """ Median of three for stability """ return x + y + z - (min(x,y,z) + max(x,y,z)) def select2(xs, a, b, k): """ Dual pivot, ala Java 7 quick sort, version of select """ if a + 1 == b: return xs[a] x, y = choose_pivots(xs, a, b) i, j = partition3(xs, a, b, x, y) if k < i: return select2(xs, a, i, k) if k >= j: return select2(xs, j, b, k) # If all the values in the middle segment are equal, # we have to return early to ensure termination. if x + 1 == y: return x return select2(xs, i, j, k) def select2b(xs, k, lim=100): """ Dual pivot, ala Java 7 quick sort, version of select """ depth = 0 a, b = 0, len(xs) while a + 1 != b and depth != lim: x, y = choose_pivots(xs, a, b) i, j = partition3(xs, a, b, x, y) if k < i: b = i elif k >= j: a = j else: # If all the values in the middle segment are equal, # we have to return early to ensure termination. if x + 1 == y: return x a, b = i, j depth += 1 if depth == lim: return select_by_sorting(xs[a:b]) return xs[a] def choose_pivots(xs, a, b): """ Chooses two pivots, similarly to how Java 7 does it """ third = (b-a)//3 x, y = xs[a+third], xs[b-1-third] if x < y: return x, y if x == y: return x, x+1 return y, x def partition3(xs, a, b, x, y): """ Post cond: xs[a:i] < x <= xs[i:j] < y <= xs[j:b] Inv: xs[a:i] < x <= xs[i:j] < y <= xs[k:b] """ assert 0 <= a < b <= len(xs) and x < y i, j, k = a, a, b while j != k: if xs[j] < x: xs[i], xs[j] = xs[j], xs[i] i, j = i+1, j+1 elif xs[j] < y: j = j+1 else: xs[j], xs[k-1] = xs[k-1], xs[j] k = k-1 return i, j # === Test code === import random def validate(): """Confirm that all select functions give the correct result.""" ranks = list(range(500)) data = [1000 + r for r in ranks] random.shuffle(ranks) # Test in random order. for rank in ranks: random.shuffle(data) assert select_by_sorting(data, rank) == 1000+rank for k in (7, 23, 47, 97): random.shuffle(data) assert selectk(k, data, rank) == 1000+rank random.shuffle(data) assert select(data, rank) == 1000+rank random.shuffle(data) assert select2(data, 0, 500, rank) == 1000+rank # Helper class for timings. Since we mostly care about large data sets, # the times will be moderately large (e.g. in excess of a second). In # that case, we don't bother with timeit, which is designed for timing # small snippets with millisecond timings. class Stopwatch: """Time hefty or long-running block of code using a ``with`` statement: >>> with Stopwatch() as sw: #doctest: +SKIP ... do_this() ... do_that() ... >>> print(sw.elapsed) #doctest: +SKIP 1.234567 """ def __init__(self, timer=None): if timer is None: from timeit import default_timer as timer self.timer = timer self._start = self._end = self._elapsed = None def __enter__(self): self._start = self.timer() return self def __exit__(self, *args): self._end = self.timer() @property def elapsed(self): return self._end - self._start HEADER = """\ N sort select7 select23 select47 select97 select select2 select2b np.median -------- -------- -------- -------- -------- -------- -------- -------- -------- ---------""" def run_test(data, single=True): """Run a timing test against all select* functions, and print the results. Smaller times are faster, hence better. Argument is the data points in the list argument. If is true, each function is called once. If is false, each function is called three times, using different ranks, and the total time taken is averaged. These timing tests assume that each test takes an appreciable amount of time (e.g. seconds rather than milliseconds), and do not take heroic measures to reduce the measurement overhead. Consequently, for very small and very low timings, the results shown may be inaccurate. """ size = len(data) if single: ranks = (size//2,) results = ([],) # List to hold results of calling select functions. else: # Pick three semi-arbitrary ranks. ranks = (size//2, size//3, 4*size//5) results = ([], [], []) # One list per rank. print("%8d" % size, end = ' ', flush=True) a = data[:] with Stopwatch() as sw: for r, L in zip(ranks, results): L.append(select_by_sorting(a, r)) print("%8.3f" % (sw.elapsed/3), end=' ', flush=True) for k in (7, 23, 47, 97): a = data[:] with Stopwatch() as sw: for r, L in zip(ranks, results): L.append(selectk(k, a, r)) print("%8.3f" % (sw.elapsed/3), end=' ', flush=True) a = data[:] with Stopwatch() as sw: for r, L in zip(ranks, results): L.append(select(a, r)) print("%8.3f" % (sw.elapsed/3), end=' ', flush=True) a = data[:] with Stopwatch() as sw: for r, L in zip(ranks, results): L.append(select2(a, 0, size, r)) print("%8.3f" % (sw.elapsed/3), end=' ', flush=True) a = data[:] with Stopwatch() as sw: for r, L in zip(ranks, results): L.append(select2b(a, r)) print("%8.3f" % (sw.elapsed/3), end=' ', flush=True) a = data[:] with Stopwatch() as sw: for r, L in zip(ranks, results): L.append(np.partition(a, r)[r]) print("%8.3f" % (sw.elapsed/3), flush=True) # Verify that all the functions gave the same result. for rank, L in zip(ranks, results): if any(x != L[0] for x in L): print("test for rank %d failed" % r, L) def gen_random_perm(size): """ Generates a random permutation of some """ data = list(range(size)) random.shuffle(data) return data def drive(single=True): with Stopwatch() as sw: if single: print("== Single call mode ==") else: print("== Average of three calls mode ==") print(HEADER) for i in range(4, 7): size = 10**i run_test(gen_random_perm(size//2), single) run_test(gen_random_perm(size), single) for i in range(2, 12): size = i*10**6 run_test(gen_random_perm(size), single) print('Non-randoms') run_test([0]*10**7, single) run_test(list(range(10**7)), single) run_test(list(reversed(range(10**7))), single) print("Total elapsed time: %.2f minutes" % (sw.elapsed/60)) print()