This issue tracker has been migrated to GitHub, and is currently read-only.
For more information, see the GitHub FAQs in the Python's Developer Guide.

classification
Title: Assertion failure when calling statistics.variance() on a float32 Numpy array
Type: behavior Stage: resolved
Components: Library (Lib) Versions: Python 3.11, Python 3.10, Python 3.9
process
Status: closed Resolution: fixed
Dependencies: Superseder:
Assigned To: Nosy List: iritkatriel, mark.dickinson, reed, rhettinger, steven.daprano, xtreak
Priority: normal Keywords: patch

Created on 2020-01-05 05:34 by reed, last changed 2022-04-11 14:59 by admin. This issue is now closed.

Pull Requests
URL Status Linked Edit
PR 27960 merged rhettinger, 2021-08-26 01:20
Messages (12)
msg359323 - (view) Author: Reed (reed) Date: 2020-01-05 05:34
If a float32 Numpy array is passed to statistics.variance(), an assertion failure occurs. For example:

    import statistics
    import numpy as np
    x = np.array([1, 2], dtype=np.float32)
    statistics.variance(x)

The assertion error is:

    assert T == U and count == count2

Even if you convert x to a list with `x = list(x)`, the issue still occurs. The issue is caused by the following lines in statistics.py (https://github.com/python/cpython/blob/ec007cb43faf5f33d06efbc28152c7fdcb2edb9c/Lib/statistics.py#L687-L691):

    T, total, count = _sum((x-c)**2 for x in data)
    # The following sum should mathematically equal zero, but due to rounding
    # error may not.
    U, total2, count2 = _sum((x-c) for x in data)
    assert T == U and count == count2

When a float32 Numpy value is squared in the term (x-c)**2, it turns into a float64 value, causing the `T == U` assertion to fail. I think the best way to fix this would be to replace (x-c)**2 with (x-c)*(x-c). This fix would no longer assume the input's ** operator returns the same type.
msg359324 - (view) Author: Karthikeyan Singaravelan (xtreak) * (Python committer) Date: 2020-01-05 07:56
I think it's more of an implementation artifact of numpy eq definition for float32 and float64 and can possibly break again if (x-c) * (x-c) was also changed to return float64 in future.
msg359325 - (view) Author: Steven D'Aprano (steven.daprano) * (Python committer) Date: 2020-01-05 08:03
Nice analysis and bug report, thank you! That's pretty strange behaviour for float32, but I guess we're stuck with it.

I wonder if the type assertion has outlived its usefulness? I.e. drop the `T == U` part and change the assertion to `assert count == count2` only.

If we removed the failing part of the assertion, and changed the final line to `return (U, total)`, that ought to keep the exact sum but convert to float32 later, rather than float64.

I am inclined to have the stdev of float32 return a float32 is possible. What do you think?

We should check the numpy docs to see what the conversion rules for numpy floats are.
msg359329 - (view) Author: Mark Dickinson (mark.dickinson) * (Python committer) Date: 2020-01-05 10:56
[Karthikeyan]

> can possibly break again if (x-c) * (x-c) was also changed to return float64 in future

I think it's safe to assume that multiplying two NumPy float32's will continue to give a float32 back in the future; NumPy has no reason to give back a different type, and changing this would be a big breaking change.

The big difference between (x-c)**2 and (x-c)*(x-c) in this respect is that the latter is purely an operation on float32 operands, while the former is a mixed-type operation: NumPy sees a binary operation between a float32 and a Python int, and has to figure out a suitable type for the result. The int to float32 conversion is regarded as introducing unacceptable precision loss, so it chooses float64 as an acceptable output type, converts both input operands to that type, and then does the computation.

Some relevant parts of the NumPy docs:

https://docs.scipy.org/doc/numpy/reference/ufuncs.html#casting-rules
https://docs.scipy.org/doc/numpy/reference/generated/numpy.result_type.html

Even for pure Python floats, x*x is a simpler, more accurate, and likely faster operation than x**2. On a typical system, the former (eventually) maps to a hardware floating-point multiplication, which is highly likely to be correctly rounded, while the latter, after conversion of the r.h.s. to a float, maps to a libm pow call. That libm pow call *could* conceivably have a fast/accurate path for a right-hand-side of 2.0, but it could equally conceivably not have such a path.

OTOH, (x-c)*(x-c) repeats the subtraction unnecessarily, but perhaps assignment expressions could rescue us? For example, replace:

    sum((x-c)**2 for x in data)

with:

    sum((y:=(x-c))*y for x in data)

[Steven]

> I am inclined to have the stdev of float32 return a float32 is possible. 

Would that also imply intermediate calculations being performed only with float32, or would intermediate calculations be performed with a more precise type? float32 has small enough precision to run into real accuracy problems with modestly large datasets. For example:

    >>> import numpy as np
    >>> sum(np.ones(10**8, dtype=np.float32), start=np.float32(0))
    16777216.0  # should be 10**8

or, less dramatically:

    >>> sum(np.full(10**6, 1.73, dtype=np.float32), start=np.float32(0)) / 10**6
    1.74242125  # only two accurate significant figures
msg359381 - (view) Author: Reed (reed) Date: 2020-01-05 20:37
Thank you all for the comments! Either using (x-c)*(x-c), or removing the assertion and changing the final line to `return (U, total)`, seem reasonable. I slightly prefer the latter case, due to Mark's comments about x*x being faster and simpler than x**2. But I am not an expert on this.

> I am inclined to have the stdev of float32 return a float32 is possible. What do you think?

Agreed.

> OTOH, (x-c)*(x-c) repeats the subtraction unnecessarily, but perhaps assignment expressions could rescue us?

Yeah, we should avoid repeating the subtraction. Another method of doing so is to define a square function. For example:

    def square(y):
        return y*y
    sum(square(x-c) for x in data)

> Would that also imply intermediate calculations being performed only with float32, or would intermediate calculations be performed with a more precise type?

Currently, statistics.py computes sums in infinite precision (https://github.com/python/cpython/blob/422ed16fb846eec0b5b2a4eb3a978c9862615665/Lib/statistics.py#L123) for any type. The multiplications (and exponents if we go that route) would still be float32.
msg399964 - (view) Author: Irit Katriel (iritkatriel) * (Python committer) Date: 2021-08-20 11:51
I've reproduced this on 3.9 and 3.10. This part of the code in main is still the same, so the issue is probably there even though we don't have numpy with which to test.
msg400006 - (view) Author: Raymond Hettinger (rhettinger) * (Python committer) Date: 2021-08-21 00:41
Removing the assertion and implementing Steven's idea seems like the best way to go:

    sum((y:=(x-c)) * y for x in data)
msg400303 - (view) Author: Raymond Hettinger (rhettinger) * (Python committer) Date: 2021-08-26 01:14
The rounding correction in _ss() looks mathematically incorrect to me:

   ∑ (xᵢ - x̅ + εᵢ)² = ∑ (xᵢ - x̅)² - (∑ εᵢ)² ÷ n

If we drop this logic (which seems completely bogus), all the tests still pass and the code becomes cleaner:

    def _ss(data, c=None):
        if c is None:
            c = mean(data)
        T, total, count = _sum((y := x - c) * y for x in data)
        return (T, total)


-- Algebraic form of the current code ----------------------

from sympy import symbols, simplify

x1, x2, x3, e1, e2, e3 = symbols('x1 x2 x3 e1 e2 e3')
n = 3

# high accuracy mean
c = (x1 + x2 + x3) / n

# sum of squared deviations with subtraction errors
total = (x1 - c + e1)**2 + (x2 - c + e2)**2 + (x3 - c + e3)**2

# sum of subtraction errors = e1 + e2 + e3
total2 = (x1 - c + e1) + (x2 - c + e2) + (x3 - c + e3)

# corrected sum of squared deviations
total -= total2 ** 2 / n

# exact sum of squared deviations
desired = (x1 - c)**2 + (x2 - c)**2 + (x3 - c)**2

# expected versus actual
print(simplify(desired - total))

This gives:

    (e1 + e2 + e3)**2/3
    + (-2*x1 + x2 + x3)**2/9
    + (x1 - 2*x2 + x3)**2/9
    + (x1 + x2 - 2*x3)**2/9
    - (3*e1 + 2*x1 - x2 - x3)**2/9
    - (3*e2 - x1 + 2*x2 - x3)**2/9
    - (3*e3 - x1 - x2 + 2*x3)**2/9

-- Substituting in concrete values ----------------------

x1, x2, x3, e1, e2, e3 = 11, 17, 5, 0.3, 0.1, -0.2

This gives:

    75.74000000000001  uncorrected total
    75.72666666666667  "corrected" total
    72.0               desired result
msg400321 - (view) Author: Mark Dickinson (mark.dickinson) * (Python committer) Date: 2021-08-26 08:21
> The rounding correction in _ss() looks mathematically incorrect to me [...]

I don't think it was intended as a rounding correction - I think it's just computing the variance (prior to the division by n or n-1) of the `(x - c)` terms using the standard "expectation of x^2 - (expectation of x)^2" formula:

  sum((x - c)**2 for x in data) - (sum(x - c for x in data)**2) / n

So I guess it *can* be thought of as a rounding correction, but what it's correcting for is an inaccurate value of "c"; it's not correcting for inaccuracies in the subtraction results. That is, if you were to add an artificial error into c at some point before computing "total" and "total2", that correction term should take you back to something approaching the true sum of squares of deviations.

So mathematically, I think it's correct, but not useful, because mathematically "total2" will be zero. Numerically, it's probably not helpful.
msg400323 - (view) Author: Mark Dickinson (mark.dickinson) * (Python committer) Date: 2021-08-26 08:38
> what it's correcting for is an inaccurate value of "c" [...]

In more detail:

Suppose "m" is the true mean of the x in data, but all we have is an approximate mean "c" to work with. Write "e" for the error in that approximation, so that c = m + e. Then (using Python notation, but treating the expressions as exact mathematical expressions computed in the reals):

   sum((x-c)**2 for x in data)

== sum((x-m-e)**2 for x in data)

== sum((x - m)**2 for x in data) - 2 * sum((x - m)*e for x in data)
                                 + sum(e**2 for x in data)

== sum((x - m)**2 for x in data) - 2 * e * sum((x - m) for x in data)
                                 + sum(e**2 for x in data)

== sum((x - m)**2 for x in data) + sum(e**2 for x in data)
       (because sum((x - m) for x in data) is 0)

== sum((x - m)**2 for x in data) + n*e**2

So the error in our result arising from the error in computing m is that n*e**2 term. And that's the term that's being subtracted here, because

   sum(x - c for x in data) ** 2 / n
== sum(x - m - e for x in data) ** 2 / n
== (sum(x - m for x in data) - sum(e for x in data))**2 / n
== (0 - n * e)**2 / n
== n * e**2
msg400359 - (view) Author: Raymond Hettinger (rhettinger) * (Python committer) Date: 2021-08-26 16:35
> what it's correcting for is an inaccurate value of "c" [...]

I'll leave the logic as-is and just add a note about what is being corrected.

> Numerically, it's probably not helpful.

To make a difference, the mean would have to have huge magnitude relative to the variance; otherwise, squaring the error would drown it out to zero.

The mean() should already be accurate to within a 1/2 ulp.  The summation and division are exact.  There is only a single rounding when the result converts from Fraction to a float or decimal.
msg400681 - (view) Author: Raymond Hettinger (rhettinger) * (Python committer) Date: 2021-08-31 01:57
New changeset 793f55bde9b0299100c12ddb0e6949c6eb4d85e5 by Raymond Hettinger in branch 'main':
bpo-39218: Improve accuracy of variance calculation (GH-27960)
https://github.com/python/cpython/commit/793f55bde9b0299100c12ddb0e6949c6eb4d85e5
History
Date User Action Args
2022-04-11 14:59:24adminsetgithub: 83399
2021-08-31 01:58:53rhettingersetstatus: open -> closed
resolution: fixed
stage: patch review -> resolved
2021-08-31 01:57:48rhettingersetmessages: + msg400681
2021-08-26 16:35:52rhettingersetmessages: + msg400359
2021-08-26 08:38:33mark.dickinsonsetmessages: + msg400323
2021-08-26 08:21:11mark.dickinsonsetmessages: + msg400321
2021-08-26 01:20:31rhettingersetkeywords: + patch
stage: patch review
pull_requests: + pull_request26406
2021-08-26 01:14:13rhettingersetmessages: + msg400303
2021-08-25 13:20:19taleinatsetnosy: - taleinat
2021-08-21 00:41:40rhettingersetmessages: + msg400006
2021-08-20 11:51:03iritkatrielsetnosy: + iritkatriel

messages: + msg399964
versions: + Python 3.9, Python 3.10, Python 3.11, - Python 3.8
2020-01-05 20:37:33reedsetmessages: + msg359381
2020-01-05 10:56:51mark.dickinsonsetnosy: + mark.dickinson
messages: + msg359329
2020-01-05 08:03:42steven.dapranosetmessages: + msg359325
2020-01-05 07:56:10xtreaksetnosy: + rhettinger, xtreak, steven.daprano, taleinat
messages: + msg359324
2020-01-05 05:34:52reedcreate