"""
Proof of correctness for crsum
==============================
Terminology
-----------
Call a list p of floats *nonoverlapping* if all its elements are
nonzero and nonspecial, if the elements are increasing in magnitude,
and if p[i-1] does not overlap p[i] for any 0 < i < len(p).
Notation
--------
Write Sum(p) for the (exact) sum of the values in a list p, as a real
number. Note that Sum(p) may not be exactly representable as a
floating-point number.
Lemma
-----
Let e be a power of 2. If p is nonoverlapping and nonempty, and
0 < p[-1] < e
then
0 < Sum(p) < e.
Similarly, if 0 > p[-1] > -e then 0 > Sum(p) > -e.
Proof
-----
The second part follows from the first by negating. We prove the
first part.
First inequality: Since p is nonoverlapping, p[i-1] <= p[i]/2 for all
0 < i < len(p). Hence
Sum(p) >= p[-1] - 1/2 p[-1] - 1/4 p[-1] - ... > 0.
Second inequality: if p[-1] < 2^n then 2^n-p[-1] is larger than, and
nonoverlapping with, p[0] through p[-2]. (The least significant bit
of 2^n - p[-1] is the same as that of p[-1].) By the first inequality
applied to -p[0], ..., -p[-2], 2^n - p[-1], it follows that
0 < -p[0] + -p[1] + ... + -p[-2] + 2^n - p[-1] == 2^n - Sum(p)
hence Sum(p) < 2^n, as claimed. The second part of the lemma follows
by symmetry.
Lemma
-----
Suppose that p is nonoverlapping and len(p) >= 2. Suppose furthermore
that p[-1] is the correctly rounded value of p[-1] + p[-2]. Then
p[-1] is the correctly rounded value of Sum(p) unless:
p[-1] + 2*p[-2] is exactly representable, and
len(p) >= 3, and
p[-3] and p[-2] have the same sign.
If all of these conditions are satisfied then the correctly rounded
value of Sum(p) is p[-1] + 2*p[-2].
Proof
-----
Since, by assumption, p[-2] is nonzero, p[-1] + p[-2] is not exactly
representable. Let p[-1] and p[-1]+e be the two floating-point
numbers closest to p[-1]+p[-2]. Then abs(e) is a power of 2, and
either 0 < p[-2] <= e/2 or 0 > p[-2] >= e/2. We give the proof in the
former case; the proof in the latter is virtually identical.
So 0 < p[-2] <= e/2. Divide into cases:
Case 1: p[-2] < e/2. Then by the previous lemma, 0 < Sum(p[:-1]) <
e/2, so p[-1] < Sum(p) < p[-1] + e/2, hence the sum is closer to p[-1]
than p[-1] + e and p[-1] gives the correctly rounded value.
Case 2: len(p) == 2. Then by assumption, p[-1] is the correctly
rounded value of p[-1] + p[-2] == Sum(p).
Case 3: p[-2] == e/2, len(p) >= 3, and p[-2] and p[-3] have opposite
signs. Then -e/2 < p[-3] < 0, so by the preceding lemma,
-e/2 < Sum(p[:-2]) < 0.
Adding p[-1] + p[-2] == p[-1] + e/2 gives:
p[-1] < Sum(p) < p[-1] + e/2
Hence p[-1] is the closest floating-point number to Sum(p).
Case 4: p[-2] == e/2, len(p) >= 3, and p[-2] and p[-3] have the same
sign. Then 0 < p[-3] < e/2, so by the preceding lemma, 0 <
Sum(p[:-2]) < e/2. Hence
p[-1] + e/2 < Sum(p) < p[-1] + e
So p[-1] + e is the closest floating-point number to Sum(p), and the
correctly rounded result is p[-1] + e = p[-1] + 2*p[-2].
Finally, the case where p[-1] is +- the largest representable float
and p[-2] has the same sign as p[-1] needs special treatment. But if
p[-1] is the largest representable float and p[-2] is positive then,
letting e be the difference between p[-1] and 2**exp_max,
p[-1] < p[-1] + p[-2] < p[-1] + e/2
where the second inequality is strict, because p[-1]+e/2 would be
rounded up to 2**exp_max (and hence would overflow), contradicting the
assumption that p[-1] + p[-2] rounds to p[-1]. Now as in case 1
above, it follows that Sum(p) rounds to p[-1].
Notes on correctness of msum
==============================
(Assuming the proof of correctness for the original algorithm.)
Divide into two stages:
Stage 1: accumulating partials
Stage 2: summing partials
Stage 1: we follow the original algorithm, but whenever an
overflow occurs in computing x+y == hi + lo, we instead compute
x+y-2**1024 = hi+lo. As before, hi is the correctly rounded value
of x+y-2**1024, and abs(lo) <= min(abs(x), abs(y)). From these
properties it follows as in the original proof that the elements
of partials[1:] are nonzero, nonadjacent, and increasing in
magnitude.
We keep track of the extra multiples of 2.**1024 in partials[0].
It's theoretically possible for partials[0] itself to overflow,
but only if there are more than 2**53 overflows, which seems
exceedingly unlikely in practice.
Lemma
-----
Suppose that x and y are floats with abs(y) <= abs(x). If x+y
overflows and x is positive then x-2**1024 is exactly
representable, and can be exactly computed as x-2**1023-2**1023.
If x+y overflows and x is negative then x+2**1024 is exactly
representable, and can be computed as x+2**1023+2**1023.
Proof
-----
If x+y overflows and x is positive then 2**1023 <= x < 2**1024,
and it follows that both x-2**1023 and x-2**1024 are exactly
representable.
At the end of stage 1, the value of the sum of the original
iterable is Sum(partials[1:]) + 2**1024*partials[0].
(Special values: if a NaN is encountered the function exits
immediately, returning that NaN. Otherwise, partials[0] will be
nan if the iterable contained both +inf and -inf, +inf if the
iterable contained +inf but not -inf, and -inf if the iterable
contained -inf but not +inf.)
Stage 2: if partials[0] == 0.0, we just sum partials[1:], using
crsum. If abs(partials[0]) >= 2.0 then the sum overflowed. The
difficult case is when abs(partials[0]) == 1.0.
For notational simplicity, assume that partials[0] == 1.0, so that
we want to compute 2**1024 + Sum(partials[1:]). Write p for
partials. Then we have the following:
Lemma
-----
Assume len(p) > 1. Assume IEEE 754 format and semantics
(float_info.mant_dig = 53, float_info.max_exp = 1024)
If 2**1024 + p[-1] > 2**1024 - 2**970 then 2**1024 + Sum(p[1:])
overflows.
If 2**1024 + p[-1] < 2**1024 - 2**970 then 2**1024 + Sum(p[1:]) is
representable.
If 2**1024 + p[-1] == 2**1024 - 2**970 then 2**1024 + Sum(p[1:])
rounds to 2**1024-2**971 if len(p) > 2 and p[-2] < 0, and
overflows otherwise.
Proof
-----
Similar to the proof for crsum above.
To detect whether 2**1024 + p[-1] >= 2**1024 - 2**970, we test
whether 2.*(2**1023 + p[-1]/2.0) overflows. (p[-1]/2.0 is exact
except possibly when p[-1] is subnormal, but when p[-1] is
subnormal the test still produces the right result.)
To detect whether 2**1024 + p[-1] == 2**1024 - 2**970, check whether
2**1023 + p[-1]/2 == 2**1023 and 2**1023 + p[-1] - 2**1023 == p[-1].
"""
from time import time
from random import random, gauss, shuffle
from math import isinf, isnan
from sys import float_info
twopow = 2.0**(float_info.max_exp - 1)
def samesign(x, y):
return (x >= 0.0) == (y >= 0.0)
def twosum(x, y):
# assumes that abs(x) >= abs(y)
hi = x + y
lo = y - (hi - x)
return hi, lo
def crsum(partials):
"""Compute the sum of a list of nonoverlapping floats.
On input, partials is a list of nonzero, nonspecial,
nonoverlapping floats, strictly increasing in magnitude, but
possibly not all having the same sign.
On output, the sum of partials gives the error in the returned
result, which is correctly rounded (using the round-half-to-even
rule). The elements of partials remain nonzero, nonspecial,
nonoverlapping, and increasing in magnitude.
Assumes IEEE 754 float format and semantics.
"""
if not partials:
return 0.0
# sum from the top, stopping as soon as the sum is inexact.
total_so_far = partials.pop()
while partials:
total_so_far, lo = twosum(total_so_far, partials.pop())
if lo:
partials.append(lo)
break
# adjust for correct rounding if necessary
if len(partials) >= 2 and samesign(partials[-1], partials[-2]) and \
total_so_far + 2*partials[-1] - total_so_far == 2*partials[-1]:
total_so_far += 2*partials[-1]
partials[-1] = -partials[-1]
return total_so_far
def msum(iterable):
"""Full precision sum of values in iterable. Returns the value of the
sum, rounded to the nearest representable floating-point number
using the round-half-to-even rule.
"""
# Stage 1: accumulate partials
partials = [0.0]
for x in iterable:
if isnan(x):
return x
elif isinf(x):
partials[0] += x
else:
i = 1
for y in partials[1:]:
if abs(x) < abs(y):
x, y = y, x
hi, lo = twosum(x, y)
if isinf(hi):
sign = 1 if hi > 0 else -1
x = x - twopow*sign - twopow*sign
partials[0] += sign
if abs(x) < abs(y):
x, y = y, x
hi, lo = twosum(x, y)
if lo:
partials[i] = lo
i += 1
x = hi
partials[i:] = [x] if x else []
# special cases arising from infinities
if isinf(partials[0]):
return partials[0]
elif isnan(partials[0]):
raise ValueError('infinities of both signs in summands')
# Stage 2: sum partials[1:] + 2**exp_max * partials[0]
if abs(partials[0]) == 1.0 and len(partials) > 1 and \
not samesign(partials[-1], partials[0]):
# problem case: decide whether result is representable
hi, lo = twosum(partials[0]*twopow, partials[-1]/2)
if isinf(2*hi):
# overflow, except in edge case...
if hi+2*lo-hi == 2*lo and \
len(partials) > 2 and samesign(lo, partials[-2]):
return 2*(hi+2*lo)
else:
partials[-1:] = [2*lo, 2*hi] if lo else [2*hi]
partials[0] = 0.0
if not partials[0]:
return crsum(partials[1:])
raise OverflowError('overflow in msum')
def test(func):
inf = float('inf')
nan = float('nan')
test_values = [
([], 0.0),
([0.0], 0.0),
([1e100, 1.0, -1e100, 1e-100, 1e50, -1.0, -1e50], 1e-100),
([1e308, 1e308, -1e308], 1e308),
([-1e308, 1e308, 1e308], 1e308),
([1e308, -1e308, 1e308], 1e308),
([2.0**1023, 2.0**1023, -2.0**1000], 1.7976930277114552e+308),
([twopow, twopow, twopow, twopow, -twopow, -twopow, -twopow], 8.9884656743115795e+307),
([2.0**53, -0.5, -2.0**-54], 2.0**53-1.0),
([2.0**53, 1.0, 2.0**-100], 2.0**53+2.0),
([2.0**53+10.0, 1.0, 2.0**-100], 2.0**53+12.0),
([2.0**53-4.0, 0.5, 2.0**-54], 2.0**53-3.0),
([2.0**1023-2.0**970, -1.0, 2.0**1023], 1.7976931348623157e+308),
([float_info.max, float_info.max*2.**-54], float_info.max),
([float_info.max, float_info.max*2.**-53], OverflowError),
([1./n for n in range(1, 1001)], 7.4854708605503451),
([(-1.)**n/n for n in range(1, 1001)], -0.69264743055982025),
([1.7**(i+1)-1.7**i for i in range(1000)] + [-1.7**1000], -1.0),
([inf, -inf, nan], nan),
([nan, inf, -inf], nan),
([inf, nan, inf], nan),
([inf, inf], inf),
([inf, -inf], ValueError),
([-inf, 1e308, 1e308, -inf], -inf),
([2.0**1023-2.0**970, 0.0, 2.0**1023], OverflowError),
([2.0**1023-2.0**970, 1.0, 2.0**1023], OverflowError),
([2.0**1023, 2.0**1023], OverflowError),
([2.0**1023, 2.0**1023, -1.0], OverflowError),
([twopow, twopow, twopow, twopow, -twopow, -twopow], OverflowError),
([twopow, twopow, twopow, twopow, -twopow, twopow], OverflowError),
([-twopow, -twopow, -twopow, -twopow], OverflowError),
([2.**1023, 2.**1023, -2.**971], float_info.max),
([2.**1023, 2.**1023, -2.**970], OverflowError),
([-2.**970, 2.**1023, 2.**1023, -2.**-1074], float_info.max),
([2.**1023, 2.**1023, -2.**970, 2.**-1074], OverflowError),
([-2.**1023, 2.**971, -2.**1023], -float_info.max),
([-2.**1023, -2.**1023, 2.**970], OverflowError),
([-2.**1023, -2.**1023, 2.**970, 2.**-1074], -float_info.max),
([-2.**-1074, -2.**1023, -2.**1023, 2.**970], OverflowError),
([2.**930, -2.**980, 2.**1023, 2.**1023, twopow, -twopow], 1.7976931348622137e+308),
([2.**1023, 2.**1023, -1e307], 1.6976931348623159e+308),
([1e16, 1., 1e-16], 10000000000000002.0),
]
e, t = 0, time()
for i, (vals, s) in enumerate(test_values):
if isinstance(s, type) and issubclass(s, Exception):
try:
m = func(vals)
except s:
continue
else:
pass
else:
try:
m = func(vals)
if m == s or isnan(m) and isnan(s):
continue
except Exception, m:
pass
print "test %d failed: got %r, expected %r for %s(%r)" % (
i, m, s, func.__name__, vals)
e += 1
if not e:
t = (time() - t) * 1e3
print 'all %s tests passed (%.3f ms)' % (func.__name__, t)
return e
def test2(math_sum, cmath_sum=None):
# Compare C math.- and cmath.sum results with those from the msum()
# function above. Tests copied from the original Python recipe at
#
e, t = 0, time()
for j in xrange(1000):
vals = [7, 1e100, -7, -1e100, -9e-20, 8e-20] * 10
s = 0
for i in range(200):
v = gauss(0, random()) ** 7 - s
s += v
vals.append(v)
shuffle(vals)
s = msum(vals)
try:
m = math_sum(vals)
except Exception, m:
pass
if m != s:
print 'test failed: got %r, expected %r for %.100r' % (m, s, vals)
e += 1
if cmath_sum: # check complex too
c = complex(s, -s)
try:
m = cmath_sum([complex(v, -v) for v in vals])
except Exception, m:
pass
if m != c:
print 'test failed: got %r, expected %r for %.100r' % (m, c, vals)
e += 1
if not e:
t = (time() - t) * 1e3
if cmath_sum:
c = ', incl complex'
else:
c = ''
print 'all %s tests passed%s (%.3f ms)' % ('Compare', c, t)
return e
if __name__ == '__main__':
t = time()
e = test(msum)
try:
from math import sum as math_sum
from cmath import sum as cmath_sum
e += test(math_sum)
e += test2(math_sum, cmath_sum)
except ImportError:
pass
t = (time() - t) * 1e3
if e:
print e, 'tests failed (%.3f ms)' % t
else:
print 'all tests passed (%.3f ms)' % t