def merge_at(s, i): s[i] += s[i+1] del s[i+1] return s[i] def merge_force_collapse(s): cost = 0 while len(s) > 1: n = len(s) - 2 if n > 0 and s[n-1] < s[n+1]: n -= 1 cost += merge_at(s, n) return cost def sort(runs, merge_collapse): stack = [] maxstack = 0 cost = 0 for x in runs: stack.append(x) maxstack = max(maxstack, len(stack)) cost += merge_collapse(stack) cost += merge_force_collapse(stack) return cost, maxstack def timsort(s): cost = 0 while len(s) > 1: n = len(s) - 2 if ((n > 0 and s[n-1] <= s[n] + s[n+1]) or (n > 1 and s[n-2] <= s[n-1] + s[n])): if s[n-1] < s[n+1]: n -= 1 cost += merge_at(s, n) elif s[n] <= s[n+1]: cost += merge_at(s, n) else: break return cost def twomerge(s): cost = 0 while len(s) > 1: n = len(s) - 2 if s[n] < 2 * s[n+1]: if n > 0 and s[n-1] < s[n+1]: n -= 1 cost += merge_at(s, n) else: break if len(s) > 1: assert s[-2] >= 2 * s[-1] if len(s) > 2: assert s[-3] >= 2 * s[-2] return cost # The Shivers sort is taken from # "Adaptive Shivers Sort: An Alternative Sorting Algorithm" # Vincent Jugé # http://igm.univ-mlv.fr/~juge/papers/shivers_arxiv.pdf # return floor(log2(i)) >= floor(log2(j)) def log2_ge(i, j): return i >= j or i > (i ^ j) def log2_g(i,j): return not log2_get(j, i) def shivers(s): cost = 0 n = len(s) while n >= 3 and log2_ge(s[-1], s[-3]): cost += merge_at(s, n-3) n -= 1 while n >= 3 and log2_ge(s[-2], s[-3]): cost += merge_at(s, n-3) n -= 1 while n >= 2 and log2_ge(s[-1], s[-2]): cost += merge_at(s, n-2) n -= 1 assert len(s) == n return cost def shivers2(s): cost = 0 while len(s) > 1: n = len(s) if (n < 3): break elif (log2_g(s[-1],s[-2]) and s[-1] < s[-3]): cost += merge_at(s, n-2) elif (log2_g(s[-1],s[-3]) or log2_g(s[-1],s[-2]) or (n > 3 and log2_ge(s[-2],s[-3]))): if (n == 4 and (log2_ge(s[-2],s[-4]) or log2_g(s[-3],s[-4]))): cost += merge_at(s, n-4) else: cost += merge_at(s, n-3) else: break return cost # Minimal cost across all possible ways of merging runs. # Takes time cubic in len(runs). def ideal(runs): runs = list(runs) n = len(runs) # c[i, w] is minimal cost of merging slice runs[i : i+w]. c = {(i, 1) : 0 for i in range(n)} prefixsum = runs[0] for width in range(2, n+1): prefixsum += runs[width - 1] total = prefixsum for start in range(n - width + 1): c[start, width] = total + min(c[start, w] + c[start + w, width - w] for w in range(1, width)) if start + width < n: total += runs[start + width] - runs[start] return c[0, n] def greedy(runs): s = list(runs) cost = 0 while len(s) > 1: i = min(range(len(s) - 1), key=lambda i: s[i] + s[i+1]) cost += merge_at(s, i) return cost # Bad sequence of run lengths for timsort, summing to `n`. def rtim(n): if n <= 3: return [n] n2 = n >> 1 return rtim(n2) + rtim(n2 - 1) + [(n & 1) + 1] def power(s1, n1, n2, n): assert s1 >= 0 assert n1 >= 1 and n2 >= 1 assert s1 + n1 + n2 <= n # a = s1 + n1/2 # b = s1 + n1 + n2/2 = a + (n1 + n2)/2 a = 2*s1 + n1 # 2*a b = a + n1 + n2 # 2*b # Array length has d bits. Max power is d: # b/n - a/n = (b-a)/n = (n1 + n2)/2/n >= 2/2/n = 1/n > 1/2**d # So at worst b/n and a/n differ in bit 1/2**d. # a and b have <= d+1 bits. Shift left by d-1 and divide by 2n = # shift left by d-2 and divide by n. Result is d - bit length of # xor. After the shift, the numerator has at most d+1 + d-2 = 2*d-1 # bits. Any value of d >= n.bit_length() can be used. d = n.bit_length() # or larger; smaller can fail a = (a << (d-2)) // n b = (b << (d-2)) // n return d - (a ^ b).bit_length() # Like power but move n up to a power of 2. def powercheat(s1, n1, n2, n): assert s1 >= 0 assert n1 >= 1 and n2 >= 1 assert s1 + n1 + n2 <= n a = 2*s1 + n1 # 2*a b = a + n1 + n2 # 2*b d = n.bit_length() # or larger; smaller can fail a >>= 2 b >>= 2 return d - (a ^ b).bit_length() def powerloop(s1, n1, n2, n): assert s1 >= 0 assert n1 >= 1 and n2 >= 1 assert s1 + n1 + n2 <= n # a = s1 + n1/2 # b = s1 + n1 + n2/2 = a + (n1 + n2)/2 a = 2*s1 + n1 # 2*a b = a + n1 + n2 # 2*b L = 0 while True: L += 1 if a >= n: assert b >= n a -= n b -= n elif b >= n: break assert a < b < n a <<= 1 b <<= 1 return L from math import frexp T53 = float(2**53) # Using frexp works fine for small n, but flakes out quickly as n grows. # To get it to flake out even faster, instead of dividing by n, multiply # by 1.0/n. def power_frexp(s1, n1, n2, n): assert s1 >= 0 assert n1 >= 1 and n2 >= 1 assert s1 + n1 + n2 <= n # a = s1 + n1/2 # b = s1 + n1 + n2/2 = a + (n1 + n2)/2 a = s1 + n1 * 0.5 b = a + (n1 + n2) * 0.5 m1, e1 = frexp(a / n) m2, e2 = frexp(b / n) if e1 != e2: return 1 - max(e1, e2) m1 = int(m1 * T53) m2 = int(m2 * T53) x = m1 ^ m2 assert x return 54 - x.bit_length() - e1 if 0: for N in range(2, 10001): print(N) for s1 in range(N-1): # 2 <= sum of lengths <= N - s1 for lensum in range(2, N - s1 + 1): for n1 in range(1, lensum): n2 = lensum - n1 pold = powerloop(s1, n1, n2, N) pnew = powercheat(s1, n1, n2, N) if pold != pnew: print("OUCH", N, pold, pnew, s1, n1, n2) assert False if 0: from random import randrange count = 0 LO = 1 << 60 HI = LO << 1 while True: count += 1 N = randrange(LO, HI) s1 = randrange(N-1) for lensum in 2, N - s1, randrange(3, N - s1): for n1 in 1, lensum - 1, randrange(1, lensum): n2 = lensum - n1 pold = powerloop(s1, n1, n2, N) pnew = powercheat(s1, n1, n2, N) if pold != pnew: print("OUCH", N, pold, pnew, s1, n1, n2) assert False if count % 100_000 == 0: print(count) def powersort(runs): cost = 0 maxstack = 0 s = [] runs = list(runs) n = sum(runs) s1, n1 = 0, runs[0] for i in range(1, len(runs)): s2, n2 = s1 + n1, runs[i] p = power(s1, n1, n2, n) #p = powercheat(s1, n1, n2, n) while s and s[-1][-1] > p: s0, n0, _ = s.pop() assert s0 + n0 == s1 n1 += n0 cost += n1 s1 = s0 assert s1 + n1 == s2 if s: assert s[-1][-1] < p # never equal! s.append((s1, n1, p)) maxstack = max(maxstack, len(s)) s1, n1 = s2, n2 while s: s0, n0, _ = s.pop() assert s0 + n0 == s1 n1 += n0 cost += n1 s1 = s0 assert (s1, n1) == (0, n), (s1, n1, 0, n) return cost, maxstack def new_powersort(runs): cost = 0 maxstack = 0 s = [] runs = list(runs) n = sum(runs) nthird = int(n // 3) s1, n1 = 0, runs[0] for i in range(1, len(runs)): s2, n2 = s1 + n1, runs[i] p = power(s1, n1, n2, n) #p = powercheat(s1, n1, n2, n) while s and s[-1][-1] > p: #if s[-1][1] > n2: # break if len(s) > 1 and s[-2][-1] > p and s[-2][1] < n1: s0, n0, p0 = s[-2] n0 += s[-1][1] cost += n0 assert s0 + n0 == s1 s[-2] = s0, n0, p0 del s[-1] continue s0, n0, _ = s.pop() assert s0 + n0 == s1 n1 += n0 cost += n1 s1 = s0 assert s1 + n1 == s2 if 0:#s: assert s[-1][-1] < p # never equal! s.append((s1, n1, 1 if n1 > nthird else p)) #s.append((s1, n1, p)) maxstack = max(maxstack, len(s)) s1, n1 = s2, n2 while s: if len(s) > 1 and s[-2][1] < n1: s0, n0, _ = s[-2] n0 += s[-1][1] cost += n0 assert s0 + n0 == s1 s[-2] = s0, n0, _ del s[-1] continue s0, n0, _ = s.pop() assert s0 + n0 == s1 n1 += n0 cost += n1 s1 = s0 assert (s1, n1) == (0, n), (s1, n1, 0, n) return cost, maxstack # powersort = new_powersort def pmerge_at(s, i): s[i][1] += s[i+1][1] del s[i+1] return s[i][1] def pmerge_force_collapse(s): cost = 0 while len(s) > 1: n = len(s) - 2 cost += pmerge_at(s, n) return cost def pnewrun(stack, s2, n2, n): cost = 0 if stack: s1, n1, p = stack[-1] assert s1 + n1 == s2 assert p is None p = power(s1, n1, n2, n) while len(stack) > 1 and stack[-2][-1] > p: cost += pmerge_at(stack, len(stack) - 2) if len(stack) > 1: assert stack[-2][-1] < p stack[-1][-1] = p stack.append([s2, n2, None]) return cost def powersort_squash(runs): cost = 0 maxstack = 0 stack = [] runs = list(runs) n = sum(runs) s2 = 0 for n2 in runs: cost += pnewrun(stack, s2, n2, n) maxstack = max(maxstack, len(stack)) s2 += n2 cost += pmerge_force_collapse(stack) s1, n1, p = stack[-1] assert (s1, n1) == (0, n), (s1, n1, 0, n) return cost, maxstack def show(runs): runs = list(runs) n = sum(runs) s1, n1 = 0, runs[0] for i in range(1, len(runs)): s2, n2 = s1 + n1, runs[i] p = power(s1, n1, n2, n) print(f"{n1} <{p}>") s1, n1 = s2, n2 print(n1) def midpoints(runs): runs = list(runs) n = sum(runs) s = 0 for r in runs: print(r, (s + r/2) / n) s += r if 1: def one(tag, r): print(tag) r = list(r) print(" timsort", sort(r, timsort)) print(" twomerge", sort(r, twomerge)) print(" shivers", sort(r, shivers)) print(" shivers #2", sort(r, shivers2)) print("powersort", powersort(r)) print() from random import random, randrange one("all the same", [32] * 1000) # identical stats one("ascending", range(1, 2000)) # timsort a little better one("descending", reversed(range(1, 2000))) # twomerge significantly better one("bad timsort case", rtim(100000)) n = 1000000 one("bad powersort case", (n-1, 1, 1, n)) one("another bad timsort", (190000, 180000, 10000, 10000)) one("bad twomerge case", (190000, 60000, 40000, 10000)) one("bad shivers case", (30 << 12, 1 << 12, 16 << 12, 1 << 12)) # "Random" run-length distributions are largely irrelevant, since on # randomly ordered input the actual sort is most likely to _force_ # (via local binary insertion sorts) all runs to length `minrun`. # Nevertheless ... there's no clear overall winner in this # particular made-up distribution. On some runs timsort "wins" in # the end, on others twomerge, but the total costs are typically # within 2% of each other. # Later: powersort almost always dominates both. print("randomized trials") totals = [0, 0, 0, 0, 0] for trial in range(20): runs = [] for i in range(10000): switch = random() if switch < 0.80: x = randrange(1, 100) else: x = randrange(1000, 10000) runs.append(x) for i, which in enumerate([timsort, twomerge, shivers, shivers2]): cost, depth = sort(runs, which) print(f"{which.__name__:8s}", cost, depth) totals[i] += cost #cost = ideal(runs) #print(" ideal", cost) #totals[2] += cost #cost = greedy(runs) #print(" greedy", cost) #totals[3] += cost cost, depth = powersort(runs) print("power ", cost, depth) totals[4] += cost ptotal = totals[4] print("sofar totals", totals) print("% wrt powersort", " ".join(f"{(v - ptotal)/ptotal:.2%}" for v in totals)) if 0: def p(xs): def inner(i): if i == n: yield xs return orig = xs[i] for j in range(i, n): xs[i] = xs[j] xs[j] = orig yield from inner(i+1) xs[j] = xs[i] xs[i] = orig n = len(xs) return inner(0) if 1: from itertools import product worst = 0.0 first = 1 for t in product(list(range(1, 51)), repeat=3): if t[0] > first: print(t) first = t[0] g = ideal(t) #b, ignore = powersort(t) #b, ignore = new_powersort(t) #b, ignore = sort(t, timsort) #b, ignore = sort(t, twomerge) b, ignore = sort(t, shivers) #b = greedy(t) assert g <= b if g != b: ratio = b / g if ratio > worst: print(t, b, g, sort(t, timsort), sort(t, twomerge), sort(t, shivers), powersort(t), ratio) worst = ratio print("done") # n-1, 1, 1, n # sum = 2n + 1 # best = 2 + n+1 + 2n+1 = 3n + 4 # powersort = 4n + 2 # a = (n-1)/2 / (2*n + 1) = (n-1)/(4n+2) # b = (n - 1 + 1/2) / (2*n + 1) = (2n-1)/(4n+2) # so (0, n-1, 1) has power 2 # then (n-1, 1, 1) # a = n-1 + 1/2 = n - 1/2; (2n-1)/(4n+2) # b = n + 1/2l (2n+1)/(4n+2) # so has power 1, and n-1 is merged with 1 first if 0: def pch(old, new): return f"{(new - old)/old:.4%}" from itertools import product cost1 = cost2 = 0 oldcostwon = newcostwon = costtied = 0 stack1 = stack2 = 0 oldstackwon = newstackwon = stacktied = 0 def disp(): print(cost1, cost2, pch(cost1, cost2)) print(oldcostwon, newcostwon, costtied) print(stack1, stack2, pch(stack1, stack2)) print(oldstackwon, newstackwon, stacktied) first = 1 for t in product(range(1, 41), repeat=3): if t[0] > first: print(t) first = t[0] disp() c1, s1 = sort(t, timsort) #c2, s2 = sort(t, twomerge) #c2, s2 = powersort(t) #c1, s1 = powersort(t) c2, s2 = sort(t, shivers) cost1 += c1 cost2 += c2 stack1 += s1 stack2 += s2 if c1 < c2: oldcostwon += 1 #print(t, c1, s2, c2, s2) #assert False elif c1 > c2: newcostwon += 1 #print(t, c1, s2, c2, s2) #assert False else: costtied += 1 if s1 < s2: oldstackwon += 1 elif s1 > s2: newstackwon += 1 else: stacktied += 1 print("done") disp()