Skip to content

Commit 2a53a87

Browse files
Refactor math.comb() implementation.
1 parent 6650105 commit 2a53a87

File tree

3 files changed

+91
-82
lines changed

3 files changed

+91
-82
lines changed

Lib/test/test_math.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1911,7 +1911,22 @@ def testComb(self):
19111911
self.assertRaises(ValueError, comb, 1, 10**10)
19121912
self.assertRaises(ValueError, comb, 0, 1)
19131913

1914-
1914+
n = 2**1000
1915+
self.assertEqual(comb(n, 0), 1)
1916+
self.assertEqual(comb(n, 1), n)
1917+
self.assertEqual(comb(n, 2), n * (n-1) // 2)
1918+
self.assertEqual(comb(n, n), 1)
1919+
self.assertEqual(comb(n, n-1), n)
1920+
self.assertEqual(comb(n, n-2), n * (n-1) // 2)
1921+
self.assertRaises(MemoryError, comb, n, n//2)
1922+
1923+
self.assertEqual(comb(MyIndexable(5), MyIndexable(2)), 10)
1924+
self.assertRaises(TypeError, comb, 5.0, 2)
1925+
self.assertRaises(TypeError, comb, 5, 2.0)
1926+
self.assertRaises(TypeError, comb, decimal.Decimal(5.0), 2)
1927+
self.assertRaises(TypeError, comb, 5, decimal.Decimal(2.0))
1928+
self.assertRaises(TypeError, comb, '5', 2)
1929+
self.assertRaises(TypeError, comb, 5, '2')
19151930

19161931

19171932
def test_main():

Modules/clinic/mathmodule.c.h

Lines changed: 1 addition & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Modules/mathmodule.c

Lines changed: 74 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -3001,8 +3001,8 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start)
30013001
/*[clinic input]
30023002
math.comb
30033003
3004-
n: object(subclass_of='&PyLong_Type')
3005-
k: object(subclass_of='&PyLong_Type')
3004+
n: object
3005+
k: object
30063006
30073007
Number of ways to choose *k* items from *n* items without repetition and without order.
30083008
@@ -3017,103 +3017,105 @@ Raises ValueError if the arguments are negative or if k > n.
30173017

30183018
static PyObject *
30193019
math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
3020-
/*[clinic end generated code: output=bd2cec8d854f3493 input=565f340f98efb5b5]*/
3020+
/*[clinic end generated code: output=bd2cec8d854f3493 input=b2160da8fe59df60]*/
30213021
{
3022-
PyObject *val = NULL,
3023-
*temp_obj1 = NULL,
3024-
*temp_obj2 = NULL,
3025-
*dump_var = NULL;
3022+
PyObject *result = NULL, *factor = NULL, *temp;
30263023
int overflow, cmp;
3027-
long long i, terms;
3024+
long i, factors;
3025+
3026+
n = PyNumber_Index(n);
3027+
if (n == NULL)
3028+
return NULL;
3029+
k = PyNumber_Index(k);
3030+
if (k == NULL) {
3031+
Py_DECREF(n);
3032+
return NULL;
3033+
}
30283034

30293035
cmp = PyObject_RichCompareBool(n, k, Py_LT);
30303036
if (cmp < 0) {
3031-
goto fail_comb;
3037+
goto error;
30323038
}
30333039
else if (cmp > 0) {
3034-
PyErr_Format(PyExc_ValueError,
3035-
"n must be an integer greater than or equal to k");
3036-
goto fail_comb;
3040+
PyErr_SetString(PyExc_ValueError,
3041+
"n must be an integer greater than or equal to k");
3042+
goto error;
30373043
}
30383044

3039-
/* b = min(b, a - b) */
3040-
dump_var = PyNumber_Subtract(n, k);
3041-
if (dump_var == NULL) {
3042-
goto fail_comb;
3043-
}
3044-
cmp = PyObject_RichCompareBool(k, dump_var, Py_GT);
3045-
if (cmp < 0) {
3046-
goto fail_comb;
3045+
/* k = min(k, n - k) */
3046+
temp = PyNumber_Subtract(n, k);
3047+
if (temp == NULL) {
3048+
goto error;
30473049
}
3048-
else if (cmp > 0) {
3049-
k = dump_var;
3050-
dump_var = NULL;
3050+
cmp = PyObject_RichCompareBool(k, temp, Py_GT);
3051+
if (cmp > 0) {
3052+
Py_SETREF(k, temp);
30513053
}
30523054
else {
3053-
Py_DECREF(dump_var);
3054-
dump_var = NULL;
3055+
Py_DECREF(temp);
3056+
if (cmp < 0 && PyErr_Occurred()) {
3057+
goto error;
3058+
}
30553059
}
30563060

3057-
terms = PyLong_AsLongLongAndOverflow(k, &overflow);
3058-
if (terms < 0 && PyErr_Occurred()) {
3059-
goto fail_comb;
3061+
factors = PyLong_AsLongAndOverflow(k, &overflow);
3062+
if (overflow > 0) {
3063+
PyErr_NoMemory();
3064+
goto error;
30603065
}
3061-
else if (overflow > 0) {
3062-
PyErr_Format(PyExc_OverflowError,
3063-
"minimum(n - k, k) must not exceed %lld",
3064-
LLONG_MAX);
3065-
goto fail_comb;
3066+
else if (overflow < 0 || factors < 0) {
3067+
if (!PyErr_Occurred()) {
3068+
PyErr_SetString(PyExc_ValueError,
3069+
"k must be a positive integer");
3070+
}
3071+
goto error;
30663072
}
3067-
else if (overflow < 0 || terms < 0) {
3068-
PyErr_Format(PyExc_ValueError,
3069-
"k must be a positive integer");
3070-
goto fail_comb;
3073+
3074+
if (factors == 0) {
3075+
result = PyLong_FromLong(1);
3076+
goto done;
30713077
}
30723078

3073-
if (terms == 0) {
3074-
return PyNumber_Long(_PyLong_One);
3079+
result = n;
3080+
Py_INCREF(result);
3081+
if (factors == 1) {
3082+
goto done;
30753083
}
30763084

3077-
val = PyNumber_Long(n);
3078-
for (i = 1; i < terms; ++i) {
3079-
temp_obj1 = PyLong_FromSsize_t(i);
3080-
if (temp_obj1 == NULL) {
3081-
goto fail_comb;
3082-
}
3083-
temp_obj2 = PyNumber_Subtract(n, temp_obj1);
3084-
if (temp_obj2 == NULL) {
3085-
goto fail_comb;
3085+
factor = n;
3086+
Py_INCREF(factor);
3087+
for (i = 1; i < factors; ++i) {
3088+
Py_SETREF(factor, PyNumber_Subtract(factor, _PyLong_One));
3089+
if (factor == NULL) {
3090+
goto error;
30863091
}
3087-
dump_var = val;
3088-
val = PyNumber_Multiply(val, temp_obj2);
3089-
if (val == NULL) {
3090-
goto fail_comb;
3092+
Py_SETREF(result, PyNumber_Multiply(result, factor));
3093+
if (result == NULL) {
3094+
goto error;
30913095
}
3092-
Py_DECREF(dump_var);
3093-
dump_var = NULL;
3094-
Py_DECREF(temp_obj2);
3095-
temp_obj2 = PyLong_FromUnsignedLongLong((unsigned long long)(i + 1));
3096-
if (temp_obj2 == NULL) {
3097-
goto fail_comb;
3096+
3097+
temp = PyLong_FromUnsignedLong((unsigned long)i + 1);
3098+
if (temp == NULL) {
3099+
goto error;
30983100
}
3099-
dump_var = val;
3100-
val = PyNumber_FloorDivide(val, temp_obj2);
3101-
if (val == NULL) {
3102-
goto fail_comb;
3101+
Py_SETREF(result, PyNumber_FloorDivide(result, temp));
3102+
Py_DECREF(temp);
3103+
if (result == NULL) {
3104+
goto error;
31033105
}
3104-
Py_DECREF(dump_var);
3105-
Py_DECREF(temp_obj1);
3106-
Py_DECREF(temp_obj2);
31073106
}
3107+
Py_DECREF(factor);
31083108

3109-
return val;
3110-
3111-
fail_comb:
3112-
Py_XDECREF(val);
3113-
Py_XDECREF(dump_var);
3114-
Py_XDECREF(temp_obj1);
3115-
Py_XDECREF(temp_obj2);
3109+
done:
3110+
Py_DECREF(n);
3111+
Py_DECREF(k);
3112+
return result;
31163113

3114+
error:
3115+
Py_XDECREF(factor);
3116+
Py_XDECREF(result);
3117+
Py_DECREF(n);
3118+
Py_DECREF(k);
31173119
return NULL;
31183120
}
31193121

0 commit comments

Comments
 (0)