diff -r ac11bf14618d Lib/statistics.py --- a/Lib/statistics.py Sun Sep 27 02:14:23 2015 -0700 +++ b/Lib/statistics.py Sun Sep 27 17:01:18 2015 +0300 @@ -114,13 +114,17 @@ # === Private utilities === -def _sum(data, start=0): - """_sum(data [, start]) -> value +def _sum(data, start=0, fraction=False): + """_sum(data [, start][,fraction]) -> value Return a high-precision sum of the given numeric data. If optional argument ``start`` is given, it is added to the total. If ``data`` is - empty, ``start`` (defaulting to 0) is returned. - + empty, ``start`` (defaulting to 0) is returned. If ``fraction`` is + True, will return the sum as a fractions.Fraction. This allows greater + precision. + In case any of the data causes OverflowError (inf or -inf for example) + fraction=True will be ignored and return the sum in it's original form + as fractions do not support inf. Examples -------- @@ -136,61 +140,118 @@ Fractions and Decimals are also supported: >>> from fractions import Fraction as F - >>> _sum([F(2, 3), F(7, 5), F(1, 4), F(5, 6)]) - Fraction(63, 20) + >>> _sum([F(2, 3), F(7, 5), F(1, 4), F(5, 6), 1]) + Fraction(83, 20) >>> from decimal import Decimal as D >>> data = [D("0.1375"), D("0.2108"), D("0.3061"), D("0.0419")] >>> _sum(data) Decimal('0.6963') - Mixed types are currently treated as an error, except that int is - allowed. + In case fraction is False, mixed types are treated as an error. + ints however do not count into the types limit. """ # We fail as soon as we reach a value that is not an int or the type of # the first value which is not an int. E.g. _sum([int, int, float, int]) # is okay, but sum([int, int, float, Fraction]) is not. - allowed_types = {int, type(start)} n, d = _exact_ratio(start) partials = {d: n} # map {denominator: sum of numerators} + data_types = {int, type(start)} # Helps avoid generator exhaustion + # Micro-optimizations. exact_ratio = _exact_ratio partials_get = partials.get + # Add numerators for each denominator. for x in data: - _check_type(type(x), allowed_types) + data_types.add(type(x)) + + # Micro-optimization, avoids creating set every cycle + # The check in here also avoids cycling over all data + # in case of an error at the beginning. + if len(data_types) > 2: + _check_mixed_types(data_types) + n, d = exact_ratio(x) partials[d] = partials_get(d, 0) + n - # Find the expected result type. If allowed_types has only one item, it - # will be int; if it has two, use the one which isn't int. - assert len(allowed_types) in (1, 2) - if len(allowed_types) == 1: - assert allowed_types.pop() is int - T = int - else: - T = (allowed_types - {int}).pop() + if None in partials: - assert issubclass(T, (float, Decimal)) assert not math.isfinite(partials[None]) - return T(partials[None]) + return partials[None] + total = Fraction() for d, n in sorted(partials.items()): total += Fraction(n, d) - if issubclass(T, int): - assert total.denominator == 1 - return T(total.numerator) - if issubclass(T, Decimal): - return T(total.numerator)/total.denominator - return T(total) + if fraction: + return total + + return _convert_common_type(total, types=data_types) -def _check_type(T, allowed): - if T not in allowed: - if len(allowed) == 1: - allowed.add(T) - else: - types = ', '.join([t.__name__ for t in allowed] + [T.__name__]) - raise TypeError("unsupported mixed types: %s" % types) + +def _check_mixed_types(types, exempt={int}): + """_check_mixed_types(types) -> None + + Verifies there is only 1 type in ``types`` other than + the types in the ``exempt`` set. + + Raises TypeError in case there are mixed types + not in the exempt set. + """ + if len(set(types).difference(exempt)) > 1: + types = ', '.join(t.__name__ for t in types) + raise TypeError("unsupported mixed types: %s" % types) + + +def _convert_common_type(value, data=None, types=None): + """_convert_common_type(value, data) -> value + + Converts ``value`` to a type common to ``data``. + Mixed types are not allowed. In case ``data`` includes + an int and another numeric class, this function will return + the higher precision class. + If data contains only ints, or is empty, value will be + converted into an int if there is no data loss or float if not. + If ``types`` is specified, ``data`` is ignored. ``types`` should + be a list of types. + + Examples + -------- + + >>> from fractions import Fraction + >>> _convert_common_type(Fraction(5,3), [1, 2, 3]) + 1.6666666666666667 + + >>> from fractions import Fraction + >>> from decimal import Decimal + >>> _convert_common_type(Fraction(3,2), [1, Decimal("2.2"), 3]) + Decimal('1.5') + + >>> from fractions import Fraction + >>> _convert_common_type(Fraction(9,8), [1, 2, 3.2]) + 1.125 + """ + if types: + data_types = set(types) + else: + data_types = {type(x) for x in data} + data_types.add(int) + + # Check if there are 2 or more types besides int + if len(data_types) > 2: + _check_mixed_types(data_types, exempt={int}) + + if len(data_types) == 1: + assert data_types.pop() is int + if value % 1 == 0: # No data loss + return int(value) + return float(value) + + T = (data_types - {int}).pop() + if issubclass(T, Decimal) and isinstance(value, Fraction): + return T(value.numerator)/value.denominator + + return T(value) def _exact_ratio(x): @@ -275,6 +336,12 @@ >>> mean([1, 2, 3, 4, 4]) 2.8 + >>> mean([8.99e+307, 8.989e+307]) + 8.9895e+307 + + >>> mean([5e-324, 5e-324]) + 5e-324 + >>> from fractions import Fraction as F >>> mean([F(3, 7), F(1, 21), F(5, 3), F(1, 3)]) Fraction(13, 21) @@ -290,7 +357,7 @@ n = len(data) if n < 1: raise StatisticsError('mean requires at least one data point') - return _sum(data)/n + return _convert_common_type(_sum(data, fraction=True)/n, data=data) # FIXME: investigate ways to calculate medians without sorting? Quickselect? @@ -311,7 +378,7 @@ n = len(data) if n == 0: raise StatisticsError("no median for empty data") - if n%2 == 1: + if n % 2 == 1: return data[n//2] else: i = n//2 @@ -334,7 +401,7 @@ n = len(data) if n == 0: raise StatisticsError("no median for empty data") - if n%2 == 1: + if n % 2 == 1: return data[n//2] else: return data[n//2 - 1] diff -r ac11bf14618d Lib/test/test_statistics.py --- a/Lib/test/test_statistics.py Sun Sep 27 02:14:23 2015 -0700 +++ b/Lib/test/test_statistics.py Sun Sep 27 17:01:18 2015 +0300 @@ -718,25 +718,19 @@ self.assertEqual(t, (147000, 1)) -class CheckTypeTest(unittest.TestCase): - # Test _check_type private function. +class CheckMixedTypesTest(unittest.TestCase): + # Test _check_mixed_types private function. def test_allowed(self): # Test that a type which should be allowed is allowed. allowed = set([int, float]) - statistics._check_type(int, allowed) - statistics._check_type(float, allowed) + statistics._check_mixed_types(allowed) + statistics._check_mixed_types(allowed) def test_not_allowed(self): # Test that a type which should not be allowed raises. - allowed = set([int, float]) - self.assertRaises(TypeError, statistics._check_type, Decimal, allowed) - - def test_add_to_allowed(self): - # Test that a second type will be added to the allowed set. - allowed = set([int]) - statistics._check_type(float, allowed) - self.assertEqual(allowed, set([int, float])) + allowed = set([int, float]).add(Decimal) + self.assertRaises(TypeError, statistics._check_mixed_types, allowed) # === Tests for public functions === @@ -880,7 +874,7 @@ def setUp(self): self.func = statistics._sum - def test_empty_data(self): + def test_empty_data(self): # Override test for empty data. for data in ([], (), iter([])): self.assertEqual(self.func(data), 0)