''' Speed-up comb() with a fixed column of precomputed values C(200, 100) == C(200, 20) * (C(180, 20) * (C(160, 20) * (C(140, 20) * C(120, 20) // C(40, 20)) // C(60, 20)) // C(80, 20)) // C(100, 20) ''' from math import comb, factorial ## Precomputations ################################################### FixedJ = 20 # Which diagonal of Pascal's triangle to precompute Jlim = 300 # Largest n in table of precomputed diagonal Mlim = 128 # Largest n for the 64 bit modular arithmetic tables S, F, Finv = [], [], [] for n in range(Mlim+1): f = factorial(n) s = (f & -f).bit_length() - 1 odd = f >> s inv = pow(odd, -1, 2**64) S.append(s), F.append(odd), Finv.append(inv) k2n = [max(n for n in range(Mlim+1) if comb(n, k).bit_length() <= 64) for k in range(35)] KnownComb = {n : comb(n, FixedJ) for n in range(k2n[FixedJ]+1, Jlim+1)} ## Runtime computations ############################################## def comb64(n, k): 'comb(n, k) in multiplicative group modulo 64-bits' return (F[n] * Finv[k] * Finv[n-k] & (2**64-1)) << (S[n] - S[k] - S[n - k]) def comb_iterative(n, k): 'Straight multiply and divide when k is small.' result = 1 for r in range(1, k+1): result *= n - r + 1 result //= r return result verbose = False def C(n, k): k = min(k, n - k) if k == 0: return 1 if k == 1: return n if k < len(k2n) and n <= k2n[k]: return comb64(n, k) # 64-bit fast case if k == FixedJ and n <= Jlim: return KnownComb[n] # Precomputed diagonal if k < 10: return comb_iterative(n, k) # Non-recursive for small k j = FixedJ if k > FixedJ and n <= Jlim else k // 2 if verbose: print(f'C({n}, {k}) = C({n}, {j}) * C({n-j}, {k-j}) // C({k}, {j})') return C(n, j) * C(n-j, k-j) // C(k, j) # Recursive case ## Test ############################################################### assert all(C(n, k) == comb(n, k) for n in range(Jlim+100) for k in range(n+1)) ## Storing and loading the KnownComb table ############################ positions = bytearray() # 2-byte chunks sizes = bytearray() # 1-byte chunks values = bytearray() # Variable length chunks offset = k2n[FixedJ]+1 # Starting point for table for i in range(Jlim - offset + 1): n = i + offset c = comb(n, FixedJ) size = (c.bit_length() + 7) // 8 positions += len(values).to_bytes(2, 'big') sizes += size.to_bytes(1, 'big') values += c.to_bytes(size, 'big') assert len(values) < 2 ** 16 # Positions must fit in 2 bytes total_table_size = len(positions) + len(sizes) + len(values) print(f'{total_table_size =} bytes') def lookup_comb(n): assert offset <= n <= Jlim i = n - offset position = (positions[2*i] << 8) | positions[2*i+1] size = sizes[i] return int.from_bytes(values[position : position + size], 'big') assert all(lookup_comb(n) == KnownComb[n] for n in range(offset, Jlim+1))