classification
Title: Speed up math.isqrt, again
Type: performance Stage: resolved
Components: Extension Modules Versions: Python 3.8
process
Status: closed Resolution: rejected
Dependencies: Superseder:
Assigned To: mark.dickinson Nosy List: juraj.sukop, mark.dickinson, rhettinger, serhiy.storchaka, stutzbach, tim.peters
Priority: normal Keywords: patch

Created on 2021-01-28 08:03 by juraj.sukop, last changed 2021-02-02 13:38 by mark.dickinson. This issue is now closed.

Files
File name Uploaded Description Edit
isqrt_2.py juraj.sukop, 2021-01-28 08:03
Pull Requests
URL Status Linked Edit
PR 24414 closed mark.dickinson, 2021-02-01 19:04
Messages (11)
msg385848 - (view) Author: Juraj Sukop (juraj.sukop) Date: 2021-01-28 08:03
This is a follow up to https://bugs.python.org/issue36887 and https://bugs.python.org/issue36957 .

The new `isqrt` is remarkably simple but it does not split the number at hand optimally. Ideally one would want to have 2n/n division everywhere but since the last iteration takes as much effort as all of the iterations before it this is what the attached code focuses on.

At least in my testing the `isqrt_2` code below improved the performance by 50% (3s down to 2s, for example) and, if used, perhaps the original `isqrt` could do without the final correction `a - (a*a > n)`.
msg385849 - (view) Author: Mark Dickinson (mark.dickinson) * (Python committer) Date: 2021-01-28 08:41
Thanks; I'll take a look at this at the weekend. Do you have a sketch of a proof of correctness available?
msg385856 - (view) Author: Juraj Sukop (juraj.sukop) Date: 2021-01-28 13:13
What the proof goes, you did most of the work already. Consider the following:

    l = (n.bit_length() - 1)//4
    a = isqrt(n >> 2*l)
    a = ((a << l) + n//(a << l))//2
    return a - (a*a > n)

This computes the square root of the (possibly longer) upper half, applies one Heron's step and a single correction. I think it is functionally equal to what you wrote. Those zeros don't contribute to the quotient so we could instead write:

    a = ((a << l) + (n >> l)//a)//2

The problem is that the 3n/n division in the step `(a + n//a)//2` basically recomputes the upper half we already know and so we want to avoid it: instead of 3n/n giving 2n quotient, we want 2n/n giving 1n quotient. If the upper half is correct, the lower half to be taken care of is `n - a**2`:

    a = (a << l) + ((n - (a << l)**2) >> l)//a//2

And there is no need to square the zeros either:

    a = (a << l) + ((n - (a**2 << 2*l) >> l)//a//2

So I *think* it should be correct, the only thing I'm not sure about is whether the final correction in the original `isqrt` is needed. Perhaps the automated proof of yours could give an answer?
msg385861 - (view) Author: Mark Dickinson (mark.dickinson) * (Python committer) Date: 2021-01-28 13:57
> the only thing I'm not sure about is whether the final correction in the original `isqrt` is needed

Well, *some* part of the algorithm has to make use of the low-order bits of n. Otherwise we won't be able to distinguish n = 4a**2 + 4a + 1 (whose isqrt is 2a + 1) from 4a**2 + 4a (whose isqrt is 2a).
msg385862 - (view) Author: Juraj Sukop (juraj.sukop) Date: 2021-01-28 14:08
The rounding down of `l` might compute more than half of the bits so that the final Heron' step in `isqrt_2` might correct the uncertain low bit if `a - (a*a > n)` is missing from `isqrt`.

As it currently stands, `a - (a*a > n)` is computed both in `isqrt` and `isqrt_2`. So I was thinking that maybe the former might be dropped.

Are you saying that both correction are need?
msg385863 - (view) Author: Mark Dickinson (mark.dickinson) * (Python committer) Date: 2021-01-28 14:10
> As it currently stands, `a - (a*a > n)` is computed both in `isqrt` and `isqrt_2`. So I was thinking that maybe the former might be dropped.

Ah, sorry; I misunderstood. Yes, I think so. I'll respond more fully later. (Sorry - real life getting in the way right now.)
msg385872 - (view) Author: Mark Dickinson (mark.dickinson) * (Python committer) Date: 2021-01-28 17:30
Some comments, now that I've had a chance to look properly at the suggestion.

For reference, here's the "near square root" function that forms the basis of Python's isqrt algorithm. For clarity, I've written it recursively, but it's equivalent to the iterative version described in mathmodule.c. (Definition: for a positive integer n, a "near square root" of n is an integer a such that |a - √n| < 1; or in other words the near square roots of n are the floor and the ceiling of √n.)

    def nsqrt(n):
        """Compute a near square root for a positive integer n."""
        if n < 4:
            return 1
        else:
            e = (n.bit_length() - 3) // 4
            a = nsqrt(n >> 2*e + 2)
            return (a << e) + (n >> e + 2) // a

Juraj's suggestion, applied to each step of the recursion rather than just the outer step, amounts to computing the expression in the last line in a different way. (What follows isn't *identical* to what Juraj is suggesting in all the details, but it's essentially equivalent and has the same key performance implications.) Here's the proposed new version, identical to the previous one except for the last line:

    def nsqrt(n):
        """Compute a near square root for a positive integer n."""
        if n < 4:
            return 1
        else:
            e = (n.bit_length() - 3) // 4
            a = nsqrt(n >> 2*e + 2)
            return (a << e + 1) + ((n >> e + 2) - (a * a << e)) // a

With regards to proof, it's straightforward to see that this is equivalent: we have

        (a << e) + (n >> e + 2) // a
     == (a << e) + (a << e) + (n >> e + 2) // a - (a << e)
     == (a << e + 1) + (n >> e + 2) // a - (a << e)
     == (a << e + 1) + ((n >> e + 2) - (a * a << e)) // a

It's interesting to note that with this version we *do* rely on Python's floor division semantics for negative numbers, since the quantity ((n >> e + 2) - (a * a << e)) could be negative; that last equality is not valid with C-style sign-of-result-is-sign-of-dividend division. The first version works entirely with nonnegative integers. (And the version proposed by Juraj also works entirely with nonnegative integers, as a result of the extra correction step.)

And now the key point is that the dividend (n >> e + 2) - (a * a << e) in the second version is significantly smaller (roughly two-thirds the bit count) than the original divisor n >> e + 2. That should make the division roughly twice as fast (in the limit as n gets large), and since the division is the main bottleneck in the algorithm for large n, it should speed up the algorithm overall, again in the limit for large n, despite the fact that we have a greater number of arithmetic operations per iteration. And with Python's current subquadratic multiplication but quadratic division, the advantage becomes more significant as n gets large.

The second version is a little more complicated than the first, but the complication probably amounts to no more than 10-20 extra lines of C code. Still, there's a maintenance cost in adding that complication to be considered.

But here's the thing: I don't actually care about the performance for *huge* n - people who care about that sort of thing would be better off using gmpy2. I'm much more interested in the performance implications for *smaller* n: that is, integers of length 64-256 bits, say. (For n smaller than 2**64 the difference is irrelevant, since we have a pure C fast path there.) If the second version performs better across the board it may be worth the extra complexity. If it doesn't, then what's the cutoff? That is, where does the second version start outperforming the first? I'm not really so interested in having a hybrid algorithm that switches from one solution to the other at some threshold - that's a step too far complexity-wise.

So I propose that the next step be to code up the second variant in mathmodule.c and do some performance testing.

Juraj: are you interested in working on a PR?
msg385873 - (view) Author: Mark Dickinson (mark.dickinson) * (Python committer) Date: 2021-01-28 17:57
Translation of the proposal to the iterative version described here: https://github.com/python/cpython/blob/64fc105b2d2faaeadd1026d2417b83915af6622f/Modules/mathmodule.c#L1591-L1611

The main loop:

        c = (n.bit_length() - 1) // 2
        a = 1
        d = 0
        for s in reversed(range(c.bit_length())):
            # Loop invariant: (a-1)**2 < (n >> 2*(c - d)) < (a+1)**2
            e = d
            d = c >> s
            a = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a

becomes (again identical except for the last line):

        c = (n.bit_length() - 1) // 2
        a = 1
        d = 0
        for s in reversed(range(c.bit_length())):
            # Loop invariant: (a-1)**2 < (n >> 2*(c - d)) < (a+1)**2
            e = d
            d = c >> s
            a = (a << d - e) + ((n >> 2*c - e - d + 1) - (a*a << d - e - 1)) // a
msg385985 - (view) Author: Juraj Sukop (juraj.sukop) Date: 2021-01-30 15:13
Mark, thank you very much for the thorough reply!

I believe it is not for me to say whether the complication is worth it or not. What really huge numbers goes, if I'm not mistaken, the original `isqrt` is already "fast" and only constant speed-up is possible.

I was also trying to speed up the last correction, i.e. that full-width multiplication, but didn't get too far (only M(n) -> M(n/2)). Continuing from `isqrt_2`:

    a = (a << l) + ((n - (a**2 << 2*l)) >> l)//a//2
    return a - (a*a > n)
    
    a = (a << l) + ((n - (a**2 << 2*l)) >> (l + 1))//a
    return a - (a**2 > n)
    
    x = a
    a = (x << l) + ((n - (x**2 << 2*l)) >> (l + 1))//x
    return a - (((x << l) + ((n - (x**2 << 2*l)) >> (l + 1))//x)**2 > n)
    
    x = a
    x2 = x**2
    a = (x << l) + ((n - (x2 << 2*l)) >> (l + 1))//x
    return a - (((x << l) + ((n - (x2 << 2*l)) >> (l + 1))//x)**2 > n)
    
    x = a
    x2 = x**2
    t = (n - (x2 << 2*l)) >> (l + 1)
    u = t//x
    a = (x << l) + u
    b = ((x << l) + u)**2
    return a - (b > n)
    
    x = a
    x2 = x**2
    t = (n - (x2 << 2*l)) >> (l + 1)
    u = t//x
    a = (x << l) + u
    b = ((x << l)**2 + 2*(x << l)*u + u**2)
    return a - (b > n)
    
    x = a
    x2 = x**2
    t = (n - (x2 << 2*l)) >> (l + 1)
    u, v = divmod(t, x)
    a = (x << l) + u
    b = ((x2 << 2*l) + ((t - v) << (l + 1)) + u**2)
    return a - (b > n)

But then I got stuck on that `u**2`. The above should be a bit faster but probably is not worth the complication. On the other hand, if it was worth the complication, each iteration would use only single half width multiplication as `x2` of one interaction is `b` of previous iteration.

Another idea maybe worth trying (only in C) is to use such an `l` during the iterations that it would make the numbers land on the word boundaries. In other words, computing few more bits here and there may not hurt the performance much but it would replace the shifts by pointer arithmetic. Here the only change is `//64*64`:

    def isqrt_3(n):
        l = (n.bit_length() - 1)//4//64*64
        a = isqrt(n >> 2*l)
        a = (a << l) + ((n - (a**2 << 2*l)) >> l)//a//2
        return a - (a*a > n)
msg386089 - (view) Author: Mark Dickinson (mark.dickinson) * (Python committer) Date: 2021-02-01 18:57
> the complication probably amounts to no more than 10-20 extra lines of C code

A net difference of 16 lines of code, as it turns out. The branch is here: https://github.com/mdickinson/cpython/tree/isqrt-performance

Informal not-very-scientific timings more-or-less confirm what I expected: I _do_ get a speedup approaching a factor of 2 for huge n: getting a million digits of sqrt(2) via `n = 2*10**10**6; x = isqrt(n)` takes around 9 seconds on master and 5 seconds with this branch, on my machine. But for values with 20 digits or so, the overhead of the extra operations means that the algorithm is around 20% slower. The cutoff for me seems to be somewhere between 200 and 1000 digits.

So I'm afraid I'm going to leave this as is: if speed were all we cared about then there are all sorts of things we could try, but I'd rather keep the simplicity. And it's nice that it's still *possible* to compute a million digits of sqrt(2) in a few seconds. Java's implementation of BigInteger.sqrt can't do that. :-)
msg386092 - (view) Author: Mark Dickinson (mark.dickinson) * (Python committer) Date: 2021-02-01 19:30
> Java's implementation of BigInteger.sqrt can't do that. :-)

Well, okay, depending on your definition of "a few", actually it can, but Python is still faster.
History
Date User Action Args
2021-02-02 13:38:53mark.dickinsonsetstatus: open -> closed
resolution: rejected
stage: patch review -> resolved
2021-02-01 19:30:01mark.dickinsonsetmessages: + msg386092
2021-02-01 19:04:52mark.dickinsonsetkeywords: + patch
stage: patch review
pull_requests: + pull_request23229
2021-02-01 18:57:25mark.dickinsonsetmessages: + msg386089
2021-01-30 15:13:02juraj.sukopsetmessages: + msg385985
2021-01-28 17:57:58mark.dickinsonsetmessages: + msg385873
2021-01-28 17:30:05mark.dickinsonsetnosy: + tim.peters
messages: + msg385872
2021-01-28 14:10:04mark.dickinsonsetmessages: + msg385863
2021-01-28 14:08:34juraj.sukopsetmessages: + msg385862
2021-01-28 13:57:21mark.dickinsonsetmessages: + msg385861
2021-01-28 13:13:29juraj.sukopsetmessages: + msg385856
2021-01-28 08:41:53mark.dickinsonsetassignee: mark.dickinson
messages: + msg385849
2021-01-28 08:06:06serhiy.storchakasetnosy: + rhettinger, mark.dickinson, stutzbach, serhiy.storchaka
2021-01-28 08:03:26juraj.sukopcreate