""" Subquadratic algorithms for - division - integer to string conversion - integer squareroot (with remainder) The subquadratic division algorithm is based on the paper `Fast Recursive Division' by Burnikel and Ziegler. The fast str algorithm is the usual divide-and-conquer algorithm. The square root algorithm is based on 'Karatsuba Square Root' by Paul Zimmerman.""" # Under DIV_LIMIT bits we use the normal non-recursive division. DIV_LIMIT = 8400 DIV_LIMIT_POWER = 1 << DIV_LIMIT # Use builtin str under STR_LIMIT decimal digits. STR_LIMIT = 1000 STR_LIMIT_POWER = 10**STR_LIMIT # number of bits at which to use non-recursive square root SQRT_LIMIT = 500 def split(n, k): """divmod(n, 2**k) using bit operations.""" return n >> k, n & ((1 << k) - 1) def nbits(n, len_ = len, correction = { '0': 4, '1': 3, '2': 2, '3': 2, '4': 1, '5': 1, '6': 1, '7': 1, '8': 0, '9': 0, 'a': 0, 'b': 0, 'c': 0, 'd': 0, 'e': 0, 'f': 0}): """Number of bits in the positive integer n, or 0 if n == 0.""" if n < 0: raise ValueError("The argument to nbits should be nonnegative.") hex_n = "%x" % n return (len_(hex_n)<<2) - correction[hex_n[0]] def divmod_fast(a, b): """Compute divmod(a, b) for arbitrary integers a and b, using a recursive divide and conquer algorithm when a and b are large. """ # if b is small (including the case b == 0), use divmod directly if -DIV_LIMIT_POWER < b < DIV_LIMIT_POWER: return divmod(a, b) # reduce to case b > 0, a >= 0 negate = b < 0 if negate: a, b = -a, -b invert = a < 0 if invert: a = ~a # do a long division, nbits(b) bits at a time n = nbits(b) q_blocks = nbits(a) // n # number of blocks in the quotient (possibly one too large) r, tail = split(a, q_blocks*n) q = 0 for i in xrange(q_blocks-1, -1, -1): next, tail = split(tail, i*n) qd, r = div2n1n(r << n | next, b, n) q = q << n | qd # undo reductions if invert: q, r = ~q, ~r+b if negate: r = -r return q, r def str_fast(n): """Compute str(n) for the integer n, using a subquadratic divide-and-conquer algorithm.""" # Find a tight upper bound on the number of digits in n digits = -(-nbits(n)*146//485) if digits <= 1024: return str(n) # Given an n-digit integer, compute def printn(n, j, powers, head, acc): if j >= 0: q, r = divmod_fast(n, powers[j]) printq = not head or q if printq: printn(q, j-1, powers, head, acc) printn(r, j-1, powers, not printq, acc) else: if head: acc.append("%d" % n) else: acc.append(formatstr % n) # for small n use the built-in str if -STR_LIMIT_POWER < n < STR_LIMIT_POWER: return str(n) if n < 0: acc = ['-'] else: acc = [] # find an upper bound on the number of decimal digits of n: ceiling(nbits(n)*log(2)/log(10)) # 146/485 is an approximation to (and upper bound for) log(2)/log(10) digits = -(-nbits(n)*146//485) # number of blocks = ceiling(digits / STR_LIMIT), rounded up to a power of 2 depth = nbits((digits-1) // STR_LIMIT) assert depth > 0 blocks = 1 << depth # now just increase digits to the next biggest integer that's a multiple of blocks digits += -digits % blocks blocksize = digits // blocks # want to precompute the divisors: # 10**(digits/blocks), 10***(2*digits/blocks), ..., 10**(digits/2) M = 10**blocksize powers = [M] for i in xrange(depth-1): M *= M powers.append(M) assert len(powers) == depth formatstr = "%%0%dd" % blocksize printn(n, len(powers)-1, powers, True, acc) return "".join(acc) def isqrt(n): s, r = sqrtrem(n, nbits(n)) #assert n == s*s+r and 0 <= r <= 2*s return s # powers_of_ten[k] stores 10**2**k powers_of_ten = [10, 100, 10000, 10**8, 10**16] def fast_int(s, k): """Given a string representing a positive integer, with 2**k < len(s) <= 2**(k+1), compute int(s).""" assert 2**k < len(s) <= 2**(k+1) # Base case if k < 0: return int(s) # Otherwise, recurse head, tail = s[1<>1 else: hl = l>>1 high = s[:-hl] low = s[-hl:] return _str_to_int(high) * 10**hl + _str_to_int(low) def div2n1n(a, b, n): """Divide a 2n-bit nonnegative integer a by an n-bit positive integer b, using a recursive divide-and-conquer algorithm. Inputs: n is a positive integer b is a positive integer with exactly n bits a is a nonnegative integer such that a < 2**n * b Output: (q, r) such that a = b*q+r and 0 <= r < b. """ if n <= DIV_LIMIT: return divmod(a, b) # if n is odd then double a and b; halve remainder later pad = n&1 if pad: a <<= 1 b <<= 1 n += 1 half_n = n >> 1 mask = (1<>half_n, b&mask q1, r = div3n2n(a>>n, (a>>half_n)&mask, b, b1, b2, half_n) q2, r = div3n2n(r, a&mask, b, b1, b2, half_n) if pad: r >>= 1 return q1 << half_n | q2, r def div3n2n(a12, a3, b, b1, b2, n): """Helper function for div2n1n; not intended to be called directly. """ if a12 >> n == b1: q, r = (1 << n) - 1, a12 - (b1<> 1 sqrtrem_list = [(0, 0), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2), (2, 3), (2, 4), (3, 0), (3, 1), (3, 2), (3, 3), (3, 4), (3, 5), (3, 6), (4, 0), (4, 1), (4, 2), (4, 3), (4, 4), (4, 5), (4, 6), (4, 7), (4, 8), (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), (5, 8), (5, 9), (5, 10), (6, 0), (6, 1), (6, 2), (6, 3), (6, 4), (6, 5), (6, 6), (6, 7), (6, 8), (6, 9), (6, 10), (6, 11), (6, 12), (7, 0), (7, 1), (7, 2), (7, 3), (7, 4), (7, 5), (7, 6), (7, 7), (7, 8), (7, 9), (7, 10), (7, 11), (7, 12), (7, 13), (7, 14), (8, 0), (8, 1), (8, 2), (8, 3), (8, 4), (8, 5), (8, 6), (8, 7), (8, 8), (8, 9), (8, 10), (8, 11), (8, 12), (8, 13), (8, 14), (8, 15), (8, 16), (9, 0), (9, 1), (9, 2), (9, 3), (9, 4), (9, 5), (9, 6), (9, 7), (9, 8), (9, 9), (9, 10), (9, 11), (9, 12), (9, 13), (9, 14), (9, 15), (9, 16), (9, 17), (9, 18), (10, 0), (10, 1), (10, 2), (10, 3), (10, 4), (10, 5), (10, 6), (10, 7), (10, 8), (10, 9), (10, 10), (10, 11), (10, 12), (10, 13), (10, 14), (10, 15), (10, 16), (10, 17), (10, 18), (10, 19), (10, 20), (11, 0), (11, 1), (11, 2), (11, 3), (11, 4), (11, 5), (11, 6), (11, 7), (11, 8), (11, 9), (11, 10), (11, 11), (11, 12), (11, 13), (11, 14), (11, 15), (11, 16), (11, 17), (11, 18), (11, 19), (11, 20), (11, 21), (11, 22), (12, 0), (12, 1), (12, 2), (12, 3), (12, 4), (12, 5), (12, 6), (12, 7), (12, 8), (12, 9), (12, 10), (12, 11), (12, 12), (12, 13), (12, 14), (12, 15), (12, 16), (12, 17), (12, 18), (12, 19), (12, 20), (12, 21), (12, 22), (12, 23), (12, 24), (13, 0), (13, 1), (13, 2), (13, 3), (13, 4), (13, 5), (13, 6), (13, 7), (13, 8), (13, 9), (13, 10), (13, 11), (13, 12), (13, 13), (13, 14), (13, 15), (13, 16), (13, 17), (13, 18), (13, 19), (13, 20), (13, 21), (13, 22), (13, 23), (13, 24), (13, 25), (13, 26), (14, 0), (14, 1), (14, 2), (14, 3), (14, 4), (14, 5), (14, 6), (14, 7), (14, 8), (14, 9), (14, 10), (14, 11), (14, 12), (14, 13), (14, 14), (14, 15), (14, 16), (14, 17), (14, 18), (14, 19), (14, 20), (14, 21), (14, 22), (14, 23), (14, 24), (14, 25), (14, 26), (14, 27), (14, 28), (15, 0), (15, 1), (15, 2), (15, 3), (15, 4), (15, 5), (15, 6), (15, 7), (15, 8), (15, 9), (15, 10), (15, 11), (15, 12), (15, 13), (15, 14), (15, 15), (15, 16), (15, 17), (15, 18), (15, 19), (15, 20), (15, 21), (15, 22), (15, 23), (15, 24), (15, 25), (15, 26), (15, 27), (15, 28), (15, 29), (15, 30)] def sqrtrem_base2(a, bits): """An attempt to write a faster sqrtrem_base.""" if bits <= 8: return sqrtrem_list[a] pad = (bits+1)&2 if pad: a <<= 2 bits += 2 qbits = (bits+1) >> 2 hbits = qbits << 1 a32, a10 = split(a, hbits) sp, rp = sqrtrem(a32, bits-hbits) a1, a0 = split(a10, qbits) if rp >> 1 == sp: q = 1 << qbits u = a1 else: q, u = divmod((rp << qbits-1) + (a1 >> 1), sp) u = (u<<1) + (a1&1) s = (sp << qbits) + q r = (u << qbits) + a0 - q*q if r < 0: r = r + (s << 1) - 1 s -= 1 if pad: r >>= 2 if s & 1: s >>= 1 r = r + s + 1 else: s >>= 1 return s, r def sqrtrem(a, bits): """Given a nonnegative integer a, together with a nonnegative integer bits such that bits == nbits(a), find s and r such that a = s*s+r, 0 <= r <= 2*s.""" if bits <= SQRT_LIMIT: return sqrtrem_base(a, bits) # number of bits should have the form 4n or 4n-1 (so the square # root has 2n bits) pad = (bits+1)&2 if pad: a <<= 2 bits += 2 qbits = (bits+1) >> 2 hbits = qbits << 1 a32, a10 = split(a, hbits) sp, rp = sqrtrem(a32, bits-hbits) a1, a0 = split(a10, qbits) if rp >> 1 == sp: q = 1 << qbits u = a1 else: q, u = div2n1n((rp << qbits-1) + (a1 >> 1), sp, qbits) u = (u<<1) + (a1&1) s = (sp << qbits) + q r = (u << qbits) + a0 - q*q if r < 0: r = r + (s << 1) - 1 s -= 1 if pad: r >>= 2 if s & 1: s >>= 1 r = r + s + 1 else: s >>= 1 return s, r #from random import randrange, seed #seed("Little Tommy Tucker") #m = 4000000 #n = randrange(1<