from math import frexp, ldexp def ulp(x): mant, exp = frexp(x) return ldexp(0.5, exp - 52) def floor_nroot(x, n): """For positive integers x, n, return the floor of the nth root of x.""" bl = x.bit_length() if bl <= 1000: a = int(x ** (1.0/n)) else: xhi = x >> (bl - 53) # x ~= xhi * 2**(bl-53) # x**(1/n) ~= xhi**(1/n) * 2**((bl-53)/n) a = xhi ** (1.0 / n) w, r = divmod(bl - 53, n) a *= 2.0 ** (r/n) m, e = frexp(a) a = int(m * 2**53) e += w - 53 if e >= 0: a <<= e else: a >>= -e # A guess of 1 can be horribly slow, since then the next # guess is approximately x/n. So force the first guess to # be at least 2. If that's too large, fine, it will be # cut down to 1 right away. a = max(a, 2) a = ((n-1)*a + x // a**(n-1)) // n while True: d = x // a**(n-1) if a <= d: return a a = ((n-1) * a + d) // n def nroot(x, n): """ Correctly-rounded nth root (n >= 2) of x, for a finite positive float x. """ if not (x > 0 and n >= 2): raise ValueError("x should be positive; n should be at least 2", x, n) m, e = frexp(x) rootm = floor_nroot(int(m * 2**53) << (53*n + (e-1)%n - 52), n) assert rootm.bit_length() == 54, rootm.bit_length() if rootm & 1: rootm += 1 return ldexp(rootm, (e-1)//n - 53) import decimal c = decimal.DefaultContext.copy() c.prec = 25 def rootn(x, n, D=decimal.Decimal, pow=c.power, sub=c.subtract, div=c.divide): g = x**(1.0/n) Dg = D(g) return g - float(sub(Dg, div(D(x), pow(Dg, n-1)))) / n del decimal, c def raw(x, n): return x**(1.0/n) def native(x, n): g = x**(1.0/n) if g**n == x: return g return g - (g - x/g**(n-1))/n def doit(): from random import random, randrange from math import ldexp from collections import Counter N = 1000 for n in range(2, 5000): print("n =", n) xs = [] for i in range(N): base = random() e = randrange(-500, 501) x = ldexp(base, e) xs.append(x) perfect = [nroot(x, n) for x in xs] for f, tag in [(raw, "x**(1/n)"), (native, "with 1 native-precision step"), (rootn, "with 1 extended-precision step")]: print(" ", tag) c = Counter() for x, p in zip(xs, perfect): g = f(x, n) u = (g - p) / ulp(p) assert u.is_integer() c[int(u)] += 1 if len(c) == 1 and 0 in c: print(" all correct") else: assert f is not rootn for t in sorted(c.items()): print(" %5d %5d" % t) doit()