""" Proof of correctness for crsum ============================== Terminology ----------- Call a list p of floats *nonoverlapping* if all its elements are nonzero and nonspecial, if the elements are increasing in magnitude, and if p[i-1] does not overlap p[i] for any 0 < i < len(p). Notation -------- Write Sum(p) for the (exact) sum of the values in a list p, as a real number. Note that Sum(p) may not be exactly representable as a floating-point number. Lemma ----- Let e be a power of 2. If p is nonoverlapping and nonempty, and 0 < p[-1] < e then 0 < Sum(p) < e. Similarly, if 0 > p[-1] > -e then 0 > Sum(p) > -e. Proof ----- The second part follows from the first by negating. We prove the first part. First inequality: Since p is nonoverlapping, p[i-1] <= p[i]/2 for all 0 < i < len(p). Hence Sum(p) >= p[-1] - 1/2 p[-1] - 1/4 p[-1] - ... > 0. Second inequality: if p[-1] < 2^n then 2^n-p[-1] is larger than, and nonoverlapping with, p[0] through p[-2]. (The least significant bit of 2^n - p[-1] is the same as that of p[-1].) By the first inequality applied to -p[0], ..., -p[-2], 2^n - p[-1], it follows that 0 < -p[0] + -p[1] + ... + -p[-2] + 2^n - p[-1] == 2^n - Sum(p) hence Sum(p) < 2^n, as claimed. The second part of the lemma follows by symmetry. Lemma ----- Suppose that p is nonoverlapping and len(p) >= 2. Suppose furthermore that p[-1] is the correctly rounded value of p[-1] + p[-2]. Then p[-1] is the correctly rounded value of Sum(p) unless: p[-1] + 2*p[-2] is exactly representable, and len(p) >= 3, and p[-3] and p[-2] have the same sign. If all of these conditions are satisfied then the correctly rounded value of Sum(p) is p[-1] + 2*p[-2]. Proof ----- Since, by assumption, p[-2] is nonzero, p[-1] + p[-2] is not exactly representable. Let p[-1] and p[-1]+e be the two floating-point numbers closest to p[-1]+p[-2]. Then abs(e) is a power of 2, and either 0 < p[-2] <= e/2 or 0 > p[-2] >= e/2. We give the proof in the former case; the proof in the latter is virtually identical. So 0 < p[-2] <= e/2. Divide into cases: Case 1: p[-2] < e/2. Then by the previous lemma, 0 < Sum(p[:-1]) < e/2, so p[-1] < Sum(p) < p[-1] + e/2, hence the sum is closer to p[-1] than p[-1] + e and p[-1] gives the correctly rounded value. Case 2: len(p) == 2. Then by assumption, p[-1] is the correctly rounded value of p[-1] + p[-2] == Sum(p). Case 3: p[-2] == e/2, len(p) >= 3, and p[-2] and p[-3] have opposite signs. Then -e/2 < p[-3] < 0, so by the preceding lemma, -e/2 < Sum(p[:-2]) < 0. Adding p[-1] + p[-2] == p[-1] + e/2 gives: p[-1] < Sum(p) < p[-1] + e/2 Hence p[-1] is the closest floating-point number to Sum(p). Case 4: p[-2] == e/2, len(p) >= 3, and p[-2] and p[-3] have the same sign. Then 0 < p[-3] < e/2, so by the preceding lemma, 0 < Sum(p[:-2]) < e/2. Hence p[-1] + e/2 < Sum(p) < p[-1] + e So p[-1] + e is the closest floating-point number to Sum(p), and the correctly rounded result is p[-1] + e = p[-1] + 2*p[-2]. Finally, the case where p[-1] is +- the largest representable float and p[-2] has the same sign as p[-1] needs special treatment. But if p[-1] is the largest representable float and p[-2] is positive then, letting e be the difference between p[-1] and 2**exp_max, p[-1] < p[-1] + p[-2] < p[-1] + e/2 where the second inequality is strict, because p[-1]+e/2 would be rounded up to 2**exp_max (and hence would overflow), contradicting the assumption that p[-1] + p[-2] rounds to p[-1]. Now as in case 1 above, it follows that Sum(p) rounds to p[-1]. Notes on correctness of msum ============================== (Assuming the proof of correctness for the original algorithm.) Divide into two stages: Stage 1: accumulating partials Stage 2: summing partials Stage 1: we follow the original algorithm, but whenever an overflow occurs in computing x+y == hi + lo, we instead compute x+y-2**1024 = hi+lo. As before, hi is the correctly rounded value of x+y-2**1024, and abs(lo) <= min(abs(x), abs(y)). From these properties it follows as in the original proof that the elements of partials[1:] are nonzero, nonadjacent, and increasing in magnitude. We keep track of the extra multiples of 2.**1024 in partials[0]. It's theoretically possible for partials[0] itself to overflow, but only if there are more than 2**53 overflows, which seems exceedingly unlikely in practice. Lemma ----- Suppose that x and y are floats with abs(y) <= abs(x). If x+y overflows and x is positive then x-2**1024 is exactly representable, and can be exactly computed as x-2**1023-2**1023. If x+y overflows and x is negative then x+2**1024 is exactly representable, and can be computed as x+2**1023+2**1023. Proof ----- If x+y overflows and x is positive then 2**1023 <= x < 2**1024, and it follows that both x-2**1023 and x-2**1024 are exactly representable. At the end of stage 1, the value of the sum of the original iterable is Sum(partials[1:]) + 2**1024*partials[0]. (Special values: if a NaN is encountered the function exits immediately, returning that NaN. Otherwise, partials[0] will be nan if the iterable contained both +inf and -inf, +inf if the iterable contained +inf but not -inf, and -inf if the iterable contained -inf but not +inf.) Stage 2: if partials[0] == 0.0, we just sum partials[1:], using crsum. If abs(partials[0]) >= 2.0 then the sum overflowed. The difficult case is when abs(partials[0]) == 1.0. For notational simplicity, assume that partials[0] == 1.0, so that we want to compute 2**1024 + Sum(partials[1:]). Write p for partials. Then we have the following: Lemma ----- Assume len(p) > 1. Assume IEEE 754 format and semantics (float_info.mant_dig = 53, float_info.max_exp = 1024) If 2**1024 + p[-1] > 2**1024 - 2**970 then 2**1024 + Sum(p[1:]) overflows. If 2**1024 + p[-1] < 2**1024 - 2**970 then 2**1024 + Sum(p[1:]) is representable. If 2**1024 + p[-1] == 2**1024 - 2**970 then 2**1024 + Sum(p[1:]) rounds to 2**1024-2**971 if len(p) > 2 and p[-2] < 0, and overflows otherwise. Proof ----- Similar to the proof for crsum above. To detect whether 2**1024 + p[-1] >= 2**1024 - 2**970, we test whether 2.*(2**1023 + p[-1]/2.0) overflows. (p[-1]/2.0 is exact except possibly when p[-1] is subnormal, but when p[-1] is subnormal the test still produces the right result.) To detect whether 2**1024 + p[-1] == 2**1024 - 2**970, check whether 2**1023 + p[-1]/2 == 2**1023 and 2**1023 + p[-1] - 2**1023 == p[-1]. """ from time import time from random import random, gauss, shuffle from math import isinf, isnan from sys import float_info twopow = 2.0**(float_info.max_exp - 1) def samesign(x, y): return (x >= 0.0) == (y >= 0.0) def twosum(x, y): # assumes that abs(x) >= abs(y) hi = x + y lo = y - (hi - x) return hi, lo def crsum(partials): """Compute the sum of a list of nonoverlapping floats. On input, partials is a list of nonzero, nonspecial, nonoverlapping floats, strictly increasing in magnitude, but possibly not all having the same sign. On output, the sum of partials gives the error in the returned result, which is correctly rounded (using the round-half-to-even rule). The elements of partials remain nonzero, nonspecial, nonoverlapping, and increasing in magnitude. Assumes IEEE 754 float format and semantics. """ if not partials: return 0.0 # sum from the top, stopping as soon as the sum is inexact. total_so_far = partials.pop() while partials: total_so_far, lo = twosum(total_so_far, partials.pop()) if lo: partials.append(lo) break # adjust for correct rounding if necessary if len(partials) >= 2 and samesign(partials[-1], partials[-2]) and \ total_so_far + 2*partials[-1] - total_so_far == 2*partials[-1]: total_so_far += 2*partials[-1] partials[-1] = -partials[-1] return total_so_far def msum(iterable): """Full precision sum of values in iterable. Returns the value of the sum, rounded to the nearest representable floating-point number using the round-half-to-even rule. """ # Stage 1: accumulate partials partials = [0.0] for x in iterable: if isnan(x): return x elif isinf(x): partials[0] += x else: i = 1 for y in partials[1:]: if abs(x) < abs(y): x, y = y, x hi, lo = twosum(x, y) if isinf(hi): sign = 1 if hi > 0 else -1 x = x - twopow*sign - twopow*sign partials[0] += sign if abs(x) < abs(y): x, y = y, x hi, lo = twosum(x, y) if lo: partials[i] = lo i += 1 x = hi partials[i:] = [x] if x else [] # special cases arising from infinities if isinf(partials[0]): return partials[0] elif isnan(partials[0]): raise ValueError('infinities of both signs in summands') # Stage 2: sum partials[1:] + 2**exp_max * partials[0] if abs(partials[0]) == 1.0 and len(partials) > 1 and \ not samesign(partials[-1], partials[0]): # problem case: decide whether result is representable hi, lo = twosum(partials[0]*twopow, partials[-1]/2) if isinf(2*hi): # overflow, except in edge case... if hi+2*lo-hi == 2*lo and \ len(partials) > 2 and samesign(lo, partials[-2]): return 2*(hi+2*lo) else: partials[-1:] = [2*lo, 2*hi] if lo else [2*hi] partials[0] = 0.0 if not partials[0]: return crsum(partials[1:]) raise OverflowError('overflow in msum') def test(func): inf = float('inf') nan = float('nan') test_values = [ ([], 0.0), ([0.0], 0.0), ([1e100, 1.0, -1e100, 1e-100, 1e50, -1.0, -1e50], 1e-100), ([1e308, 1e308, -1e308], 1e308), ([-1e308, 1e308, 1e308], 1e308), ([1e308, -1e308, 1e308], 1e308), ([2.0**1023, 2.0**1023, -2.0**1000], 1.7976930277114552e+308), ([twopow, twopow, twopow, twopow, -twopow, -twopow, -twopow], 8.9884656743115795e+307), ([2.0**53, -0.5, -2.0**-54], 2.0**53-1.0), ([2.0**53, 1.0, 2.0**-100], 2.0**53+2.0), ([2.0**53+10.0, 1.0, 2.0**-100], 2.0**53+12.0), ([2.0**53-4.0, 0.5, 2.0**-54], 2.0**53-3.0), ([2.0**1023-2.0**970, -1.0, 2.0**1023], 1.7976931348623157e+308), ([float_info.max, float_info.max*2.**-54], float_info.max), ([float_info.max, float_info.max*2.**-53], OverflowError), ([1./n for n in range(1, 1001)], 7.4854708605503451), ([(-1.)**n/n for n in range(1, 1001)], -0.69264743055982025), ([1.7**(i+1)-1.7**i for i in range(1000)] + [-1.7**1000], -1.0), ([inf, -inf, nan], nan), ([nan, inf, -inf], nan), ([inf, nan, inf], nan), ([inf, inf], inf), ([inf, -inf], ValueError), ([-inf, 1e308, 1e308, -inf], -inf), ([2.0**1023-2.0**970, 0.0, 2.0**1023], OverflowError), ([2.0**1023-2.0**970, 1.0, 2.0**1023], OverflowError), ([2.0**1023, 2.0**1023], OverflowError), ([2.0**1023, 2.0**1023, -1.0], OverflowError), ([twopow, twopow, twopow, twopow, -twopow, -twopow], OverflowError), ([twopow, twopow, twopow, twopow, -twopow, twopow], OverflowError), ([-twopow, -twopow, -twopow, -twopow], OverflowError), ([2.**1023, 2.**1023, -2.**971], float_info.max), ([2.**1023, 2.**1023, -2.**970], OverflowError), ([-2.**970, 2.**1023, 2.**1023, -2.**-1074], float_info.max), ([2.**1023, 2.**1023, -2.**970, 2.**-1074], OverflowError), ([-2.**1023, 2.**971, -2.**1023], -float_info.max), ([-2.**1023, -2.**1023, 2.**970], OverflowError), ([-2.**1023, -2.**1023, 2.**970, 2.**-1074], -float_info.max), ([-2.**-1074, -2.**1023, -2.**1023, 2.**970], OverflowError), ([2.**930, -2.**980, 2.**1023, 2.**1023, twopow, -twopow], 1.7976931348622137e+308), ([2.**1023, 2.**1023, -1e307], 1.6976931348623159e+308), ([1e16, 1., 1e-16], 10000000000000002.0), ] e, t = 0, time() for i, (vals, s) in enumerate(test_values): if isinstance(s, type) and issubclass(s, Exception): try: m = func(vals) except s: continue else: pass else: try: m = func(vals) if m == s or isnan(m) and isnan(s): continue except Exception, m: pass print "test %d failed: got %r, expected %r for %s(%r)" % ( i, m, s, func.__name__, vals) e += 1 if not e: t = (time() - t) * 1e3 print 'all %s tests passed (%.3f ms)' % (func.__name__, t) return e def test2(math_sum, cmath_sum=None): # Compare C math.- and cmath.sum results with those from the msum() # function above. Tests copied from the original Python recipe at # e, t = 0, time() for j in xrange(1000): vals = [7, 1e100, -7, -1e100, -9e-20, 8e-20] * 10 s = 0 for i in range(200): v = gauss(0, random()) ** 7 - s s += v vals.append(v) shuffle(vals) s = msum(vals) try: m = math_sum(vals) except Exception, m: pass if m != s: print 'test failed: got %r, expected %r for %.100r' % (m, s, vals) e += 1 if cmath_sum: # check complex too c = complex(s, -s) try: m = cmath_sum([complex(v, -v) for v in vals]) except Exception, m: pass if m != c: print 'test failed: got %r, expected %r for %.100r' % (m, c, vals) e += 1 if not e: t = (time() - t) * 1e3 if cmath_sum: c = ', incl complex' else: c = '' print 'all %s tests passed%s (%.3f ms)' % ('Compare', c, t) return e if __name__ == '__main__': t = time() e = test(msum) try: from math import sum as math_sum from cmath import sum as cmath_sum e += test(math_sum) e += test2(math_sum, cmath_sum) except ImportError: pass t = (time() - t) * 1e3 if e: print e, 'tests failed (%.3f ms)' % t else: print 'all tests passed (%.3f ms)' % t