''' Speed-up comb() with a fixed column of precomputed values ''' from math import comb, factorial ## Precomputations ################################################### FixedJ = 20 # Which diagonal of Pascal's triangle to precompute Jlim = 225 # Largest n in table of precomputed diagonal Mlim = 128 # Largest n for the 64 bit modular arithmetic tables S, F, Finv = [], [], [] for n in range(Mlim+1): f = factorial(n) s = (f & -f).bit_length() - 1 odd = f >> s inv = pow(odd, -1, 2**64) S.append(s), F.append(odd), Finv.append(inv) k2n = [max(n for n in range(Mlim+1) if comb(n, k).bit_length() <= 64) for k in range(35)] KnownComb = {n : comb(n, FixedJ) for n in range(min(k2n)+1, Jlim+1)} ## Runtime computations ############################################## def M(n, k): '64-bit comb(n, k) in multiplicative group mod 2**64' assert k < len(k2n) and n <= k2n[min(k, n - k)], f'comb({n}, {k}) exceeds 64-bits' return (F[n] * Finv[k] * Finv[n-k] % 2**64) << (S[n] - S[k] - S[n - k]) def C(n, k): k = min(k, n - k) if k == 0: return 1 if k == 1: return n if k < len(k2n) and n <= k2n[k]: return M(n, k) if k == FixedJ and n <= Jlim: return KnownComb[n] j = min(k // 2, FixedJ) return C(n, j) * C(n-j, k-j) // C(k, j) ## Test ############################################################### assert all(C(n, k) == comb(n, k) for n in range(Jlim+100) for k in range(n+1)) ## Storing and loading the KnownComb table ############################ positions = bytearray() # 2-byte chunks sizes = bytearray() # 1-byte chunks values = bytearray() # Variable length chunks offset = min(k2n)+1 # Starting point for table for i in range(Jlim - offset): n = i + offset c = comb(n, FixedJ) size = (c.bit_length() + 7) // 8 positions += len(values).to_bytes(2) sizes += size.to_bytes(1) values += c.to_bytes(size) assert len(values) < 2 ** 16 # Positions must fit in 2 bytes total_table_size = len(positions) + len(sizes) + len(values) print(f'{total_table_size =} bytes') def lookup_comb(n): assert offset <= n <= Jlim i = n - offset position = (positions[2*i] << 8) | positions[2*i+1] size = sizes[i] return int.from_bytes(values[position : position + size]) assert all(lookup_comb(n) == KnownComb[n] for n in range(offset, Jlim))