diff --git Lib/test/test_math.py Lib/test/test_math.py index 525ee09..8fa45c8 100644 --- Lib/test/test_math.py +++ Lib/test/test_math.py @@ -365,18 +365,18 @@ class MathTests(unittest.TestCase): self.ftest('fabs(1)', math.fabs(1), 1) def testFactorial(self): - def fact(n): - result = 1 - for i in range(1, int(n)+1): - result *= i - return result - values = list(range(10)) + [50, 100, 500] - random.shuffle(values) - for x in range(10): - for cast in (int, float): - self.assertEqual(math.factorial(cast(x)), fact(x), (x, fact(x), math.factorial(x))) + self.assertEqual(math.factorial(0), 1) + self.assertEqual(math.factorial(0.0), 1) + total = 1 + for i in range(1, 1000): + total *= i + self.assertEqual(math.factorial(i), total) + self.assertEqual(math.factorial(float(i)), total) self.assertRaises(ValueError, math.factorial, -1) + self.assertRaises(ValueError, math.factorial, -1.0) self.assertRaises(ValueError, math.factorial, math.pi) + self.assertRaises(OverflowError, math.factorial, sys.maxsize+1) + self.assertRaises(OverflowError, math.factorial, 10e100) def testFloor(self): self.assertRaises(TypeError, math.floor) diff --git Modules/mathmodule.c Modules/mathmodule.c index 76d7906..5a05f70 100644 --- Modules/mathmodule.c +++ Modules/mathmodule.c @@ -1129,11 +1129,147 @@ PyDoc_STRVAR(math_fsum_doc, Return an accurate floating point sum of values in the iterable.\n\ Assumes IEEE-754 floating point arithmetic."); +/* Find the index of the highest set bit. Equivalent to floor(lg(x))+1. + * Also equivalent to: bitwidth_of_type - count_leading_zero_bits(x) + */ + +/* XXX: This routine does more or less the same thing as + * bits_in_digit() in Objects/longobject.c.Someday it would be nice to + * consolidate them. On BSD, there's a library function called fls() + * that we could use, and GCC provides __builtin_clz(). + */ + +static unsigned long +find_last_set_bit(unsigned long n) +{ + unsigned long len = 0; + while (n != 0) { + ++len; + n >>= 1; + } + return len; +} + +static unsigned long +count_set_bits(unsigned long n) +{ + unsigned long count = 0; + while (n != 0) { + ++count; + n &= n - 1; /* clear least significant bit */ + } + return count; +} + +/* Divide-and-conquer factorial algorithm + * + * Based on the formula and psuedo-code provided at: + * http://www.luschny.de/math/factorial/binarysplitfact.html + * + * Faster algorithms exist, but they're more complicated and depend on + * a fast prime factoriazation algorithm. + */ + +/* Compute product(range(n, m, 2)) using divide and conquer. Assumes + * n and m are odd and m > n. max_bits must be >= find_last_set_bit(m-2). */ +static PyObject * +factorial_partial_product(unsigned long n, unsigned long m, + unsigned long max_bits) +{ + unsigned long k, num_operands, output_bits; + PyObject *left = NULL, *right = NULL, *result = NULL; + + /* If the return value will fit an unsigned long, then we can + * multiply in a tight, fast loop where each multiply is O(1). + * Compute an upper bound on the number of bits required to store + * the answer. + * + * Storing some integer z requires floor(lg(z))+1 bits, which is + * conveniently the value returned by find_last_set_bit(z). The + * product of x*y will require at most + * find_last_set_bit(x)+find_last_set_bit(y) bits to store, based + * on the idea that lg product = lg x + lg y. + * + * We know that m is the largest number to be multiplied. From + * there, we have: + * find_last_set_bit(answer) <= num_operands * find_last_set_bit(m) + */ + + num_operands = (m-n)/2; + output_bits = num_operands * max_bits; + /* "output_bits > num_operands" checks the unlikely case of an + * overflow in the multiplication above. */ + if (output_bits > num_operands && output_bits <= sizeof(unsigned long)*8) { + unsigned long total = n; + for (n += 2; n < m; n += 2) + total *= n; + return PyLong_FromUnsignedLong(total); + } + + /* n previously fit in a signed long, so the + cannot overflow. */ + k = n + num_operands; + k = (k - 1) | 1; /* Round down to nearest odd number */ + left = factorial_partial_product(n, k, find_last_set_bit(k-2)); + if (left == NULL) goto done; + right = factorial_partial_product(k, m, max_bits); + if (right == NULL) goto done; + result = PyNumber_Multiply(left, right); +done: + Py_XDECREF(left); + Py_XDECREF(right); + return result; +} + +static PyObject * +factorial_loop(unsigned long n) +{ + long i; + unsigned long v, lower, upper; + PyObject *partial, *tmp, *p = NULL, *r = NULL; + + p = PyLong_FromLong(1); + if (p == NULL) + return NULL; + r = p; + Py_INCREF(r); + + upper = 1; + for (i = find_last_set_bit(n)-2; i >= 0; i--) { + v = n >> i; + lower = upper; + upper = (v + 1) | 1; + if (v <= 2) continue; + partial = factorial_partial_product(lower, upper, + find_last_set_bit(upper-2)); + /* p *= partial */ + if (partial == NULL) goto error; + tmp = PyNumber_Multiply(p, partial); + Py_DECREF(partial); + if (tmp == NULL) goto error; + Py_DECREF(p); + p = tmp; + + /* r += p; */ + tmp = PyNumber_Multiply(r, p); + if (tmp == NULL) goto error; + Py_DECREF(r); + r = tmp; + } + + goto done; + +error: + Py_DECREF(r); +done: + Py_DECREF(p); + return r; +} + static PyObject * math_factorial(PyObject *self, PyObject *arg) { - long i, x; - PyObject *result, *iobj, *newresult; + long x; + PyObject *result = NULL, *r = NULL, *nminusnumbits = NULL; if (PyFloat_Check(arg)) { PyObject *lx; @@ -1160,25 +1296,28 @@ math_factorial(PyObject *self, PyObject *arg) return NULL; } - result = (PyObject *)PyLong_FromLong(1); - if (result == NULL) - return NULL; - for (i=1 ; i<=x ; i++) { - iobj = (PyObject *)PyLong_FromLong(i); - if (iobj == NULL) - goto error; - newresult = PyNumber_Multiply(result, iobj); - Py_DECREF(iobj); - if (newresult == NULL) - goto error; - Py_DECREF(result); - result = newresult; + if (x <= 12) { + static const unsigned lookup[] = { + 1, 1, 2, 6, 24, 120, 720, 5040, 40320, + 362880, 3628800, 39916800, 479001600 + }; + result = PyLong_FromLong(lookup[x]); + return result; } + + r = factorial_loop(x); + if (r == NULL) + return NULL; + + nminusnumbits = PyLong_FromLong(x - count_set_bits(x)); + if (nminusnumbits == NULL) + goto done; + result = PyNumber_Lshift(r, nminusnumbits); + Py_DECREF(nminusnumbits); + +done: + Py_DECREF(r); return result; - -error: - Py_DECREF(result); - return NULL; } PyDoc_STRVAR(math_factorial_doc,