import fractions import itertools import math import unittest F = fractions.Fraction def fma(a, b, c, corner_case_raises=False): """ IEEE 754-2008 compliant FMA implementation. The IEEE 754-2008 specification leaves one FMA corner case to the implementation: namely whether the cases zero * infinity + nan and infinity * zero + nan return nan or raise an invalid operation FPE. By default, this implementation follows Intel's x64 FMA3 implementation, which returns NaN in these cases. If corner_case_raises is True, then this corner case raise ValueError instead (which matches the behaviour of the decimal module's fma operation). """ # No need to worry about signaling NaNs, since Python doesn't support them. # Reduce to case where all values are non-NaN. if math.isnan(a): return a if math.isnan(b): return b if corner_case_raises: # IBM decimal arithmetic spec style: inf * 0 + nan raises. if math.isinf(a) and b == 0 or math.isinf(b) and a == 0: raise ValueError("Invalid operation in math.fma.") if math.isnan(c): return c else: # Intel style: inf * 0 + nan returns a NaN. if math.isnan(c): return c if math.isinf(a) and b == 0 or math.isinf(b) and a == 0: raise ValueError("Invalid operation in math.fma.") # For correct handling of infs and zeros, we need the signs of a*b and c. sign_ab = math.copysign(1.0, a) * math.copysign(1.0, b) sign_c = math.copysign(1.0, c) # Reduce further to case where all values are finite. if math.isinf(a) or math.isinf(b): if math.isinf(c) and sign_ab != sign_c: raise ValueError("Invalid operation in math.fma.") return -math.inf if sign_ab < 0 else math.inf if math.isinf(c): return c # Now a, b and c are finite. Compute the result exactly, as a fraction. result_as_fraction = F(a) * F(b) + F(c) if result_as_fraction == 0: return -0.0 if sign_ab == sign_c == -1.0 else 0.0 # Will raise OverflowError if result is outside the range of a float. return float(result_as_fraction) class TestFMA(unittest.TestCase): def test_fma_nan_results(self): # Combinations where at least one input is a NaN. # Each combination should return NaN, without raising. # Excludes the inf * 0 + nan case. selected = [ -math.inf, -1e300, -2.3, -1e-300, -0.0, 0.0, 1e-300, 2.3, 1e300, math.inf, math.nan ] nan_combinations = [ (a, b, c) for a in selected for b in selected for c in selected if (math.isnan(a) or math.isnan(b) or math.isnan(c)) ] for a, b, c in nan_combinations: # case inf * 0 + nan and 0 * inf + nan is implementation dependent, # and is tested separately below. if math.isinf(a) and b == 0 or math.isinf(b) and a == 0: continue self.assertIsNaN(fma(a, b, c)) def test_fma_implementation_defined_case_intel_style(self): self.assertIsNaN(fma(0, math.inf, math.nan, corner_case_raises=False)) self.assertIsNaN(fma(math.inf, 0, math.nan, corner_case_raises=False)) def test_fma_implementation_defined_case_ibm_style(self): with self.assertRaises(ValueError): fma(0, math.inf, math.nan, corner_case_raises=True) with self.assertRaises(ValueError): fma(math.inf, 0, math.nan, corner_case_raises=True) def test_fma_infinities(self): positives = [1e-300, 2.3, 1e300, math.inf] finites = [-1e300, -2.3, -1e-300, -0.0, 0.0, 1e-300, 2.3, 1e300] non_nans = [-math.inf, -2.3, -0.0, 0.0, 2.3, math.inf] # ValueError due to inf * 0 computation. for c in non_nans: with self.assertRaises(ValueError): fma(math.inf, 0.0, c) with self.assertRaises(ValueError): fma(-math.inf, 0.0, c) with self.assertRaises(ValueError): fma(math.inf, -0.0, c) with self.assertRaises(ValueError): fma(-math.inf, -0.0, c) # ValueError when a*b and c both infinite of opposite signs. for b in positives: with self.assertRaises(ValueError): fma(math.inf, b, -math.inf) with self.assertRaises(ValueError): fma(math.inf, -b, math.inf) with self.assertRaises(ValueError): fma(-math.inf, -b, -math.inf) with self.assertRaises(ValueError): fma(-math.inf, b, math.inf) # Infinite result when a*b and c both infinite of the same sign. for b in positives: self.assertEqual(fma(math.inf, b, math.inf), math.inf) self.assertEqual(fma(math.inf, -b, -math.inf), -math.inf) self.assertEqual(fma(-math.inf, -b, math.inf), math.inf) self.assertEqual(fma(-math.inf, b, -math.inf), -math.inf) # Infinite result when a*b finite, c infinite. for a, b in itertools.product(finites, finites): self.assertEqual(fma(a, b, math.inf), math.inf) self.assertEqual(fma(a, b, -math.inf), -math.inf) # Infinite result when a*b infinite, c finite. for b, c in itertools.product(positives, finites): self.assertEqual(fma(math.inf, b, c), math.inf) self.assertEqual(fma(-math.inf, b, c), -math.inf) self.assertEqual(fma(-math.inf, -b, c), math.inf) self.assertEqual(fma(math.inf, -b, c), -math.inf) self.assertEqual(fma(b, math.inf, c), math.inf) self.assertEqual(fma(b, -math.inf, c), -math.inf) self.assertEqual(fma(-b, -math.inf, c), math.inf) self.assertEqual(fma(-b, math.inf, c), -math.inf) def test_fma_zero_result(self): nonnegative_finites = [0.0, 1e-300, 2.3, 1e300] # Zero results from exact zero inputs. for b in nonnegative_finites: self.assertIsPositiveZero(fma(0.0, b, 0.0)) self.assertIsPositiveZero(fma(0.0, b, -0.0)) self.assertIsNegativeZero(fma(0.0, -b, -0.0)) self.assertIsPositiveZero(fma(0.0, -b, 0.0)) self.assertIsPositiveZero(fma(-0.0, -b, 0.0)) self.assertIsPositiveZero(fma(-0.0, -b, -0.0)) self.assertIsNegativeZero(fma(-0.0, b, -0.0)) self.assertIsPositiveZero(fma(-0.0, b, 0.0)) self.assertIsPositiveZero(fma(b, 0.0, 0.0)) self.assertIsPositiveZero(fma(b, 0.0, -0.0)) self.assertIsNegativeZero(fma(-b, 0.0, -0.0)) self.assertIsPositiveZero(fma(-b, 0.0, 0.0)) self.assertIsPositiveZero(fma(-b, -0.0, 0.0)) self.assertIsPositiveZero(fma(-b, -0.0, -0.0)) self.assertIsNegativeZero(fma(b, -0.0, -0.0)) self.assertIsPositiveZero(fma(b, -0.0, 0.0)) # Exact zero result from nonzero inputs. self.assertIsPositiveZero(fma(2.0, 2.0, -4.0)) self.assertIsPositiveZero(fma(2.0, -2.0, 4.0)) self.assertIsPositiveZero(fma(-2.0, -2.0, -4.0)) self.assertIsPositiveZero(fma(-2.0, 2.0, 4.0)) # Underflow to zero. tiny = 1e-300 self.assertIsPositiveZero(fma(tiny, tiny, 0.0)) self.assertIsNegativeZero(fma(tiny, -tiny, 0.0)) self.assertIsPositiveZero(fma(-tiny, -tiny, 0.0)) self.assertIsNegativeZero(fma(-tiny, tiny, 0.0)) self.assertIsPositiveZero(fma(tiny, tiny, -0.0)) self.assertIsNegativeZero(fma(tiny, -tiny, -0.0)) self.assertIsPositiveZero(fma(-tiny, -tiny, -0.0)) self.assertIsNegativeZero(fma(-tiny, tiny, -0.0)) # Corner case where rounding the multiplication would # give the wrong result. x = float.fromhex('0x1p-500') y = float.fromhex('0x1p-550') z = float.fromhex('0x1p-1000') self.assertIsNegativeZero(fma(x-y, x+y, -z)) self.assertIsPositiveZero(fma(y-x, x+y, z)) self.assertIsNegativeZero(fma(y-x, -(x+y), -z)) self.assertIsPositiveZero(fma(x-y, -(x+y), z)) def test_fma_overflow(self): a = b = float.fromhex('0x1p512') c = float.fromhex('0x1p1023') # Overflow from multiplication. with self.assertRaises(OverflowError): fma(a, b, 0.0) # Overflow from the addition. with self.assertRaises(OverflowError): fma(a, b/2.0, c) # No overflow, even though a*b overflows a float. self.assertEqual(fma(a, b, -c), c) # Extreme case: a * b is exactly at the overflow boundary, so the # tiniest offset makes a difference between overflow and a finite # result. a = float.fromhex('0x1.ffffffc000000p+511') b = float.fromhex('0x1.0000002000000p+512') c = float.fromhex('0x0.0000000000001p-1022') with self.assertRaises(OverflowError): fma(a, b, 0.0) with self.assertRaises(OverflowError): fma(a, b, c) self.assertEqual(fma(a, b, -c), float.fromhex('0x1.fffffffffffffp+1023')) # Another extreme case: here a*b is about as large as possible subject # to fma(a, b, c) being finite. a = float.fromhex('0x1.ae565943785f9p+512') b = float.fromhex('0x1.3094665de9db8p+512') c = float.fromhex('0x1.fffffffffffffp+1023') self.assertEqual(fma(a, b, -c), c) def test_fma_single_round(self): a = float.fromhex('0x1p-50') self.assertEqual(fma(a - 1.0, a + 1.0, 1.0), a*a) def test_random(self): # A collection of randomly generated inputs for which the naive FMA (with # two rounds) gives a different result from a singly-rounded FMA. # tuples (a, b, c, expected) test_values = [ ('0x1.694adde428b44p-1', '0x1.371b0d64caed7p-1', '0x1.f347e7b8deab8p-4', '0x1.19f10da56c8adp-1'), ('0x1.605401ccc6ad6p-2', '0x1.ce3a40bf56640p-2', '0x1.96e3bf7bf2e20p-2', '0x1.1af6d8aa83101p-1'), ('0x1.e5abd653a67d4p-2', '0x1.a2e400209b3e6p-1', '0x1.a90051422ce13p-1', '0x1.37d68cc8c0fbbp+0'), ('0x1.f94e8efd54700p-2', '0x1.123065c812cebp-1', '0x1.458f86fb6ccd0p-1', '0x1.ccdcee26a3ff3p-1'), ('0x1.bd926f1eedc96p-1', '0x1.eee9ca68c5740p-1', '0x1.960c703eb3298p-2', '0x1.3cdcfb4fdb007p+0'), ('0x1.27348350fbccdp-1', '0x1.3b073914a53f1p-1', '0x1.e300da5c2b4cbp-1', '0x1.4c51e9a3c4e29p+0'), ('0x1.2774f00b3497bp-1', '0x1.7038ec336bff0p-2', '0x1.2f6f2ccc3576bp-1', '0x1.99ad9f9c2688bp-1'), ('0x1.51d5a99300e5cp-1', '0x1.5cd74abd445a1p-1', '0x1.8880ab0bbe530p-1', '0x1.3756f96b91129p+0'), ('0x1.73cb965b821b8p-2', '0x1.218fd3d8d5371p-1', '0x1.d1ea966a1f758p-2', '0x1.5217b8fd90119p-1'), ('0x1.4aa98e890b046p-1', '0x1.954d85dff1041p-1', '0x1.122b59317ebdfp-1', '0x1.0bf644b340cc5p+0'), ('0x1.e28f29e44750fp-1', '0x1.4bcc4fdcd18fep-1', '0x1.fd47f81298259p-1', '0x1.9b000afbc9995p+0'), ('0x1.d2e850717fe78p-3', '0x1.1dd7531c303afp-1', '0x1.e0869746a2fc2p-2', '0x1.316df6eb26439p-1'), ('0x1.cf89c75ee6fbap-2', '0x1.b23decdc66825p-1', '0x1.3d1fe76ac6168p-1', '0x1.00d8ea4c12abbp+0'), ('0x1.3265ae6f05572p-2', '0x1.16d7ec285f7a2p-1', '0x1.0b8405b3827fbp-1', '0x1.5ef33c118a001p-1'), ('0x1.c4d1bf55ec1a5p-1', '0x1.bc59618459e12p-2', '0x1.ce5b73dc1773dp-1', '0x1.496cf6164f99bp+0'), ('0x1.d350026ac3946p-1', '0x1.9a234e149a68cp-2', '0x1.f5467b1911fd6p-2', '0x1.b5cee3225caa5p-1'), ] for a_hex, b_hex, c_hex, expected_hex in test_values: a = float.fromhex(a_hex) b = float.fromhex(b_hex) c = float.fromhex(c_hex) expected = float.fromhex(expected_hex) self.assertEqual(fma(a, b, c), expected) # Custom assertions. def assertIsNaN(self, value): self.assertTrue(math.isnan(value), msg="Expected a NaN, got {!r}".format(value)) def assertIsPositiveZero(self, value): self.assertTrue( value == 0 and math.copysign(1, value) > 0, msg="Expected a positive zero, got {!r}".format(value) ) def assertIsNegativeZero(self, value): self.assertTrue( value == 0 and math.copysign(1, value) < 0, msg="Expected a negative zero, got {!r}".format(value) )