def merge_at(s, i):
s[i] += s[i+1]
del s[i+1]
return s[i]
def merge_force_collapse(s):
cost = 0
while len(s) > 1:
n = len(s) - 2
if n > 0 and s[n-1] < s[n+1]:
n -= 1
cost += merge_at(s, n)
return cost
def sort(runs, merge_collapse):
stack = []
maxstack = 0
cost = 0
for x in runs:
stack.append(x)
maxstack = max(maxstack, len(stack))
cost += merge_collapse(stack)
cost += merge_force_collapse(stack)
return cost, maxstack
def timsort(s):
cost = 0
while len(s) > 1:
n = len(s) - 2
if ((n > 0 and s[n-1] <= s[n] + s[n+1]) or
(n > 1 and s[n-2] <= s[n-1] + s[n])):
if s[n-1] < s[n+1]:
n -= 1
cost += merge_at(s, n)
elif s[n] <= s[n+1]:
cost += merge_at(s, n)
else:
break
return cost
def twomerge(s):
cost = 0
while len(s) > 1:
n = len(s) - 2
if s[n] < 2 * s[n+1]:
if n > 0 and s[n-1] < s[n+1]:
n -= 1
cost += merge_at(s, n)
else:
break
if len(s) > 1:
assert s[-2] >= 2 * s[-1]
if len(s) > 2:
assert s[-3] >= 2 * s[-2]
return cost
# Minimal cost across all possible ways of merging runs.
# Takes time cubic in len(runs).
def ideal(runs):
runs = list(runs)
n = len(runs)
# c[i, w] is minimal cost of merging slice runs[i : i+w].
c = {(i, 1) : 0 for i in range(n)}
prefixsum = runs[0]
for width in range(2, n+1):
prefixsum += runs[width - 1]
total = prefixsum
for start in range(n - width + 1):
c[start, width] = total + min(c[start, w] + c[start + w, width - w]
for w in range(1, width))
if start + width < n:
total += runs[start + width] - runs[start]
return c[0, n]
def greedy(runs):
s = list(runs)
cost = 0
while len(s) > 1:
i = min(range(len(s) - 1), key=lambda i: s[i] + s[i+1])
cost += merge_at(s, i)
return cost
# Bad sequence of run lengths for timsort, summing to `n`.
def rtim(n):
if n <= 3:
return [n]
n2 = n >> 1
return rtim(n2) + rtim(n2 - 1) + [(n & 1) + 1]
def power(s1, n1, n2, n):
assert s1 >= 0
assert n1 >= 1 and n2 >= 1
assert s1 + n1 + n2 <= n
# a = s1 + n1/2
# b = s1 + n1 + n2/2 = a + (n1 + n2)/2
a = 2*s1 + n1 # 2*a
b = a + n1 + n2 # 2*b
# Array length has d bits. Max power is d:
# b/n - a/n = (b-a)/n = (n1 + n2)/2/n >= 2/2/n = 1/n > 1/2**d
# So at worst b/n and a/n differ in bit 1/2**d.
# a and b have <= d+1 bits. Shift left by d-1 and divide by 2n =
# shift left by d-2 and divide by n. Result is d - bit length of
# xor. After the shift, the numerator has at most d+1 + d-2 = 2*d-1
# bits. Any value of d >= n.bit_length() can be used.
d = n.bit_length() # or larger; smaller can fail
a = (a << (d-2)) // n
b = (b << (d-2)) // n
return d - (a ^ b).bit_length()
# Like power but move n up to a power of 2.
def powercheat(s1, n1, n2, n):
assert s1 >= 0
assert n1 >= 1 and n2 >= 1
assert s1 + n1 + n2 <= n
a = 2*s1 + n1 # 2*a
b = a + n1 + n2 # 2*b
d = n.bit_length() # or larger; smaller can fail
a >>= 2
b >>= 2
return d - (a ^ b).bit_length()
def powerloop(s1, n1, n2, n):
assert s1 >= 0
assert n1 >= 1 and n2 >= 1
assert s1 + n1 + n2 <= n
# a = s1 + n1/2
# b = s1 + n1 + n2/2 = a + (n1 + n2)/2
a = 2*s1 + n1 # 2*a
b = a + n1 + n2 # 2*b
L = 0
while True:
L += 1
if a >= n:
assert b >= n
a -= n
b -= n
elif b >= n:
break
assert a < b < n
a <<= 1
b <<= 1
return L
from math import frexp
T53 = float(2**53)
# Using frexp works fine for small n, but flakes out quickly as n grows.
# To get it to flake out even faster, instead of dividing by n, multiply
# by 1.0/n.
def power_frexp(s1, n1, n2, n):
assert s1 >= 0
assert n1 >= 1 and n2 >= 1
assert s1 + n1 + n2 <= n
# a = s1 + n1/2
# b = s1 + n1 + n2/2 = a + (n1 + n2)/2
a = s1 + n1 * 0.5
b = a + (n1 + n2) * 0.5
m1, e1 = frexp(a / n)
m2, e2 = frexp(b / n)
if e1 != e2:
return 1 - max(e1, e2)
m1 = int(m1 * T53)
m2 = int(m2 * T53)
x = m1 ^ m2
assert x
return 54 - x.bit_length() - e1
if 0:
for N in range(2, 10001):
print(N)
for s1 in range(N-1):
# 2 <= sum of lengths <= N - s1
for lensum in range(2, N - s1 + 1):
for n1 in range(1, lensum):
n2 = lensum - n1
pold = powerloop(s1, n1, n2, N)
pnew = powercheat(s1, n1, n2, N)
if pold != pnew:
print("OUCH", N, pold, pnew, s1, n1, n2)
assert False
if 0:
from random import randrange
count = 0
LO = 1 << 60
HI = LO << 1
while True:
count += 1
N = randrange(LO, HI)
s1 = randrange(N-1)
for lensum in 2, N - s1, randrange(3, N - s1):
for n1 in 1, lensum - 1, randrange(1, lensum):
n2 = lensum - n1
pold = powerloop(s1, n1, n2, N)
pnew = powercheat(s1, n1, n2, N)
if pold != pnew:
print("OUCH", N, pold, pnew, s1, n1, n2)
assert False
if count % 100_000 == 0:
print(count)
def powersort(runs):
cost = 0
maxstack = 0
s = []
runs = list(runs)
n = sum(runs)
s1, n1 = 0, runs[0]
for i in range(1, len(runs)):
s2, n2 = s1 + n1, runs[i]
p = power(s1, n1, n2, n)
#p = powercheat(s1, n1, n2, n)
while s and s[-1][-1] > p:
s0, n0, _ = s.pop()
assert s0 + n0 == s1
n1 += n0
cost += n1
s1 = s0
assert s1 + n1 == s2
if s:
assert s[-1][-1] < p # never equal!
s.append((s1, n1, p))
maxstack = max(maxstack, len(s))
s1, n1 = s2, n2
while s:
s0, n0, _ = s.pop()
assert s0 + n0 == s1
n1 += n0
cost += n1
s1 = s0
assert (s1, n1) == (0, n), (s1, n1, 0, n)
return cost, maxstack
def show(runs):
runs = list(runs)
n = sum(runs)
s1, n1 = 0, runs[0]
for i in range(1, len(runs)):
s2, n2 = s1 + n1, runs[i]
p = power(s1, n1, n2, n)
print(f"{n1} <{p}>")
s1, n1 = s2, n2
print(n1)
def midpoints(runs):
runs = list(runs)
n = sum(runs)
s = 0
for r in runs:
print(r, (s + r/2) / n)
s += r
if 1:
def one(tag, r):
print(tag)
r = list(r)
print(" timsort", sort(r, timsort))
print(" twomerge", sort(r, twomerge))
print("powersort", powersort(r))
print()
from random import random, randrange
one("all the same", [32] * 1000) # identical stats
one("ascending", range(1, 2000)) # timsort a little better
one("descending", reversed(range(1, 2000))) # twomerge significantly better
one("bad timsort case", rtim(100000))
n = 1000000
one("bad powersort case", (n-1, 1, 1, n))
one("another bad timsort", (190000, 180000, 10000, 10000))
one("bad twomerge case", (190000, 60000, 40000, 10000))
# "Random" run-length distributions are largely irrelevant, since on
# randomly ordered input the actual sort is most likely to _force_
# (via local binary insertion sorts) all runs to length `minrun`.
# Nevertheless ... there's no clear overall winner in this
# particular made-up distribution. On some runs timsort "wins" in
# the end, on others twomerge, but the total costs are typically
# within 2% of each other.
# Later: powersort almost always dominates both.
print("randomized trials")
totals = [0, 0, 0]
for trial in range(20):
runs = []
for i in range(10000):
switch = random()
if switch < 0.80:
x = randrange(1, 100)
else:
x = randrange(1000, 10000)
runs.append(x)
for i, which in enumerate([timsort, twomerge]):
cost, depth = sort(runs, which)
print(f"{which.__name__:8s}", cost, depth)
totals[i] += cost
#cost = ideal(runs)
#print(" ideal", cost)
#totals[2] += cost
#cost = greedy(runs)
#print(" greedy", cost)
#totals[3] += cost
cost, depth = powersort(runs)
print("power ", cost, depth)
totals[2] += cost
print("sofar", totals, "timsort excess",
f"{(totals[0] - totals[1]) / totals[1]:.2%}")
if 0:
def p(xs):
def inner(i):
if i == n:
yield xs
return
orig = xs[i]
for j in range(i, n):
xs[i] = xs[j]
xs[j] = orig
yield from inner(i+1)
xs[j] = xs[i]
xs[i] = orig
n = len(xs)
return inner(0)
if 0:
from itertools import product
worst = 0.0
first = 2
for t in product(range(1, 20), repeat=4):
if t[0] == first:
print(t)
first += 1
g = ideal(t)
b, ignore = powersort(t)
#b, ignore = sort(t, timsort)
#b, ignore = sort(t, twomerge)
assert g <= b
if g != b:
ratio = b / g
if ratio > worst:
print(t, b, g,
sort(t, timsort), sort(t, twomerge), ratio)
worst = ratio
print("done")
# n-1, 1, 1, n # sum = 2n + 1
# best = 2 + n+1 + 2n+1 = 3n + 4
# powersort = 4n + 2
# a = (n-1)/2 / (2*n + 1) = (n-1)/(4n+2)
# b = (n - 1 + 1/2) / (2*n + 1) = (2n-1)/(4n+2)
# so (0, n-1, 1) has power 2
# then (n-1, 1, 1)
# a = n-1 + 1/2 = n - 1/2; (2n-1)/(4n+2)
# b = n + 1/2l (2n+1)/(4n+2)
# so has power 1, and n-1 is merged with 1 first