Skip to content

Commit 2b843ac

Browse files
bpo-35431: Refactor math.comb() implementation. (GH-13725)
* Fixed some bugs. * Added support for index-likes objects. * Improved error messages. * Cleaned up and optimized the code. * Added more tests.
1 parent 9843bc1 commit 2b843ac

File tree

4 files changed

+111
-101
lines changed

4 files changed

+111
-101
lines changed

Doc/library/math.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,11 +238,11 @@ Number-theoretic and representation functions
238238
and without order.
239239

240240
Also called the binomial coefficient. It is mathematically equal to the expression
241-
``n! / (k! (n - k)!)``. It is equivalent to the coefficient of k-th term in
241+
``n! / (k! (n - k)!)``. It is equivalent to the coefficient of the *k*-th term in the
242242
polynomial expansion of the expression ``(1 + x) ** n``.
243243

244244
Raises :exc:`TypeError` if the arguments not integers.
245-
Raises :exc:`ValueError` if the arguments are negative or if k > n.
245+
Raises :exc:`ValueError` if the arguments are negative or if *k* > *n*.
246246

247247
.. versionadded:: 3.8
248248

Lib/test/test_math.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1893,25 +1893,40 @@ def testComb(self):
18931893
# Raises TypeError if any argument is non-integer or argument count is
18941894
# not 2
18951895
self.assertRaises(TypeError, comb, 10, 1.0)
1896+
self.assertRaises(TypeError, comb, 10, decimal.Decimal(1.0))
18961897
self.assertRaises(TypeError, comb, 10, "1")
1897-
self.assertRaises(TypeError, comb, "10", 1)
18981898
self.assertRaises(TypeError, comb, 10.0, 1)
1899+
self.assertRaises(TypeError, comb, decimal.Decimal(10.0), 1)
1900+
self.assertRaises(TypeError, comb, "10", 1)
18991901

19001902
self.assertRaises(TypeError, comb, 10)
19011903
self.assertRaises(TypeError, comb, 10, 1, 3)
19021904
self.assertRaises(TypeError, comb)
19031905

19041906
# Raises Value error if not k or n are negative numbers
19051907
self.assertRaises(ValueError, comb, -1, 1)
1906-
self.assertRaises(ValueError, comb, -10*10, 1)
1908+
self.assertRaises(ValueError, comb, -2**1000, 1)
19071909
self.assertRaises(ValueError, comb, 1, -1)
1908-
self.assertRaises(ValueError, comb, 1, -10*10)
1910+
self.assertRaises(ValueError, comb, 1, -2**1000)
19091911

19101912
# Raises value error if k is greater than n
1911-
self.assertRaises(ValueError, comb, 1, 10**10)
1912-
self.assertRaises(ValueError, comb, 0, 1)
1913-
1914-
1913+
self.assertRaises(ValueError, comb, 1, 2)
1914+
self.assertRaises(ValueError, comb, 1, 2**1000)
1915+
1916+
n = 2**1000
1917+
self.assertEqual(comb(n, 0), 1)
1918+
self.assertEqual(comb(n, 1), n)
1919+
self.assertEqual(comb(n, 2), n * (n-1) // 2)
1920+
self.assertEqual(comb(n, n), 1)
1921+
self.assertEqual(comb(n, n-1), n)
1922+
self.assertEqual(comb(n, n-2), n * (n-1) // 2)
1923+
self.assertRaises((OverflowError, MemoryError), comb, n, n//2)
1924+
1925+
for n, k in (True, True), (True, False), (False, False):
1926+
self.assertEqual(comb(n, k), 1)
1927+
self.assertIs(type(comb(n, k)), int)
1928+
self.assertEqual(comb(MyIndexable(5), MyIndexable(2)), 10)
1929+
self.assertIs(type(comb(MyIndexable(5), MyIndexable(2))), int)
19151930

19161931

19171932
def test_main():

Modules/clinic/mathmodule.c.h

Lines changed: 6 additions & 18 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Modules/mathmodule.c

Lines changed: 81 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -3001,10 +3001,11 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start)
30013001
/*[clinic input]
30023002
math.comb
30033003
3004-
n: object(subclass_of='&PyLong_Type')
3005-
k: object(subclass_of='&PyLong_Type')
3004+
n: object
3005+
k: object
3006+
/
30063007
3007-
Number of ways to choose *k* items from *n* items without repetition and without order.
3008+
Number of ways to choose k items from n items without repetition and without order.
30083009
30093010
Also called the binomial coefficient. It is mathematically equal to the expression
30103011
n! / (k! * (n - k)!). It is equivalent to the coefficient of k-th term in
@@ -3017,103 +3018,109 @@ Raises ValueError if the arguments are negative or if k > n.
30173018

30183019
static PyObject *
30193020
math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
3020-
/*[clinic end generated code: output=bd2cec8d854f3493 input=565f340f98efb5b5]*/
3021+
/*[clinic end generated code: output=bd2cec8d854f3493 input=2f336ac9ec8242f9]*/
30213022
{
3022-
PyObject *val = NULL,
3023-
*temp_obj1 = NULL,
3024-
*temp_obj2 = NULL,
3025-
*dump_var = NULL;
3023+
PyObject *result = NULL, *factor = NULL, *temp;
30263024
int overflow, cmp;
3027-
long long i, terms;
3025+
long long i, factors;
30283026

3029-
cmp = PyObject_RichCompareBool(n, k, Py_LT);
3030-
if (cmp < 0) {
3031-
goto fail_comb;
3027+
n = PyNumber_Index(n);
3028+
if (n == NULL) {
3029+
return NULL;
30323030
}
3033-
else if (cmp > 0) {
3034-
PyErr_Format(PyExc_ValueError,
3035-
"n must be an integer greater than or equal to k");
3036-
goto fail_comb;
3031+
k = PyNumber_Index(k);
3032+
if (k == NULL) {
3033+
Py_DECREF(n);
3034+
return NULL;
30373035
}
30383036

3039-
/* b = min(b, a - b) */
3040-
dump_var = PyNumber_Subtract(n, k);
3041-
if (dump_var == NULL) {
3042-
goto fail_comb;
3037+
if (Py_SIZE(n) < 0) {
3038+
PyErr_SetString(PyExc_ValueError,
3039+
"n must be a non-negative integer");
3040+
goto error;
30433041
}
3044-
cmp = PyObject_RichCompareBool(k, dump_var, Py_GT);
3045-
if (cmp < 0) {
3046-
goto fail_comb;
3042+
/* k = min(k, n - k) */
3043+
temp = PyNumber_Subtract(n, k);
3044+
if (temp == NULL) {
3045+
goto error;
30473046
}
3048-
else if (cmp > 0) {
3049-
k = dump_var;
3050-
dump_var = NULL;
3047+
if (Py_SIZE(temp) < 0) {
3048+
Py_DECREF(temp);
3049+
PyErr_SetString(PyExc_ValueError,
3050+
"k must be an integer less than or equal to n");
3051+
goto error;
3052+
}
3053+
cmp = PyObject_RichCompareBool(k, temp, Py_GT);
3054+
if (cmp > 0) {
3055+
Py_SETREF(k, temp);
30513056
}
30523057
else {
3053-
Py_DECREF(dump_var);
3054-
dump_var = NULL;
3058+
Py_DECREF(temp);
3059+
if (cmp < 0) {
3060+
goto error;
3061+
}
30553062
}
30563063

3057-
terms = PyLong_AsLongLongAndOverflow(k, &overflow);
3058-
if (terms < 0 && PyErr_Occurred()) {
3059-
goto fail_comb;
3060-
}
3061-
else if (overflow > 0) {
3064+
factors = PyLong_AsLongLongAndOverflow(k, &overflow);
3065+
if (overflow > 0) {
30623066
PyErr_Format(PyExc_OverflowError,
3063-
"minimum(n - k, k) must not exceed %lld",
3067+
"min(n - k, k) must not exceed %lld",
30643068
LLONG_MAX);
3065-
goto fail_comb;
3069+
goto error;
30663070
}
3067-
else if (overflow < 0 || terms < 0) {
3068-
PyErr_Format(PyExc_ValueError,
3069-
"k must be a positive integer");
3070-
goto fail_comb;
3071+
else if (overflow < 0 || factors < 0) {
3072+
if (!PyErr_Occurred()) {
3073+
PyErr_SetString(PyExc_ValueError,
3074+
"k must be a non-negative integer");
3075+
}
3076+
goto error;
30713077
}
30723078

3073-
if (terms == 0) {
3074-
return PyNumber_Long(_PyLong_One);
3079+
if (factors == 0) {
3080+
result = PyLong_FromLong(1);
3081+
goto done;
30753082
}
30763083

3077-
val = PyNumber_Long(n);
3078-
for (i = 1; i < terms; ++i) {
3079-
temp_obj1 = PyLong_FromSsize_t(i);
3080-
if (temp_obj1 == NULL) {
3081-
goto fail_comb;
3082-
}
3083-
temp_obj2 = PyNumber_Subtract(n, temp_obj1);
3084-
if (temp_obj2 == NULL) {
3085-
goto fail_comb;
3084+
result = n;
3085+
Py_INCREF(result);
3086+
if (factors == 1) {
3087+
goto done;
3088+
}
3089+
3090+
factor = n;
3091+
Py_INCREF(factor);
3092+
for (i = 1; i < factors; ++i) {
3093+
Py_SETREF(factor, PyNumber_Subtract(factor, _PyLong_One));
3094+
if (factor == NULL) {
3095+
goto error;
30863096
}
3087-
dump_var = val;
3088-
val = PyNumber_Multiply(val, temp_obj2);
3089-
if (val == NULL) {
3090-
goto fail_comb;
3097+
Py_SETREF(result, PyNumber_Multiply(result, factor));
3098+
if (result == NULL) {
3099+
goto error;
30913100
}
3092-
Py_DECREF(dump_var);
3093-
dump_var = NULL;
3094-
Py_DECREF(temp_obj2);
3095-
temp_obj2 = PyLong_FromUnsignedLongLong((unsigned long long)(i + 1));
3096-
if (temp_obj2 == NULL) {
3097-
goto fail_comb;
3101+
3102+
temp = PyLong_FromUnsignedLongLong((unsigned long long)i + 1);
3103+
if (temp == NULL) {
3104+
goto error;
30983105
}
3099-
dump_var = val;
3100-
val = PyNumber_FloorDivide(val, temp_obj2);
3101-
if (val == NULL) {
3102-
goto fail_comb;
3106+
Py_SETREF(result, PyNumber_FloorDivide(result, temp));
3107+
Py_DECREF(temp);
3108+
if (result == NULL) {
3109+
goto error;
31033110
}
3104-
Py_DECREF(dump_var);
3105-
Py_DECREF(temp_obj1);
3106-
Py_DECREF(temp_obj2);
31073111
}
3112+
Py_DECREF(factor);
31083113

3109-
return val;
3110-
3111-
fail_comb:
3112-
Py_XDECREF(val);
3113-
Py_XDECREF(dump_var);
3114-
Py_XDECREF(temp_obj1);
3115-
Py_XDECREF(temp_obj2);
3114+
done:
3115+
Py_DECREF(n);
3116+
Py_DECREF(k);
3117+
return result;
31163118

3119+
error:
3120+
Py_XDECREF(factor);
3121+
Py_XDECREF(result);
3122+
Py_DECREF(n);
3123+
Py_DECREF(k);
31173124
return NULL;
31183125
}
31193126

0 commit comments

Comments
 (0)