 Author juraj.sukop juraj.sukop, mark.dickinson, rhettinger, serhiy.storchaka, stutzbach, tim.peters 2021-01-30.15:13:02 -1.0 Yes <1612019583.0.0.136424744258.issue43053@roundup.psfhosted.org>
Content
Mark, thank you very much for the thorough reply!

I believe it is not for me to say whether the complication is worth it or not. What really huge numbers goes, if I'm not mistaken, the original `isqrt` is already "fast" and only constant speed-up is possible.

I was also trying to speed up the last correction, i.e. that full-width multiplication, but didn't get too far (only M(n) -> M(n/2)). Continuing from `isqrt_2`:

a = (a << l) + ((n - (a**2 << 2*l)) >> l)//a//2
return a - (a*a > n)

a = (a << l) + ((n - (a**2 << 2*l)) >> (l + 1))//a
return a - (a**2 > n)

x = a
a = (x << l) + ((n - (x**2 << 2*l)) >> (l + 1))//x
return a - (((x << l) + ((n - (x**2 << 2*l)) >> (l + 1))//x)**2 > n)

x = a
x2 = x**2
a = (x << l) + ((n - (x2 << 2*l)) >> (l + 1))//x
return a - (((x << l) + ((n - (x2 << 2*l)) >> (l + 1))//x)**2 > n)

x = a
x2 = x**2
t = (n - (x2 << 2*l)) >> (l + 1)
u = t//x
a = (x << l) + u
b = ((x << l) + u)**2
return a - (b > n)

x = a
x2 = x**2
t = (n - (x2 << 2*l)) >> (l + 1)
u = t//x
a = (x << l) + u
b = ((x << l)**2 + 2*(x << l)*u + u**2)
return a - (b > n)

x = a
x2 = x**2
t = (n - (x2 << 2*l)) >> (l + 1)
u, v = divmod(t, x)
a = (x << l) + u
b = ((x2 << 2*l) + ((t - v) << (l + 1)) + u**2)
return a - (b > n)

But then I got stuck on that `u**2`. The above should be a bit faster but probably is not worth the complication. On the other hand, if it was worth the complication, each iteration would use only single half width multiplication as `x2` of one interaction is `b` of previous iteration.

Another idea maybe worth trying (only in C) is to use such an `l` during the iterations that it would make the numbers land on the word boundaries. In other words, computing few more bits here and there may not hurt the performance much but it would replace the shifts by pointer arithmetic. Here the only change is `//64*64`:

def isqrt_3(n):
l = (n.bit_length() - 1)//4//64*64
a = isqrt(n >> 2*l)
a = (a << l) + ((n - (a**2 << 2*l)) >> l)//a//2
return a - (a*a > n)
History
Date User Action Args
2021-01-30 15:13:03juraj.sukopsetrecipients: + juraj.sukop, tim.peters, rhettinger, mark.dickinson, stutzbach, serhiy.storchaka
2021-01-30 15:13:03juraj.sukopsetmessageid: <1612019583.0.0.136424744258.issue43053@roundup.psfhosted.org>