from __future__ import print_function
import numpy as np
# Tim Peter's select by median-of-median-of-k
# https://mail.python.org/pipermail/python-list/2002-May/132170.html
#
# with modifications by Steven D'Aprano.
def select_by_sorting(a, rank):
"""Return rank'th item of list a by sorting. May modify a."""
assert 0 <= rank < len(a)
a.sort()
return a[rank]
def selectk(k, a, rank):
"""Return rank'th item of list a using median-of-median-of-k.
May modify a in place.
"""
n = len(a)
assert 0 <= rank < n
if n <= k:
return sorted(a)[rank]
# Find the high median-of-median-of-k's.
medians = [sorted(a[i:i+k])[k//2] for i in range(0, n-k+1, k)]
median = selectk(k, medians, len(medians)//2)
# Partition around the median.
# a[:i] <= median <= a[j+1:]
i, j = 0, n-1
while i <= j:
while a[i] < median:
i += 1
while a[j] > median:
j -= 1
if i <= j:
a[i], a[j] = a[j], a[i]
i += 1
j -= 1
if rank < i:
return selectk(k, a[:i], rank)
else:
return selectk(k, a[i:], rank-i)
# Thomas Dybdahl Ahle's select and select2.
# http://pastebin.com/30x0j39a
def select(xs, k):
""" Returns sorted(xs)[k] using expected linear time """
# Inv: xs[:a] <= res < xs[b:]
a, b = 0, len(xs)
while a + 1 != b:
x = choose_pivot(xs, a, b)
# Partition xs such that xs[:i] < x, xs[i:j] = x, xs[j:] > x
i, j = partition3(xs, a, b, x, x+1)
if k < i:
b = i
# If i <= res < j, res must be equal to the pivot
elif k < j:
a, b = i, i+1
else:
a = j
return xs[a]
def choose_pivot(xs, a, b):
""" Chooses a single pivot """
x = median_of_three(xs[a], xs[(a+b)//2], xs[b-1])
return x
def median_of_three(x, y, z):
""" Median of three for stability """
return x + y + z - (min(x,y,z) + max(x,y,z))
def select2(xs, a, b, k):
""" Dual pivot, ala Java 7 quick sort, version of select """
if a + 1 == b:
return xs[a]
x, y = choose_pivots(xs, a, b)
i, j = partition3(xs, a, b, x, y)
if k < i:
return select2(xs, a, i, k)
if k >= j:
return select2(xs, j, b, k)
# If all the values in the middle segment are equal,
# we have to return early to ensure termination.
if x + 1 == y:
return x
return select2(xs, i, j, k)
def select2b(xs, k, lim=100):
""" Dual pivot, ala Java 7 quick sort, version of select """
depth = 0
a, b = 0, len(xs)
while a + 1 != b and depth != lim:
x, y = choose_pivots(xs, a, b)
i, j = partition3(xs, a, b, x, y)
if k < i:
b = i
elif k >= j:
a = j
else:
# If all the values in the middle segment are equal,
# we have to return early to ensure termination.
if x + 1 == y:
return x
a, b = i, j
depth += 1
if depth == lim:
return select_by_sorting(xs[a:b])
return xs[a]
def choose_pivots(xs, a, b):
""" Chooses two pivots, similarly to how Java 7 does it """
third = (b-a)//3
x, y = xs[a+third], xs[b-1-third]
if x < y: return x, y
if x == y: return x, x+1
return y, x
def partition3(xs, a, b, x, y):
""" Post cond: xs[a:i] < x <= xs[i:j] < y <= xs[j:b]
Inv: xs[a:i] < x <= xs[i:j] < y <= xs[k:b] """
assert 0 <= a < b <= len(xs) and x < y
i, j, k = a, a, b
while j != k:
if xs[j] < x:
xs[i], xs[j] = xs[j], xs[i]
i, j = i+1, j+1
elif xs[j] < y:
j = j+1
else:
xs[j], xs[k-1] = xs[k-1], xs[j]
k = k-1
return i, j
# === Test code ===
import random
def validate():
"""Confirm that all select functions give the correct result."""
ranks = list(range(500))
data = [1000 + r for r in ranks]
random.shuffle(ranks) # Test in random order.
for rank in ranks:
random.shuffle(data)
assert select_by_sorting(data, rank) == 1000+rank
for k in (7, 23, 47, 97):
random.shuffle(data)
assert selectk(k, data, rank) == 1000+rank
random.shuffle(data)
assert select(data, rank) == 1000+rank
random.shuffle(data)
assert select2(data, 0, 500, rank) == 1000+rank
# Helper class for timings. Since we mostly care about large data sets,
# the times will be moderately large (e.g. in excess of a second). In
# that case, we don't bother with timeit, which is designed for timing
# small snippets with millisecond timings.
class Stopwatch:
"""Time hefty or long-running block of code using a ``with`` statement:
>>> with Stopwatch() as sw: #doctest: +SKIP
... do_this()
... do_that()
...
>>> print(sw.elapsed) #doctest: +SKIP
1.234567
"""
def __init__(self, timer=None):
if timer is None:
from timeit import default_timer as timer
self.timer = timer
self._start = self._end = self._elapsed = None
def __enter__(self):
self._start = self.timer()
return self
def __exit__(self, *args):
self._end = self.timer()
@property
def elapsed(self):
return self._end - self._start
HEADER = """\
N sort select7 select23 select47 select97 select select2 select2b np.median
-------- -------- -------- -------- -------- -------- -------- -------- -------- ---------"""
def run_test(data, single=True):
"""Run a timing test against all select* functions, and print the
results. Smaller times are faster, hence better.
Argument is the data points in the list argument.
If is true, each function is called once. If is
false, each function is called three times, using different ranks,
and the total time taken is averaged.
These timing tests assume that each test takes an appreciable amount of
time (e.g. seconds rather than milliseconds), and do not take heroic
measures to reduce the measurement overhead. Consequently, for very
small and very low timings, the results shown may be inaccurate.
"""
size = len(data)
if single:
ranks = (size//2,)
results = ([],) # List to hold results of calling select functions.
else:
# Pick three semi-arbitrary ranks.
ranks = (size//2, size//3, 4*size//5)
results = ([], [], []) # One list per rank.
print("%8d" % size, end = ' ', flush=True)
a = data[:]
with Stopwatch() as sw:
for r, L in zip(ranks, results):
L.append(select_by_sorting(a, r))
print("%8.3f" % (sw.elapsed/3), end=' ', flush=True)
for k in (7, 23, 47, 97):
a = data[:]
with Stopwatch() as sw:
for r, L in zip(ranks, results):
L.append(selectk(k, a, r))
print("%8.3f" % (sw.elapsed/3), end=' ', flush=True)
a = data[:]
with Stopwatch() as sw:
for r, L in zip(ranks, results):
L.append(select(a, r))
print("%8.3f" % (sw.elapsed/3), end=' ', flush=True)
a = data[:]
with Stopwatch() as sw:
for r, L in zip(ranks, results):
L.append(select2(a, 0, size, r))
print("%8.3f" % (sw.elapsed/3), end=' ', flush=True)
a = data[:]
with Stopwatch() as sw:
for r, L in zip(ranks, results):
L.append(select2b(a, r))
print("%8.3f" % (sw.elapsed/3), end=' ', flush=True)
a = data[:]
with Stopwatch() as sw:
for r, L in zip(ranks, results):
L.append(np.partition(a, r)[r])
print("%8.3f" % (sw.elapsed/3), flush=True)
# Verify that all the functions gave the same result.
for rank, L in zip(ranks, results):
if any(x != L[0] for x in L):
print("test for rank %d failed" % r, L)
def gen_random_perm(size):
""" Generates a random permutation of some """
data = list(range(size))
random.shuffle(data)
return data
def drive(single=True):
with Stopwatch() as sw:
if single:
print("== Single call mode ==")
else:
print("== Average of three calls mode ==")
print(HEADER)
for i in range(4, 7):
size = 10**i
run_test(gen_random_perm(size//2), single)
run_test(gen_random_perm(size), single)
for i in range(2, 12):
size = i*10**6
run_test(gen_random_perm(size), single)
print('Non-randoms')
run_test([0]*10**7, single)
run_test(list(range(10**7)), single)
run_test(list(reversed(range(10**7))), single)
print("Total elapsed time: %.2f minutes" % (sw.elapsed/60))
print()