Skip to content

Commit d02c5e9

Browse files
authored
bpo-46258: Streamline isqrt fast path (#30333)
1 parent cfbde65 commit d02c5e9

File tree

2 files changed

+45
-14
lines changed

2 files changed

+45
-14
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Speed up :func:`math.isqrt` for small positive integers by replacing two
2+
division steps with a lookup table.

Modules/mathmodule.c

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1718,20 +1718,49 @@ completes the proof sketch.
17181718
17191719
*/
17201720

1721+
/*
1722+
The _approximate_isqrt_tab table provides approximate square roots for
1723+
16-bit integers. For any n in the range 2**14 <= n < 2**16, the value
1724+
1725+
a = _approximate_isqrt_tab[(n >> 8) - 64]
1726+
1727+
is an approximate square root of n, satisfying (a - 1)**2 < n < (a + 1)**2.
1728+
1729+
The table was computed in Python using the expression:
1730+
1731+
[min(round(sqrt(256*n + 128)), 255) for n in range(64, 256)]
1732+
*/
1733+
1734+
static const uint8_t _approximate_isqrt_tab[192] = {
1735+
128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
1736+
140, 141, 142, 143, 144, 144, 145, 146, 147, 148, 149, 150,
1737+
151, 151, 152, 153, 154, 155, 156, 156, 157, 158, 159, 160,
1738+
160, 161, 162, 163, 164, 164, 165, 166, 167, 167, 168, 169,
1739+
170, 170, 171, 172, 173, 173, 174, 175, 176, 176, 177, 178,
1740+
179, 179, 180, 181, 181, 182, 183, 183, 184, 185, 186, 186,
1741+
187, 188, 188, 189, 190, 190, 191, 192, 192, 193, 194, 194,
1742+
195, 196, 196, 197, 198, 198, 199, 200, 200, 201, 201, 202,
1743+
203, 203, 204, 205, 205, 206, 206, 207, 208, 208, 209, 210,
1744+
210, 211, 211, 212, 213, 213, 214, 214, 215, 216, 216, 217,
1745+
217, 218, 219, 219, 220, 220, 221, 221, 222, 223, 223, 224,
1746+
224, 225, 225, 226, 227, 227, 228, 228, 229, 229, 230, 230,
1747+
231, 232, 232, 233, 233, 234, 234, 235, 235, 236, 237, 237,
1748+
238, 238, 239, 239, 240, 240, 241, 241, 242, 242, 243, 243,
1749+
244, 244, 245, 246, 246, 247, 247, 248, 248, 249, 249, 250,
1750+
250, 251, 251, 252, 252, 253, 253, 254, 254, 255, 255, 255,
1751+
};
17211752

17221753
/* Approximate square root of a large 64-bit integer.
17231754
17241755
Given `n` satisfying `2**62 <= n < 2**64`, return `a`
17251756
satisfying `(a - 1)**2 < n < (a + 1)**2`. */
17261757

1727-
static uint64_t
1758+
static inline uint32_t
17281759
_approximate_isqrt(uint64_t n)
17291760
{
1730-
uint32_t u = 1U + (n >> 62);
1731-
u = (u << 1) + (n >> 59) / u;
1732-
u = (u << 3) + (n >> 53) / u;
1733-
u = (u << 7) + (n >> 41) / u;
1734-
return (u << 15) + (n >> 17) / u;
1761+
uint32_t u = _approximate_isqrt_tab[(n >> 56) - 64];
1762+
u = (u << 7) + (uint32_t)(n >> 41) / u;
1763+
return (u << 15) + (uint32_t)((n >> 17) / u);
17351764
}
17361765

17371766
/*[clinic input]
@@ -1749,7 +1778,8 @@ math_isqrt(PyObject *module, PyObject *n)
17491778
{
17501779
int a_too_large, c_bit_length;
17511780
size_t c, d;
1752-
uint64_t m, u;
1781+
uint64_t m;
1782+
uint32_t u;
17531783
PyObject *a = NULL, *b;
17541784

17551785
n = _PyNumber_Index(n);
@@ -1776,18 +1806,17 @@ math_isqrt(PyObject *module, PyObject *n)
17761806
c = (c - 1U) / 2U;
17771807

17781808
/* Fast path: if c <= 31 then n < 2**64 and we can compute directly with a
1779-
fast, almost branch-free algorithm. In the final correction, we use `u*u
1780-
- 1 >= m` instead of the simpler `u*u > m` in order to get the correct
1781-
result in the corner case where `u=2**32`. */
1809+
fast, almost branch-free algorithm. */
17821810
if (c <= 31U) {
1811+
int shift = 31 - (int)c;
17831812
m = (uint64_t)PyLong_AsUnsignedLongLong(n);
17841813
Py_DECREF(n);
17851814
if (m == (uint64_t)(-1) && PyErr_Occurred()) {
17861815
return NULL;
17871816
}
1788-
u = _approximate_isqrt(m << (62U - 2U*c)) >> (31U - c);
1789-
u -= u * u - 1U >= m;
1790-
return PyLong_FromUnsignedLongLong((unsigned long long)u);
1817+
u = _approximate_isqrt(m << 2*shift) >> shift;
1818+
u -= (uint64_t)u * u > m;
1819+
return PyLong_FromUnsignedLong(u);
17911820
}
17921821

17931822
/* Slow path: n >= 2**64. We perform the first five iterations in C integer
@@ -1811,7 +1840,7 @@ math_isqrt(PyObject *module, PyObject *n)
18111840
goto error;
18121841
}
18131842
u = _approximate_isqrt(m) >> (31U - d);
1814-
a = PyLong_FromUnsignedLongLong((unsigned long long)u);
1843+
a = PyLong_FromUnsignedLong(u);
18151844
if (a == NULL) {
18161845
goto error;
18171846
}

0 commit comments

Comments
 (0)