-
-
Notifications
You must be signed in to change notification settings - Fork 32.2k
bpo-38005: Fixed comparing and creating of InterpreterID and ChannelID. #15652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
326bd04
6c28b76
5524e62
b6bf9c3
f3603dd
b768077
052fc77
c8219f9
6d56a37
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -526,30 +526,23 @@ def test_with_int(self): | |
self.assertEqual(int(id), 10) | ||
|
||
def test_coerce_id(self): | ||
id = interpreters.InterpreterID('10', force=True) | ||
self.assertEqual(int(id), 10) | ||
|
||
id = interpreters.InterpreterID(10.0, force=True) | ||
self.assertEqual(int(id), 10) | ||
|
||
class Int(str): | ||
def __init__(self, value): | ||
self._value = value | ||
def __int__(self): | ||
return self._value | ||
def __index__(self): | ||
return 10 | ||
|
||
id = interpreters.InterpreterID(Int(10), force=True) | ||
self.assertEqual(int(id), 10) | ||
for id in ('10', '1_0', Int()): | ||
with self.subTest(id=id): | ||
id = interpreters.InterpreterID(id, force=True) | ||
self.assertEqual(int(id), 10) | ||
|
||
def test_bad_id(self): | ||
for id in [-1, 'spam']: | ||
with self.subTest(id): | ||
with self.assertRaises(ValueError): | ||
interpreters.InterpreterID(id) | ||
with self.assertRaises(OverflowError): | ||
interpreters.InterpreterID(2**64) | ||
with self.assertRaises(TypeError): | ||
interpreters.InterpreterID(object()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nice to have a test still for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I replaced it with |
||
self.assertRaises(TypeError, interpreters.InterpreterID, object()) | ||
self.assertRaises(TypeError, interpreters.InterpreterID, 10.0) | ||
self.assertRaises(TypeError, interpreters.InterpreterID, b'10') | ||
self.assertRaises(ValueError, interpreters.InterpreterID, -1) | ||
self.assertRaises(ValueError, interpreters.InterpreterID, '-1') | ||
self.assertRaises(ValueError, interpreters.InterpreterID, 'spam') | ||
self.assertRaises(OverflowError, interpreters.InterpreterID, 2**64) | ||
|
||
def test_does_not_exist(self): | ||
id = interpreters.channel_create() | ||
|
@@ -572,6 +565,14 @@ def test_equality(self): | |
self.assertTrue(id1 == id1) | ||
self.assertTrue(id1 == id2) | ||
self.assertTrue(id1 == int(id1)) | ||
self.assertTrue(int(id1) == id1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice :) |
||
self.assertTrue(id1 == float(int(id1))) | ||
self.assertTrue(float(int(id1)) == id1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This one test crashed with the past implementation. |
||
self.assertFalse(id1 == float(int(id1)) + 0.1) | ||
self.assertFalse(id1 == str(int(id1))) | ||
self.assertFalse(id1 == 2**1000) | ||
self.assertFalse(id1 == float('inf')) | ||
self.assertFalse(id1 == 'spam') | ||
self.assertFalse(id1 == id3) | ||
|
||
self.assertFalse(id1 != id1) | ||
|
@@ -1105,30 +1106,20 @@ def test_with_kwargs(self): | |
self.assertEqual(cid.end, 'both') | ||
|
||
def test_coerce_id(self): | ||
cid = interpreters._channel_id('10', force=True) | ||
self.assertEqual(int(cid), 10) | ||
|
||
cid = interpreters._channel_id(10.0, force=True) | ||
self.assertEqual(int(cid), 10) | ||
|
||
class Int(str): | ||
def __init__(self, value): | ||
self._value = value | ||
def __int__(self): | ||
return self._value | ||
def __index__(self): | ||
return 10 | ||
|
||
cid = interpreters._channel_id(Int(10), force=True) | ||
cid = interpreters._channel_id(Int(), force=True) | ||
self.assertEqual(int(cid), 10) | ||
|
||
def test_bad_id(self): | ||
for cid in [-1, 'spam']: | ||
with self.subTest(cid): | ||
with self.assertRaises(ValueError): | ||
interpreters._channel_id(cid) | ||
with self.assertRaises(OverflowError): | ||
interpreters._channel_id(2**64) | ||
with self.assertRaises(TypeError): | ||
interpreters._channel_id(object()) | ||
self.assertRaises(TypeError, interpreters._channel_id, object()) | ||
self.assertRaises(TypeError, interpreters._channel_id, 10.0) | ||
self.assertRaises(TypeError, interpreters._channel_id, '10') | ||
self.assertRaises(TypeError, interpreters._channel_id, b'10') | ||
self.assertRaises(ValueError, interpreters._channel_id, -1) | ||
self.assertRaises(OverflowError, interpreters._channel_id, 2**64) | ||
|
||
def test_bad_kwargs(self): | ||
with self.assertRaises(ValueError): | ||
|
@@ -1164,6 +1155,14 @@ def test_equality(self): | |
self.assertTrue(cid1 == cid1) | ||
self.assertTrue(cid1 == cid2) | ||
self.assertTrue(cid1 == int(cid1)) | ||
self.assertTrue(int(cid1) == cid1) | ||
self.assertTrue(cid1 == float(int(cid1))) | ||
self.assertTrue(float(int(cid1)) == cid1) | ||
self.assertFalse(cid1 == float(int(cid1)) + 0.1) | ||
self.assertFalse(cid1 == str(int(cid1))) | ||
self.assertFalse(cid1 == 2**1000) | ||
self.assertFalse(cid1 == float('inf')) | ||
self.assertFalse(cid1 == 'spam') | ||
self.assertFalse(cid1 == cid3) | ||
|
||
self.assertFalse(cid1 != cid1) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Fixed comparing and creating of InterpreterID and ChannelID. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1405,6 +1405,34 @@ typedef struct channelid { | |
_channels *channels; | ||
} channelid; | ||
|
||
static int | ||
channel_id_converter(PyObject *arg, void *ptr) | ||
{ | ||
int64_t cid; | ||
if (PyObject_TypeCheck(arg, &ChannelIDtype)) { | ||
cid = ((channelid *)arg)->id; | ||
} | ||
else if (PyIndex_Check(arg)) { | ||
cid = PyLong_AsLongLong(arg); | ||
if (cid == -1 && PyErr_Occurred()) { | ||
return 0; | ||
} | ||
if (cid < 0) { | ||
PyErr_Format(PyExc_ValueError, | ||
"channel ID must be a non-negative int, got %R", arg); | ||
return 0; | ||
} | ||
} | ||
else { | ||
PyErr_Format(PyExc_TypeError, | ||
"channel ID must be an int, got %.100s", | ||
arg->ob_type->tp_name); | ||
return 0; | ||
} | ||
*(int64_t *)ptr = cid; | ||
return 1; | ||
} | ||
|
||
static channelid * | ||
newchannelid(PyTypeObject *cls, int64_t cid, int end, _channels *channels, | ||
int force, int resolve) | ||
|
@@ -1437,28 +1465,16 @@ static PyObject * | |
channelid_new(PyTypeObject *cls, PyObject *args, PyObject *kwds) | ||
{ | ||
static char *kwlist[] = {"id", "send", "recv", "force", "_resolve", NULL}; | ||
PyObject *id; | ||
int64_t cid; | ||
int send = -1; | ||
int recv = -1; | ||
int force = 0; | ||
int resolve = 0; | ||
if (!PyArg_ParseTupleAndKeywords(args, kwds, | ||
"O|$pppp:ChannelID.__new__", kwlist, | ||
&id, &send, &recv, &force, &resolve)) | ||
"O&|$pppp:ChannelID.__new__", kwlist, | ||
channel_id_converter, &cid, &send, &recv, &force, &resolve)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a neat trick. :) |
||
return NULL; | ||
|
||
// Coerce and check the ID. | ||
int64_t cid; | ||
if (PyObject_TypeCheck(id, &ChannelIDtype)) { | ||
cid = ((channelid *)id)->id; | ||
} | ||
else { | ||
cid = _Py_CoerceID(id); | ||
if (cid < 0) { | ||
return NULL; | ||
} | ||
} | ||
|
||
// Handle "send" and "recv". | ||
if (send == 0 && recv == 0) { | ||
PyErr_SetString(PyExc_ValueError, | ||
|
@@ -1592,30 +1608,28 @@ channelid_richcompare(PyObject *self, PyObject *other, int op) | |
int equal; | ||
if (PyObject_TypeCheck(other, &ChannelIDtype)) { | ||
channelid *othercid = (channelid *)other; | ||
if (cid->end != othercid->end) { | ||
equal = 0; | ||
} | ||
else { | ||
equal = (cid->id == othercid->id); | ||
} | ||
equal = (cid->end == othercid->end) && (cid->id == othercid->id); | ||
} | ||
else { | ||
other = PyNumber_Long(other); | ||
if (other == NULL) { | ||
PyErr_Clear(); | ||
Py_RETURN_NOTIMPLEMENTED; | ||
} | ||
int64_t othercid = PyLong_AsLongLong(other); | ||
Py_DECREF(other); | ||
if (othercid == -1 && PyErr_Occurred() != NULL) { | ||
else if (PyLong_Check(other)) { | ||
/* Fast path */ | ||
int overflow; | ||
long long othercid = PyLong_AsLongLongAndOverflow(other, &overflow); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use int64_t or the macro? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using But using |
||
if (othercid == -1 && PyErr_Occurred()) { | ||
return NULL; | ||
} | ||
if (othercid < 0) { | ||
equal = 0; | ||
} | ||
else { | ||
equal = (cid->id == othercid); | ||
equal = !overflow && (othercid >= 0) && (cid->id == othercid); | ||
} | ||
else if (PyNumber_Check(other)) { | ||
PyObject *pyid = PyLong_FromLongLong(cid->id); | ||
if (pyid == NULL) { | ||
return NULL; | ||
} | ||
PyObject *res = PyObject_RichCompare(pyid, other, op); | ||
Py_DECREF(pyid); | ||
return res; | ||
} | ||
else { | ||
Py_RETURN_NOTIMPLEMENTED; | ||
} | ||
|
||
if ((op == Py_EQ && equal) || (op == Py_NE && !equal)) { | ||
|
@@ -1754,8 +1768,7 @@ static PyTypeObject ChannelIDtype = { | |
0, /* tp_getattro */ | ||
0, /* tp_setattro */ | ||
0, /* tp_as_buffer */ | ||
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | | ||
Py_TPFLAGS_LONG_SUBCLASS, /* tp_flags */ | ||
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ | ||
channelid_doc, /* tp_doc */ | ||
0, /* tp_traverse */ | ||
0, /* tp_clear */ | ||
|
@@ -2017,10 +2030,6 @@ interp_destroy(PyObject *self, PyObject *args, PyObject *kwds) | |
"O:destroy", kwlist, &id)) { | ||
return NULL; | ||
} | ||
if (!PyLong_Check(id)) { | ||
PyErr_SetString(PyExc_TypeError, "ID must be an int"); | ||
return NULL; | ||
} | ||
|
||
// Look up the interpreter. | ||
PyInterpreterState *interp = _PyInterpreterID_LookUp(id); | ||
|
@@ -2145,10 +2154,6 @@ interp_run_string(PyObject *self, PyObject *args, PyObject *kwds) | |
&id, &code, &shared)) { | ||
return NULL; | ||
} | ||
if (!PyLong_Check(id)) { | ||
PyErr_SetString(PyExc_TypeError, "first arg (ID) must be an int"); | ||
return NULL; | ||
} | ||
|
||
// Look up the interpreter. | ||
PyInterpreterState *interp = _PyInterpreterID_LookUp(id); | ||
|
@@ -2216,10 +2221,6 @@ interp_is_running(PyObject *self, PyObject *args, PyObject *kwds) | |
"O:is_running", kwlist, &id)) { | ||
return NULL; | ||
} | ||
if (!PyLong_Check(id)) { | ||
PyErr_SetString(PyExc_TypeError, "ID must be an int"); | ||
return NULL; | ||
} | ||
|
||
PyInterpreterState *interp = _PyInterpreterID_LookUp(id); | ||
if (interp == NULL) { | ||
|
@@ -2268,13 +2269,9 @@ static PyObject * | |
channel_destroy(PyObject *self, PyObject *args, PyObject *kwds) | ||
{ | ||
static char *kwlist[] = {"cid", NULL}; | ||
PyObject *id; | ||
if (!PyArg_ParseTupleAndKeywords(args, kwds, | ||
"O:channel_destroy", kwlist, &id)) { | ||
return NULL; | ||
} | ||
int64_t cid = _Py_CoerceID(id); | ||
if (cid < 0) { | ||
int64_t cid; | ||
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&:channel_destroy", kwlist, | ||
channel_id_converter, &cid)) { | ||
return NULL; | ||
} | ||
|
||
|
@@ -2331,14 +2328,10 @@ static PyObject * | |
channel_send(PyObject *self, PyObject *args, PyObject *kwds) | ||
{ | ||
static char *kwlist[] = {"cid", "obj", NULL}; | ||
PyObject *id; | ||
int64_t cid; | ||
PyObject *obj; | ||
if (!PyArg_ParseTupleAndKeywords(args, kwds, | ||
"OO:channel_send", kwlist, &id, &obj)) { | ||
return NULL; | ||
} | ||
int64_t cid = _Py_CoerceID(id); | ||
if (cid < 0) { | ||
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O:channel_send", kwlist, | ||
channel_id_converter, &cid, &obj)) { | ||
return NULL; | ||
} | ||
|
||
|
@@ -2357,13 +2350,9 @@ static PyObject * | |
channel_recv(PyObject *self, PyObject *args, PyObject *kwds) | ||
{ | ||
static char *kwlist[] = {"cid", NULL}; | ||
PyObject *id; | ||
if (!PyArg_ParseTupleAndKeywords(args, kwds, | ||
"O:channel_recv", kwlist, &id)) { | ||
return NULL; | ||
} | ||
int64_t cid = _Py_CoerceID(id); | ||
if (cid < 0) { | ||
int64_t cid; | ||
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&:channel_recv", kwlist, | ||
channel_id_converter, &cid)) { | ||
return NULL; | ||
} | ||
|
||
|
@@ -2379,17 +2368,13 @@ static PyObject * | |
channel_close(PyObject *self, PyObject *args, PyObject *kwds) | ||
{ | ||
static char *kwlist[] = {"cid", "send", "recv", "force", NULL}; | ||
PyObject *id; | ||
int64_t cid; | ||
int send = 0; | ||
int recv = 0; | ||
int force = 0; | ||
if (!PyArg_ParseTupleAndKeywords(args, kwds, | ||
"O|$ppp:channel_close", kwlist, | ||
&id, &send, &recv, &force)) { | ||
return NULL; | ||
} | ||
int64_t cid = _Py_CoerceID(id); | ||
if (cid < 0) { | ||
"O&|$ppp:channel_close", kwlist, | ||
channel_id_converter, &cid, &send, &recv, &force)) { | ||
return NULL; | ||
} | ||
|
||
|
@@ -2431,17 +2416,13 @@ channel_release(PyObject *self, PyObject *args, PyObject *kwds) | |
{ | ||
// Note that only the current interpreter is affected. | ||
static char *kwlist[] = {"cid", "send", "recv", "force", NULL}; | ||
PyObject *id; | ||
int64_t cid; | ||
int send = 0; | ||
int recv = 0; | ||
int force = 0; | ||
if (!PyArg_ParseTupleAndKeywords(args, kwds, | ||
"O|$ppp:channel_release", kwlist, | ||
&id, &send, &recv, &force)) { | ||
return NULL; | ||
} | ||
int64_t cid = _Py_CoerceID(id); | ||
if (cid < 0) { | ||
"O&|$ppp:channel_release", kwlist, | ||
channel_id_converter, &cid, &send, &recv, &force)) { | ||
return NULL; | ||
} | ||
if (send == 0 && recv == 0) { | ||
|
@@ -2538,7 +2519,6 @@ PyInit__xxsubinterpreters(void) | |
} | ||
|
||
/* Initialize types */ | ||
ChannelIDtype.tp_base = &PyLong_Type; | ||
if (PyType_Ready(&ChannelIDtype) != 0) { | ||
return NULL; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__index__
makes sense, but what about__int__
too. Shouldn't both work?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
float
has__int__
, but not__index__
.(10.1).__int__()
returns10
. So__int__
should not work.Note that I changed
return self._value
toreturn 10
becauseint(10)
andint(str(10))
has the same result. So it was not clear what way thestr
subclass with__int__
used, bot has the same result.