diff -r 14c52bb996be Lib/statistics.py --- a/Lib/statistics.py Mon Oct 03 09:15:39 2016 -0700 +++ b/Lib/statistics.py Mon Oct 03 20:18:34 2016 +0100 @@ -89,7 +89,7 @@ import numbers from fractions import Fraction -from decimal import Decimal +from decimal import Decimal, DefaultContext, MAX_EMAX, MIN_EMIN from itertools import groupby, chain from bisect import bisect_left, bisect_right @@ -103,6 +103,13 @@ # === Private utilities === +# Decimal module context used during polishing step of nth root computation. +_NTH_ROOT_CONTEXT = DefaultContext.copy() +_NTH_ROOT_CONTEXT.prec = 25 +_NTH_ROOT_CONTEXT.Emax = MAX_EMAX +_NTH_ROOT_CONTEXT.Emin = MIN_EMIN + + def _sum(data, start=0): """_sum(data [, start]) -> (type, sum, count) @@ -304,8 +311,11 @@ def __init__(self): raise TypeError('namespace only, do not instantiate') - def nth_root(x, n): - """Return the positive nth root of numeric x. + def nth_root(x, n, scale=0): + """Return the positive nth root of numeric x. If scale is given, + return the positive nth root of x * 2**scale. + + The scale parameter is not supported for Decimal inputs. This may be more accurate than ** or pow(): @@ -322,97 +332,52 @@ """ if not isinstance(n, int): raise TypeError('degree n must be an int') + if not isinstance(scale, int): + raise TypeError('scale factor must be an int') if n < 2: raise ValueError('degree n must be 2 or more') if isinstance(x, decimal.Decimal): + if scale: + raise NotImplementedError( + "Decimal nth root with a binary scale not implemented.") return _nroot_NS.decimal_nroot(x, n) elif isinstance(x, numbers.Real): - return _nroot_NS.float_nroot(x, n) + return _nroot_NS.float_nroot(x, n, scale) else: raise TypeError('expected a number, got %s') % type(x).__name__ - def float_nroot(x, n): + def float_nroot(x, n, scale, + D2=Decimal(2), + pow=_NTH_ROOT_CONTEXT.power, + mul=_NTH_ROOT_CONTEXT.multiply, + div=_NTH_ROOT_CONTEXT.divide, + sub=_NTH_ROOT_CONTEXT.subtract): """Handle nth root of Reals, treated as a float.""" assert isinstance(n, int) and n > 1 if x < 0: raise ValueError('domain error: root of negative number') - elif x == 0: - return math.copysign(0.0, x) - elif x > 0: - try: - isinfinity = math.isinf(x) - except OverflowError: - return _nroot_NS.bignum_nroot(x, n) - else: - if isinfinity: - return float('inf') - else: - return _nroot_NS.nroot(x, n) - else: - assert math.isnan(x) - return float('nan') - def nroot(x, n): - """Calculate x**(1/n), then improve the answer.""" - # This uses math.pow() to calculate an initial guess for the root, - # then uses the iterated nroot algorithm to improve it. - # - # By my testing, about 8% of the time the iterated algorithm ends - # up converging to a result which is less accurate than the initial - # guess. [FIXME: is this still true?] In that case, we use the - # guess instead of the "improved" value. This way, we're never - # less accurate than math.pow(). - r1 = math.pow(x, 1.0/n) - eps1 = abs(r1**n - x) - if eps1 == 0.0: - # r1 is the exact root, so we're done. By my testing, this - # occurs about 80% of the time for x < 1 and 30% of the - # time for x > 1. - return r1 - else: - try: - r2 = _nroot_NS.iterated_nroot(x, n, r1) - except RuntimeError: - return r1 - else: - eps2 = abs(r2**n - x) - if eps1 < eps2: - return r1 - return r2 + # Convert ints to floats, and pull apart into fraction and exponent. + m, e = _frexp_gen(x) - def iterated_nroot(a, n, g): - """Return the nth root of a, starting with guess g. + # Handle special cases. + if m == 0.0 or math.isinf(m) or math.isnan(m): + return m - This is a special case of Newton's Method. - https://en.wikipedia.org/wiki/Nth_root_algorithm - """ - np = n - 1 - def iterate(r): - try: - return (np*r + a/math.pow(r, np))/n - except OverflowError: - # If r is large enough, r**np may overflow. If that - # happens, r**-np will be small, but not necessarily zero. - return (np*r + a*math.pow(r, -np))/n - # With a good guess, such as g = a**(1/n), this will converge in - # only a few iterations. However a poor guess can take thousands - # of iterations to converge, if at all. We guard against poor - # guesses by setting an upper limit to the number of iterations. - r1 = g - r2 = iterate(g) - for i in range(1000): - if r1 == r2: - break - # Use Floyd's cycle-finding algorithm to avoid being trapped - # in a cycle. - # https://en.wikipedia.org/wiki/Cycle_detection#Tortoise_and_hare - r1 = iterate(r1) - r2 = iterate(iterate(r2)) - else: - # If the guess is particularly bad, the above may fail to - # converge in any reasonable time. - raise RuntimeError('nth-root failed to converge') - return r2 + # Normalise. We want nth_root(m * 2**(e + scale)), which + # is 2**g_scale * nth_root(frac * 2**r). + frac = 2.0 * m + g_scale, r = divmod(e + scale - 1, n) + + # Approximation to nth_root(frac * 2**r), using the libm pow. + g = frac**(1/n) * 2.0**(r/n) + + # Polish the result with a single iteration of Newton's method in + # extended precision. + Dg = Decimal(g) + g_polished = g - float(sub(Dg, div(mul(Decimal(frac), pow(D2, r)), + pow(Dg, n-1)))) / n + return math.ldexp(g_polished, g_scale) def decimal_nroot(x, n): """Handle nth root of Decimals.""" @@ -450,34 +415,61 @@ return r1 r0, r1 = r1, iterate(r1) - def bignum_nroot(x, n): - """Return the nth root of a positive huge number.""" - assert x > 0 - # I state without proof that ⁿ√x ≈ ⁿ√2·ⁿ√(x//2) - # and that for sufficiently big x the error is acceptable. - # We now halve x until it is small enough to get the root. - m = 0 - while True: - x //= 2 - m += 1 - try: - y = float(x) - except OverflowError: - continue - break - a = _nroot_NS.nroot(y, n) - # At this point, we want the nth-root of 2**m, or 2**(m/n). - # We can write that as 2**(q + r/n) = 2**q * ⁿ√2**r where q = m//n. - q, r = divmod(m, n) - b = 2**q * _nroot_NS.nroot(2**r, n) - return a * b - # This is the (private) function for calculating nth roots: _nth_root = _nroot_NS.nth_root assert type(_nth_root) is type(lambda: None) +def _frexp_gen(n): + """Version of math.frexp that works for floats and arbitrary integers. + + This version of frexp avoids the OverflowError that occurs for math.frexp + applied to integers outside the range of a float. + + Examples + -------- + >>> _frexp_gen(10) + (0.625, 4) + >>> _frexp_gen(16) + (0.5, 5) + >>> _frexp_gen(-10) + (-0.625, 4) + >>> _frexp_gen(2**1023) + (0.5, 1024) + >>> _frexp_gen(2**1024) + (0.5, 1025) + >>> _frexp_gen(2**1024 - 1) + (0.5, 1025) + >>> _frexp_gen(2**1024 - 2**970) # round-ties-to-even check + (0.5, 1025) + >>> _frexp_gen(2**1024 + 2**971) # round-ties-to-even check + (0.5, 1025) + >>> _frexp_gen(2**1024 + 3 * 2**971) # round-ties-to-even check + (0.5000000000000002, 1025) + >>> _frexp_gen(2**1024 + 5 * 2**971) # round-ties-to-even check + (0.5000000000000002, 1025) + >>> _frexp_gen(2**1024 + 7 * 2**971) # round-ties-to-even check + (0.5000000000000004, 1025) + >>> _frexp_gen(0) + (0.0, 0) + >>> _frexp_gen(-2**1024) + (-0.5, 1025) + """ + try: + return math.frexp(n) + except OverflowError: + # Does correct rounding assuming IEEE 754 binary64 format, but + # shouldn't give badly wrong results for other formats. For abs(n) in + # the range [2**53, 2**1024 - 2**970), the code below should give + # identical results to math.frexp(n). For the purposes of _product, + # the abs(m) == 1.0 check can be omitted - it's safe to simply return + # (m, e). + e = n.bit_length() + m = ((n >> e - 54) - (-n >> e - 54)) / (1 << 55) + return (m / 2.0, e + 1) if abs(m) == 1.0 else (m, e) + + def _product(values): """Return product of values as (exponent, mantissa).""" errmsg = 'mixed Decimal and float is not supported' @@ -499,11 +491,11 @@ # # x1*x2 = 2**p1*m1 * 2**p2*m2 = 2**(p1+p2)*(m1*m2) # - mant, scale = 1, 0 #math.frexp(prod) # FIXME + mant, scale = _frexp_gen(prod) for y in chain([x], values): if isinstance(y, Decimal): raise TypeError(errmsg) - m1, e1 = math.frexp(y) + m1, e1 = _frexp_gen(y) m2, e2 = math.frexp(mant) scale += (e1 + e2) mant = m1*m2