Index: longobject.c =================================================================== --- longobject.c (revision 66680) +++ longobject.c (working copy) @@ -90,6 +90,31 @@ #define MAX(x, y) ((x) < (y) ? (y) : (x)) #define MIN(x, y) ((x) > (y) ? (y) : (x)) +/* long_digits is supposed to be a 64 bits int, twice a twodigits, + * which is twice a digit; if there is not a 64 bits int type, + * long_digits is a twodigits. + * Grade school multiplication is performed by x_mul, which uses + * operands made of digit elements, or by x_mul2, with twodigits elements; + * below the size MUL2 it is more convenient to use the former. + * FIXME write these macro in a portable way. + */ +#ifdef HAVE_LONG_LONG +#define HAVE_INT64 +#endif + +#ifdef HAVE_INT64 +#define MUL2 20 +#define long_digits unsigned long long +#define PyLong_SHIFT2 (PyLong_SHIFT<<1) +#define PyLong_BASE2 ((twodigits)1 << PyLong_SHIFT2) +#define PyLong_MASK2 ((twodigits)(PyLong_BASE2 - 1)) +#else +#define long_digits twodigits +#define PyLong_SHIFT2 PyLong_SHIFT +#define PyLong_BASE2 PyLong_BASE +#define PyLong_MASK2 PyLong_MASK +#endif + /* Forward */ static PyLongObject *long_normalize(PyLongObject *); static PyLongObject *mul1(PyLongObject *, wdigit); @@ -2434,45 +2459,249 @@ return (PyObject *)z; } +/* utilities to convert a Python number to an array of twodigits and + * viceversa; used in x_mul2() + */ + +static twodigits* +to_twodigits(PyLongObject *a) +{ + int size_a = ABS(Py_SIZE(a)); + int size2_a = (size_a&1)? (size_a>>1) + 1: (size_a>>1); + twodigits *dig2 = (twodigits*) calloc(size2_a, sizeof(twodigits)); + if (dig2 == NULL) + return NULL; + int i, j; + for(i=0, j=0; i < size_a; i += 2, j += 1){ + dig2[j] = a->ob_digit[i] + (a->ob_digit[i+1]<< PyLong_SHIFT); + } + if (size_a & 1) + dig2[size2_a - 1] = a->ob_digit[size_a - 1]; + return dig2; +} + +static void +from_twodigits(PyLongObject*z, twodigits* p, int size2_p) +{ + int i, j; + for(i=0, j=0; i < size2_p; i++, j += 2) { + z->ob_digit[j] = (digit) (p[i] & PyLong_MASK); + z->ob_digit[j+1] = (digit) (p[i] >> PyLong_SHIFT); + } +} + /* Grade school multiplication, ignoring the signs. * Returns the absolute value of the product, or NULL if error. */ static PyLongObject * x_mul(PyLongObject *a, PyLongObject *b) { + PyLongObject *z; + Py_ssize_t size_a = ABS(Py_SIZE(a)); + Py_ssize_t size_b = ABS(Py_SIZE(b)); + Py_ssize_t i; + + z = _PyLong_New(size_a + size_b); + if (z == NULL) + return NULL; + + memset(z->ob_digit, 0, Py_SIZE(z) * sizeof(digit)); + if (a == b) { + /* Efficient squaring per HAC, Algorithm 14.16: + * http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf + * Gives slightly less than a 2x speedup when a == b, + * via exploiting that each entry in the multiplication + * pyramid appears twice (except for the size_a squares). + */ + for (i = 0; i < size_a; ++i) { + twodigits carry; + twodigits f = a->ob_digit[i]; + digit *pz = z->ob_digit + (i << 1); + digit *pa = a->ob_digit + i + 1; + digit *paend = a->ob_digit + size_a; + + SIGCHECK({ + Py_DECREF(z); + return NULL; + }) + + carry = *pz + f * f; + *pz++ = (digit)(carry & PyLong_MASK); + carry >>= PyLong_SHIFT; + assert(carry <= PyLong_MASK); + /* Now f is added in twice in each column of the + * pyramid it appears. Same as adding f<<1 once. + */ + f <<= 1; + while (pa < paend) { + carry += *pz + *pa++ * f; + *pz++ = (digit)(carry & PyLong_MASK); + carry >>= PyLong_SHIFT; + assert(carry <= (PyLong_MASK << 1)); + } + if (carry) { + carry += *pz; + *pz++ = (digit)(carry & PyLong_MASK); + carry >>= PyLong_SHIFT; + } + if (carry) + *pz += (digit)(carry & PyLong_MASK); + assert((carry >> PyLong_SHIFT) == 0); + } + } + else { /* a is not the same as b -- gradeschool long mult */ + /* To reduce the number of shift and mask operations + * collect terms in the following way + * (a[0] + a[1]*BASE + a[2]*BASE**2 + ...) * + * (b[0] + b[1]*BASE + b[2]*BASE**2 + ...) = + * a[0]*b[0] + (a[0]*b[1] + a[1]*b[0])*BASE + + * (a[0]*b[2] + a[1]*b[1])*BASE**2 + ... + */ + for (i = 0; i < size_a - 1; i += 2) { + twodigits f0 = a->ob_digit[i]; + twodigits f1 = a->ob_digit[i+1]; + digit *pz = z->ob_digit + i; + digit *pb = b->ob_digit; + twodigits carry = *pz + pb[0] * f0; + *pz++ = (digit)(carry & PyLong_MASK); + /* carry <= MASK*(MASK+1) = BASE**2 - BASE */ + carry >>= PyLong_SHIFT; + /* carry <= BASE - 2 = MASK - 1 */ + SIGCHECK({ + Py_DECREF(z); + return NULL; + }) + int j; + /* Bounds on carry in the loop to guarantee + * that it does not overflow: + * For j = 0: + * carry += *pz + pb[j+1] * f0 + pb[j] * f1; + * carry <= MASK-1 + MASK + 2*MASK**2 + * = 2*MASK**2 + 2*MASK - 1 + * = 2*BASE**2 - 2*BASE - 1 + * carry >>= PyLong_SHIFT; + * carry <= 2*BASE - 3 = 2*MASK - 1 + * For j = 1: + * carry += *pz + pb[j+1] * f0 + pb[j] * f1; + * carry <= 2*MASK-1 + MASK + 2*MASK**2 = + * = 2*MASK**2 + 3*MASK - 1 = + * = 2*BASE**2 - BASE - 2 + * carry >>= PyLong_SHIFT; + * carry <= 2*BASE - 2 = 2*MASK + * For j = 2: + * carry += *pz + pb[j+1] * f0 + pb[j] * f1; + * carry <= 2*MASK + MASK + 2*MASK**2 = + * = 2*MASK**2 + 3*MASK + * = 2*BASE**2 - BASE - 1 + * carry >>= PyLong_SHIFT; + * carry <= 2*BASE - 2 = 2*MASK; + * as in the case j=1, so that the bounds remain + * the same for the rest of the loop; therefore + * in this loop one has always + * carry <= 2*MASK**2 + 3*MASK + * which fits in a twodigits, see longintrepr.h + * + */ + for(j=0; j < size_b-1; j++) { + carry += *pz + pb[j+1] * f0 + pb[j] * f1; + *pz++ = (digit)(carry & PyLong_MASK); + carry >>= PyLong_SHIFT; + assert(carry <= 2*PyLong_MASK); + } + carry += *pz + pb[size_b-1] * f1; + /* carry <= 2*MASK + MASK + MASK**2 = + * = MASK**2 + 3*MASK = BASE**2 + BASE - 2 + */ + *pz++ = (digit)(carry & PyLong_MASK); + carry >>= PyLong_SHIFT; + /* carry <= BASE */ + if (carry) + *pz += (digit)(carry & PyLong_MASK); + /* according to the above bound one has + * (carry >> PyLong_SHIFT) <= 1 + * However it must be strictly + * (carry >> PyLong_SHIFT) == 0 + * otherwise z would exceed size_a + size_b, which + * is impossible. + */ + assert((carry >> PyLong_SHIFT) == 0); + } + if (size_a&1) { + twodigits carry = 0; + twodigits f = a->ob_digit[size_a - 1]; + digit *pz = z->ob_digit + size_a - 1; + digit *pb = b->ob_digit; + digit *pbend = b->ob_digit + size_b; + + SIGCHECK({ + Py_DECREF(z); + return NULL; + }) + + while (pb < pbend) { + carry += *pz + *pb++ * f; + *pz++ = (digit)(carry & PyLong_MASK); + carry >>= PyLong_SHIFT; + assert(carry <= PyLong_MASK); + } + if (carry) + *pz += (digit)(carry & PyLong_MASK); + assert((carry >> PyLong_SHIFT) == 0); + } + } + return long_normalize(z); +} + +#ifdef HAVE_INT64 +/* Grade school long multiplication, ignoring the signs. + * Returns the absolute value of the product, or NULL if error. + * Same algorithm as in x_mul(), but using doubled sizes. + */ +static PyLongObject * +x_mul2(PyLongObject *a, PyLongObject *b) +{ PyLongObject *z; - Py_ssize_t size_a = ABS(Py_SIZE(a)); - Py_ssize_t size_b = ABS(Py_SIZE(b)); + int size_a = ABS(Py_SIZE(a)); + int size_b = ABS(Py_SIZE(b)); Py_ssize_t i; - z = _PyLong_New(size_a + size_b); - if (z == NULL) - return NULL; - memset(z->ob_digit, 0, Py_SIZE(z) * sizeof(digit)); if (a == b) { - /* Efficient squaring per HAC, Algorithm 14.16: - * http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf - * Gives slightly less than a 2x speedup when a == b, - * via exploiting that each entry in the multiplication - * pyramid appears twice (except for the size_a squares). - */ - for (i = 0; i < size_a; ++i) { - twodigits carry; - twodigits f = a->ob_digit[i]; - digit *pz = z->ob_digit + (i << 1); - digit *pa = a->ob_digit + i + 1; - digit *paend = a->ob_digit + size_a; + twodigits *a2 = to_twodigits(a); + if (a2 == NULL) + return NULL; + int size2_a = (size_a&1)? (size_a>>1) + 1: (size_a>>1); + z = _PyLong_New(size2_a << 2); + if (z == NULL) { + free(a2); + return NULL; + } + memset(z->ob_digit, 0, Py_SIZE(z) * sizeof(digit)); + twodigits *c2 = (twodigits*) calloc(size2_a + size2_a, + sizeof(twodigits)); + if (c2 == NULL) { + Py_DECREF(z); + free(a2); + return NULL; + } + for (i = 0; i < size2_a; ++i) { + long_digits carry; + long_digits f = a2[i]; + twodigits *pz = c2 + (i << 1); + twodigits *pa = a2 + i + 1; + twodigits *paend = a2 + size2_a; SIGCHECK({ Py_DECREF(z); + free(a2); + free(c2); return NULL; }) carry = *pz + f * f; - *pz++ = (digit)(carry & PyLong_MASK); - carry >>= PyLong_SHIFT; - assert(carry <= PyLong_MASK); + *pz++ = (twodigits)(carry & PyLong_MASK2); + carry >>= PyLong_SHIFT2; + assert(carry <= PyLong_MASK2); /* Now f is added in twice in each column of the * pyramid it appears. Same as adding f<<1 once. @@ -2480,46 +2709,118 @@ f <<= 1; while (pa < paend) { carry += *pz + *pa++ * f; - *pz++ = (digit)(carry & PyLong_MASK); - carry >>= PyLong_SHIFT; - assert(carry <= (PyLong_MASK << 1)); + *pz++ = (twodigits)(carry & PyLong_MASK2); + carry >>= PyLong_SHIFT2; + assert(carry <= (PyLong_MASK2 << 1)); } if (carry) { carry += *pz; - *pz++ = (digit)(carry & PyLong_MASK); - carry >>= PyLong_SHIFT; + *pz++ = (twodigits)(carry & PyLong_MASK2); + carry >>= PyLong_SHIFT2; } if (carry) - *pz += (digit)(carry & PyLong_MASK); - assert((carry >> PyLong_SHIFT) == 0); + *pz += (twodigits)(carry & PyLong_MASK2); + assert((carry >> PyLong_SHIFT2) == 0); } + from_twodigits(z, c2, size2_a + size2_a); + free(a2); + free(c2); } else { /* a is not the same as b -- gradeschool long mult */ - for (i = 0; i < size_a; ++i) { - twodigits carry = 0; - twodigits f = a->ob_digit[i]; - digit *pz = z->ob_digit + i; - digit *pb = b->ob_digit; - digit *pbend = b->ob_digit + size_b; + /* create arrays of twodigits a2 and b2 from a and b; + * then proceed as in x_mul, with all types doubled: + * digit becomes twodigits, twodigits becomes + * long_digits, PyLong_SHIFT becomes PyLong_SHIFT2, + * PyLong_MASK becomes PyLong_MASK2. + * At the end convert c2 in z. + */ + twodigits *a2 = to_twodigits(a); + if (a2 == NULL) + return NULL; + int size2_a = (size_a&1)? (size_a>>1) + 1: (size_a>>1); + twodigits *b2 = to_twodigits(b); + if (b2 == NULL) { + free(a2); + return NULL; + } + int size2_b = (size_b&1)? (size_b>>1) + 1: (size_b>>1); + z = _PyLong_New((size2_a + size2_b)<<1); + if (z == NULL) { + free(a2); + free(b2); + return NULL; + } + memset(z->ob_digit, 0, Py_SIZE(z) * sizeof(digit)); + twodigits *c2 = (twodigits*) calloc(size2_a + size2_b, + sizeof(twodigits)); + if (c2 == NULL) { + Py_DECREF(z); + free(a2); + free(b2); + return NULL; + } + for (i = 0; i < size2_a - 1; i += 2) { + long_digits f0 = a2[i]; + long_digits f1 = a2[i+1]; + twodigits *pz = c2 + i; + twodigits *pb = b2; + long_digits carry = *pz + pb[0] * f0; + *pz++ = (twodigits)(carry & PyLong_MASK2); + carry >>= PyLong_SHIFT2; SIGCHECK({ Py_DECREF(z); + free(a2); + free(b2); + free(c2); return NULL; }) + int j; + for(j=0; j < size2_b-1; j++) { + carry += *pz + pb[j+1] * f0 + pb[j] * f1; + *pz++ = (twodigits)(carry & PyLong_MASK2); + carry >>= PyLong_SHIFT2; + } + carry += *pz + pb[size2_b-1] * f1; + *pz++ = (twodigits)(carry & PyLong_MASK2); + carry >>= PyLong_SHIFT2; + if (carry) + *pz = (twodigits)(carry & PyLong_MASK2); + assert((carry >> PyLong_SHIFT2) == 0); + } + if (size2_a&1) { + long_digits carry = 0; + long_digits f = a2[size2_a - 1]; + twodigits *pz = c2 + size2_a - 1; + twodigits *pb = b2; + twodigits *pbend = b2 + size2_b; + SIGCHECK({ + Py_DECREF(z); + free(a2); + free(b2); + free(c2); + return NULL; + }) while (pb < pbend) { carry += *pz + *pb++ * f; - *pz++ = (digit)(carry & PyLong_MASK); - carry >>= PyLong_SHIFT; - assert(carry <= PyLong_MASK); + *pz++ = (twodigits)(carry & PyLong_MASK2); + carry >>= PyLong_SHIFT2; } if (carry) - *pz += (digit)(carry & PyLong_MASK); - assert((carry >> PyLong_SHIFT) == 0); + *pz = (twodigits)(carry & PyLong_MASK2); + assert((carry >> PyLong_SHIFT2) == 0); } + from_twodigits(z, c2, size2_a + size2_b); + free(a2); + free(b2); + free(c2); + } + return long_normalize(z); } +#endif /* A helper for Karatsuba multiplication (k_mul). Takes a long "n" and an integer "size" representing the place to @@ -2564,23 +2865,11 @@ { Py_ssize_t asize = ABS(Py_SIZE(a)); Py_ssize_t bsize = ABS(Py_SIZE(b)); - PyLongObject *ah = NULL; - PyLongObject *al = NULL; - PyLongObject *bh = NULL; - PyLongObject *bl = NULL; PyLongObject *ret = NULL; PyLongObject *t1, *t2, *t3; Py_ssize_t shift; /* the number of digits we split off */ Py_ssize_t i; - /* (ah*X+al)(bh*X+bl) = ah*bh*X*X + (ah*bl + al*bh)*X + al*bl - * Let k = (ah+al)*(bh+bl) = ah*bl + al*bh + ah*bh + al*bl - * Then the original product is - * ah*bh*X*X + (k - ah*bh - al*bl)*X + al*bl - * By picking X to be a power of 2, "*X" is just shifting, and it's - * been reduced to 3 multiplies on numbers half the size. - */ - /* We want to split based on the larger number; fiddle so that b * is largest. */ @@ -2595,14 +2884,35 @@ } /* Use gradeschool math when either number is too small. */ - i = a == b ? KARATSUBA_SQUARE_CUTOFF : KARATSUBA_CUTOFF; - if (asize <= i) { +#ifdef HAVE_INT64 + if (asize < MUL2) { if (asize == 0) return (PyLongObject *)PyLong_FromLong(0); else return x_mul(a, b); } + i = a == b ? KARATSUBA_SQUARE_CUTOFF : KARATSUBA_CUTOFF; + if (asize <= i) { + return x_mul2(a, b); + } +#else + i = a == b ? KARATSUBA_SQUARE_CUTOFF : KARATSUBA_CUTOFF; + if (asize <= i) { + if (asize == 0) + return (PyLongObject *)PyLong_FromLong(0); + else + return x_mul(a, b); + } +#endif + /* (ah*X+al)(bh*X+bl) = ah*bh*X*X + (ah*bl + al*bh)*X + al*bl + * Let k = (ah+al)*(bh+bl) = ah*bl + al*bh + ah*bh + al*bl + * Then the original product is + * ah*bh*X*X + (k - ah*bh - al*bl)*X + al*bl + * By picking X to be a power of 2, "*X" is just shifting, and it's + * been reduced to 3 multiplies on numbers half the size. + */ + /* If a is small compared to b, splitting on b gives a degenerate * case with ah==0, and Karatsuba may be (even much) less efficient * than "grade school" then. However, we can still win, by viewing @@ -2612,6 +2922,10 @@ if (2 * asize <= bsize) return k_lopsided_mul(a, b); + PyLongObject *ah = NULL; + PyLongObject *al = NULL; + PyLongObject *bh = NULL; + PyLongObject *bl = NULL; /* Split a & b into hi & lo pieces. */ shift = bsize >> 1; if (kmul_split(a, shift, &ah, &al) < 0) goto fail; @@ -2840,14 +3154,14 @@ PyLongObject *z; CHECK_BINOP(a, b); - + if (ABS(Py_SIZE(a)) <= 1 && ABS(Py_SIZE(b)) <= 1) { PyObject *r; r = PyLong_FromLong(MEDIUM_VALUE(a)*MEDIUM_VALUE(b)); return r; } + z = k_mul(a, b); - z = k_mul(a, b); /* Negate if exactly one of the inputs is negative. */ if (((Py_SIZE(a) ^ Py_SIZE(b)) < 0) && z) NEGATE(z); @@ -3842,3 +4156,4 @@ } #endif } +