Skip to content

Commit 5c08ce9

Browse files
authored
bpo-36957: Speed up math.isqrt (#13405)
* Add math.isqrt function computing the integer square root. * Code cleanup: remove redundant comments, rename some variables. * Tighten up code a bit more; use Py_XDECREF to simplify error handling. * Update Modules/mathmodule.c Co-Authored-By: Serhiy Storchaka <[email protected]> * Update Modules/mathmodule.c Use real argument clinic type instead of an alias Co-Authored-By: Serhiy Storchaka <[email protected]> * Add proof sketch * Updates from review. * Correct and expand documentation. * Fix bad reference handling on error; make some variables block-local; other tidying. * Style and consistency fixes. * Add missing error check; don't try to DECREF a NULL a * Simplify some error returns. * Another two test cases: - clarify that floats are rejected even if they happen to be squares of small integers - TypeError beats ValueError for a negative float * Add fast path for small inputs. Needs tests. * Speed up isqrt for n >= 2**64 as well; add extra tests. * Reduce number of test-cases to avoid dominating the run-time of test_math. * Don't perform unnecessary extra iterations when computing c_bit_length. * Abstract common uint64_t code out into a separate function. * Cleanup. * Add a missing Py_DECREF in an error branch. More cleanup. * Update Modules/mathmodule.c Add missing `static` declaration to helper function. Co-Authored-By: Serhiy Storchaka <[email protected]> * Add missing backtick.
1 parent 7c59362 commit 5c08ce9

File tree

2 files changed

+57
-8
lines changed

2 files changed

+57
-8
lines changed

Lib/test/test_math.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -917,6 +917,7 @@ def testIsqrt(self):
917917
test_values = (
918918
list(range(1000))
919919
+ list(range(10**6 - 1000, 10**6 + 1000))
920+
+ [2**e + i for e in range(60, 200) for i in range(-40, 40)]
920921
+ [3**9999, 10**5001]
921922
)
922923

Modules/mathmodule.c

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1620,6 +1620,22 @@ completes the proof sketch.
16201620
16211621
*/
16221622

1623+
1624+
/* Approximate square root of a large 64-bit integer.
1625+
1626+
Given `n` satisfying `2**62 <= n < 2**64`, return `a`
1627+
satisfying `(a - 1)**2 < n < (a + 1)**2`. */
1628+
1629+
static uint64_t
1630+
_approximate_isqrt(uint64_t n)
1631+
{
1632+
uint32_t u = 1U + (n >> 62);
1633+
u = (u << 1) + (n >> 59) / u;
1634+
u = (u << 3) + (n >> 53) / u;
1635+
u = (u << 7) + (n >> 41) / u;
1636+
return (u << 15) + (n >> 17) / u;
1637+
}
1638+
16231639
/*[clinic input]
16241640
math.isqrt
16251641
@@ -1633,8 +1649,9 @@ static PyObject *
16331649
math_isqrt(PyObject *module, PyObject *n)
16341650
/*[clinic end generated code: output=35a6f7f980beab26 input=5b6e7ae4fa6c43d6]*/
16351651
{
1636-
int a_too_large, s;
1652+
int a_too_large, c_bit_length;
16371653
size_t c, d;
1654+
uint64_t m, u;
16381655
PyObject *a = NULL, *b;
16391656

16401657
n = PyNumber_Index(n);
@@ -1653,24 +1670,55 @@ math_isqrt(PyObject *module, PyObject *n)
16531670
return PyLong_FromLong(0);
16541671
}
16551672

1673+
/* c = (n.bit_length() - 1) // 2 */
16561674
c = _PyLong_NumBits(n);
16571675
if (c == (size_t)(-1)) {
16581676
goto error;
16591677
}
16601678
c = (c - 1U) / 2U;
16611679

1662-
/* s = c.bit_length() */
1663-
s = 0;
1664-
while ((c >> s) > 0) {
1665-
++s;
1680+
/* Fast path: if c <= 31 then n < 2**64 and we can compute directly with a
1681+
fast, almost branch-free algorithm. In the final correction, we use `u*u
1682+
- 1 >= m` instead of the simpler `u*u > m` in order to get the correct
1683+
result in the corner case where `u=2**32`. */
1684+
if (c <= 31U) {
1685+
m = (uint64_t)PyLong_AsUnsignedLongLong(n);
1686+
Py_DECREF(n);
1687+
if (m == (uint64_t)(-1) && PyErr_Occurred()) {
1688+
return NULL;
1689+
}
1690+
u = _approximate_isqrt(m << (62U - 2U*c)) >> (31U - c);
1691+
u -= u * u - 1U >= m;
1692+
return PyLong_FromUnsignedLongLong((unsigned long long)u);
16661693
}
16671694

1668-
a = PyLong_FromLong(1);
1695+
/* Slow path: n >= 2**64. We perform the first five iterations in C integer
1696+
arithmetic, then switch to using Python long integers. */
1697+
1698+
/* From n >= 2**64 it follows that c.bit_length() >= 6. */
1699+
c_bit_length = 6;
1700+
while ((c >> c_bit_length) > 0U) {
1701+
++c_bit_length;
1702+
}
1703+
1704+
/* Initialise d and a. */
1705+
d = c >> (c_bit_length - 5);
1706+
b = _PyLong_Rshift(n, 2U*c - 62U);
1707+
if (b == NULL) {
1708+
goto error;
1709+
}
1710+
m = (uint64_t)PyLong_AsUnsignedLongLong(b);
1711+
Py_DECREF(b);
1712+
if (m == (uint64_t)(-1) && PyErr_Occurred()) {
1713+
goto error;
1714+
}
1715+
u = _approximate_isqrt(m) >> (31U - d);
1716+
a = PyLong_FromUnsignedLongLong((unsigned long long)u);
16691717
if (a == NULL) {
16701718
goto error;
16711719
}
1672-
d = 0;
1673-
while (--s >= 0) {
1720+
1721+
for (int s = c_bit_length - 6; s >= 0; --s) {
16741722
PyObject *q;
16751723
size_t e = d;
16761724

0 commit comments

Comments
 (0)