diff --git a/Include/cpython/interpreteridobject.h b/Include/cpython/interpreteridobject.h index cb72c2b0895..67ec5873542 100644 --- a/Include/cpython/interpreteridobject.h +++ b/Include/cpython/interpreteridobject.h @@ -14,8 +14,6 @@ PyAPI_FUNC(PyObject *) _PyInterpreterID_New(int64_t); PyAPI_FUNC(PyObject *) _PyInterpreterState_GetIDObject(PyInterpreterState *); PyAPI_FUNC(PyInterpreterState *) _PyInterpreterID_LookUp(PyObject *); -PyAPI_FUNC(int64_t) _Py_CoerceID(PyObject *); - #ifdef __cplusplus } #endif diff --git a/Lib/test/test__xxsubinterpreters.py b/Lib/test/test__xxsubinterpreters.py index 78b2030a1f6..f5474f4101d 100644 --- a/Lib/test/test__xxsubinterpreters.py +++ b/Lib/test/test__xxsubinterpreters.py @@ -526,30 +526,23 @@ class InterpreterIDTests(TestBase): self.assertEqual(int(id), 10) def test_coerce_id(self): - id = interpreters.InterpreterID('10', force=True) - self.assertEqual(int(id), 10) - - id = interpreters.InterpreterID(10.0, force=True) - self.assertEqual(int(id), 10) - class Int(str): - def __init__(self, value): - self._value = value - def __int__(self): - return self._value + def __index__(self): + return 10 - id = interpreters.InterpreterID(Int(10), force=True) - self.assertEqual(int(id), 10) + for id in ('10', '1_0', Int()): + with self.subTest(id=id): + id = interpreters.InterpreterID(id, force=True) + self.assertEqual(int(id), 10) def test_bad_id(self): - for id in [-1, 'spam']: - with self.subTest(id): - with self.assertRaises(ValueError): - interpreters.InterpreterID(id) - with self.assertRaises(OverflowError): - interpreters.InterpreterID(2**64) - with self.assertRaises(TypeError): - interpreters.InterpreterID(object()) + self.assertRaises(TypeError, interpreters.InterpreterID, object()) + self.assertRaises(TypeError, interpreters.InterpreterID, 10.0) + self.assertRaises(TypeError, interpreters.InterpreterID, b'10') + self.assertRaises(ValueError, interpreters.InterpreterID, -1) + self.assertRaises(ValueError, interpreters.InterpreterID, '-1') + self.assertRaises(ValueError, interpreters.InterpreterID, 'spam') + self.assertRaises(OverflowError, interpreters.InterpreterID, 2**64) def test_does_not_exist(self): id = interpreters.channel_create() @@ -572,6 +565,14 @@ class InterpreterIDTests(TestBase): self.assertTrue(id1 == id1) self.assertTrue(id1 == id2) self.assertTrue(id1 == int(id1)) + self.assertTrue(int(id1) == id1) + self.assertTrue(id1 == float(int(id1))) + self.assertTrue(float(int(id1)) == id1) + self.assertFalse(id1 == float(int(id1)) + 0.1) + self.assertFalse(id1 == str(int(id1))) + self.assertFalse(id1 == 2**1000) + self.assertFalse(id1 == float('inf')) + self.assertFalse(id1 == 'spam') self.assertFalse(id1 == id3) self.assertFalse(id1 != id1) @@ -1105,30 +1106,20 @@ class ChannelIDTests(TestBase): self.assertEqual(cid.end, 'both') def test_coerce_id(self): - cid = interpreters._channel_id('10', force=True) - self.assertEqual(int(cid), 10) - - cid = interpreters._channel_id(10.0, force=True) - self.assertEqual(int(cid), 10) - class Int(str): - def __init__(self, value): - self._value = value - def __int__(self): - return self._value + def __index__(self): + return 10 - cid = interpreters._channel_id(Int(10), force=True) + cid = interpreters._channel_id(Int(), force=True) self.assertEqual(int(cid), 10) def test_bad_id(self): - for cid in [-1, 'spam']: - with self.subTest(cid): - with self.assertRaises(ValueError): - interpreters._channel_id(cid) - with self.assertRaises(OverflowError): - interpreters._channel_id(2**64) - with self.assertRaises(TypeError): - interpreters._channel_id(object()) + self.assertRaises(TypeError, interpreters._channel_id, object()) + self.assertRaises(TypeError, interpreters._channel_id, 10.0) + self.assertRaises(TypeError, interpreters._channel_id, '10') + self.assertRaises(TypeError, interpreters._channel_id, b'10') + self.assertRaises(ValueError, interpreters._channel_id, -1) + self.assertRaises(OverflowError, interpreters._channel_id, 2**64) def test_bad_kwargs(self): with self.assertRaises(ValueError): @@ -1164,6 +1155,14 @@ class ChannelIDTests(TestBase): self.assertTrue(cid1 == cid1) self.assertTrue(cid1 == cid2) self.assertTrue(cid1 == int(cid1)) + self.assertTrue(int(cid1) == cid1) + self.assertTrue(cid1 == float(int(cid1))) + self.assertTrue(float(int(cid1)) == cid1) + self.assertFalse(cid1 == float(int(cid1)) + 0.1) + self.assertFalse(cid1 == str(int(cid1))) + self.assertFalse(cid1 == 2**1000) + self.assertFalse(cid1 == float('inf')) + self.assertFalse(cid1 == 'spam') self.assertFalse(cid1 == cid3) self.assertFalse(cid1 != cid1) diff --git a/Misc/NEWS.d/next/Core and Builtins/2019-09-02-20-00-31.bpo-38005.e7VsTA.rst b/Misc/NEWS.d/next/Core and Builtins/2019-09-02-20-00-31.bpo-38005.e7VsTA.rst new file mode 100644 index 00000000000..706abf587b9 --- /dev/null +++ b/Misc/NEWS.d/next/Core and Builtins/2019-09-02-20-00-31.bpo-38005.e7VsTA.rst @@ -0,0 +1 @@ +Fixed comparing and creating of InterpreterID and ChannelID. diff --git a/Modules/_xxsubinterpretersmodule.c b/Modules/_xxsubinterpretersmodule.c index 19d98fd9693..7842947e54a 100644 --- a/Modules/_xxsubinterpretersmodule.c +++ b/Modules/_xxsubinterpretersmodule.c @@ -1405,6 +1405,34 @@ typedef struct channelid { _channels *channels; } channelid; +static int +channel_id_converter(PyObject *arg, void *ptr) +{ + int64_t cid; + if (PyObject_TypeCheck(arg, &ChannelIDtype)) { + cid = ((channelid *)arg)->id; + } + else if (PyIndex_Check(arg)) { + cid = PyLong_AsLongLong(arg); + if (cid == -1 && PyErr_Occurred()) { + return 0; + } + if (cid < 0) { + PyErr_Format(PyExc_ValueError, + "channel ID must be a non-negative int, got %R", arg); + return 0; + } + } + else { + PyErr_Format(PyExc_TypeError, + "channel ID must be an int, got %.100s", + arg->ob_type->tp_name); + return 0; + } + *(int64_t *)ptr = cid; + return 1; +} + static channelid * newchannelid(PyTypeObject *cls, int64_t cid, int end, _channels *channels, int force, int resolve) @@ -1437,28 +1465,16 @@ static PyObject * channelid_new(PyTypeObject *cls, PyObject *args, PyObject *kwds) { static char *kwlist[] = {"id", "send", "recv", "force", "_resolve", NULL}; - PyObject *id; + int64_t cid; int send = -1; int recv = -1; int force = 0; int resolve = 0; if (!PyArg_ParseTupleAndKeywords(args, kwds, - "O|$pppp:ChannelID.__new__", kwlist, - &id, &send, &recv, &force, &resolve)) + "O&|$pppp:ChannelID.__new__", kwlist, + channel_id_converter, &cid, &send, &recv, &force, &resolve)) return NULL; - // Coerce and check the ID. - int64_t cid; - if (PyObject_TypeCheck(id, &ChannelIDtype)) { - cid = ((channelid *)id)->id; - } - else { - cid = _Py_CoerceID(id); - if (cid < 0) { - return NULL; - } - } - // Handle "send" and "recv". if (send == 0 && recv == 0) { PyErr_SetString(PyExc_ValueError, @@ -1592,30 +1608,28 @@ channelid_richcompare(PyObject *self, PyObject *other, int op) int equal; if (PyObject_TypeCheck(other, &ChannelIDtype)) { channelid *othercid = (channelid *)other; - if (cid->end != othercid->end) { - equal = 0; - } - else { - equal = (cid->id == othercid->id); - } + equal = (cid->end == othercid->end) && (cid->id == othercid->id); } - else { - other = PyNumber_Long(other); - if (other == NULL) { - PyErr_Clear(); - Py_RETURN_NOTIMPLEMENTED; - } - int64_t othercid = PyLong_AsLongLong(other); - Py_DECREF(other); - if (othercid == -1 && PyErr_Occurred() != NULL) { + else if (PyLong_Check(other)) { + /* Fast path */ + int overflow; + long long othercid = PyLong_AsLongLongAndOverflow(other, &overflow); + if (othercid == -1 && PyErr_Occurred()) { return NULL; } - if (othercid < 0) { - equal = 0; - } - else { - equal = (cid->id == othercid); + equal = !overflow && (othercid >= 0) && (cid->id == othercid); + } + else if (PyNumber_Check(other)) { + PyObject *pyid = PyLong_FromLongLong(cid->id); + if (pyid == NULL) { + return NULL; } + PyObject *res = PyObject_RichCompare(pyid, other, op); + Py_DECREF(pyid); + return res; + } + else { + Py_RETURN_NOTIMPLEMENTED; } if ((op == Py_EQ && equal) || (op == Py_NE && !equal)) { @@ -1754,8 +1768,7 @@ static PyTypeObject ChannelIDtype = { 0, /* tp_getattro */ 0, /* tp_setattro */ 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | - Py_TPFLAGS_LONG_SUBCLASS, /* tp_flags */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ channelid_doc, /* tp_doc */ 0, /* tp_traverse */ 0, /* tp_clear */ @@ -2017,10 +2030,6 @@ interp_destroy(PyObject *self, PyObject *args, PyObject *kwds) "O:destroy", kwlist, &id)) { return NULL; } - if (!PyLong_Check(id)) { - PyErr_SetString(PyExc_TypeError, "ID must be an int"); - return NULL; - } // Look up the interpreter. PyInterpreterState *interp = _PyInterpreterID_LookUp(id); @@ -2145,10 +2154,6 @@ interp_run_string(PyObject *self, PyObject *args, PyObject *kwds) &id, &code, &shared)) { return NULL; } - if (!PyLong_Check(id)) { - PyErr_SetString(PyExc_TypeError, "first arg (ID) must be an int"); - return NULL; - } // Look up the interpreter. PyInterpreterState *interp = _PyInterpreterID_LookUp(id); @@ -2216,10 +2221,6 @@ interp_is_running(PyObject *self, PyObject *args, PyObject *kwds) "O:is_running", kwlist, &id)) { return NULL; } - if (!PyLong_Check(id)) { - PyErr_SetString(PyExc_TypeError, "ID must be an int"); - return NULL; - } PyInterpreterState *interp = _PyInterpreterID_LookUp(id); if (interp == NULL) { @@ -2268,13 +2269,9 @@ static PyObject * channel_destroy(PyObject *self, PyObject *args, PyObject *kwds) { static char *kwlist[] = {"cid", NULL}; - PyObject *id; - if (!PyArg_ParseTupleAndKeywords(args, kwds, - "O:channel_destroy", kwlist, &id)) { - return NULL; - } - int64_t cid = _Py_CoerceID(id); - if (cid < 0) { + int64_t cid; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&:channel_destroy", kwlist, + channel_id_converter, &cid)) { return NULL; } @@ -2331,14 +2328,10 @@ static PyObject * channel_send(PyObject *self, PyObject *args, PyObject *kwds) { static char *kwlist[] = {"cid", "obj", NULL}; - PyObject *id; + int64_t cid; PyObject *obj; - if (!PyArg_ParseTupleAndKeywords(args, kwds, - "OO:channel_send", kwlist, &id, &obj)) { - return NULL; - } - int64_t cid = _Py_CoerceID(id); - if (cid < 0) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O:channel_send", kwlist, + channel_id_converter, &cid, &obj)) { return NULL; } @@ -2357,13 +2350,9 @@ static PyObject * channel_recv(PyObject *self, PyObject *args, PyObject *kwds) { static char *kwlist[] = {"cid", NULL}; - PyObject *id; - if (!PyArg_ParseTupleAndKeywords(args, kwds, - "O:channel_recv", kwlist, &id)) { - return NULL; - } - int64_t cid = _Py_CoerceID(id); - if (cid < 0) { + int64_t cid; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&:channel_recv", kwlist, + channel_id_converter, &cid)) { return NULL; } @@ -2379,17 +2368,13 @@ static PyObject * channel_close(PyObject *self, PyObject *args, PyObject *kwds) { static char *kwlist[] = {"cid", "send", "recv", "force", NULL}; - PyObject *id; + int64_t cid; int send = 0; int recv = 0; int force = 0; if (!PyArg_ParseTupleAndKeywords(args, kwds, - "O|$ppp:channel_close", kwlist, - &id, &send, &recv, &force)) { - return NULL; - } - int64_t cid = _Py_CoerceID(id); - if (cid < 0) { + "O&|$ppp:channel_close", kwlist, + channel_id_converter, &cid, &send, &recv, &force)) { return NULL; } @@ -2431,17 +2416,13 @@ channel_release(PyObject *self, PyObject *args, PyObject *kwds) { // Note that only the current interpreter is affected. static char *kwlist[] = {"cid", "send", "recv", "force", NULL}; - PyObject *id; + int64_t cid; int send = 0; int recv = 0; int force = 0; if (!PyArg_ParseTupleAndKeywords(args, kwds, - "O|$ppp:channel_release", kwlist, - &id, &send, &recv, &force)) { - return NULL; - } - int64_t cid = _Py_CoerceID(id); - if (cid < 0) { + "O&|$ppp:channel_release", kwlist, + channel_id_converter, &cid, &send, &recv, &force)) { return NULL; } if (send == 0 && recv == 0) { @@ -2538,7 +2519,6 @@ PyInit__xxsubinterpreters(void) } /* Initialize types */ - ChannelIDtype.tp_base = &PyLong_Type; if (PyType_Ready(&ChannelIDtype) != 0) { return NULL; } diff --git a/Objects/interpreteridobject.c b/Objects/interpreteridobject.c index 0a1dfa25795..3edbb85e6ac 100644 --- a/Objects/interpreteridobject.c +++ b/Objects/interpreteridobject.c @@ -5,38 +5,6 @@ #include "interpreteridobject.h" -int64_t -_Py_CoerceID(PyObject *orig) -{ - PyObject *pyid = PyNumber_Long(orig); - if (pyid == NULL) { - if (PyErr_ExceptionMatches(PyExc_TypeError)) { - PyErr_Format(PyExc_TypeError, - "'id' must be a non-negative int, got %R", orig); - } - else { - PyErr_Format(PyExc_ValueError, - "'id' must be a non-negative int, got %R", orig); - } - return -1; - } - int64_t id = PyLong_AsLongLong(pyid); - Py_DECREF(pyid); - if (id == -1 && PyErr_Occurred() != NULL) { - if (!PyErr_ExceptionMatches(PyExc_OverflowError)) { - PyErr_Format(PyExc_ValueError, - "'id' must be a non-negative int, got %R", orig); - } - return -1; - } - if (id < 0) { - PyErr_Format(PyExc_ValueError, - "'id' must be a non-negative int, got %R", orig); - return -1; - } - return id; -} - typedef struct interpid { PyObject_HEAD int64_t id; @@ -85,8 +53,31 @@ interpid_new(PyTypeObject *cls, PyObject *args, PyObject *kwds) id = ((interpid *)idobj)->id; } else { - id = _Py_CoerceID(idobj); + PyObject *pyid; + if (PyIndex_Check(idobj)) { + pyid = idobj; + Py_INCREF(pyid); + } + else if (PyUnicode_Check(idobj)) { + pyid = PyNumber_Long(idobj); + if (pyid == NULL) { + return NULL; + } + } + else { + PyErr_Format(PyExc_TypeError, + "interpreter ID must be an int, got %.100s", + idobj->ob_type->tp_name); + return NULL; + } + id = PyLong_AsLongLong(pyid); + Py_DECREF(pyid); + if (id == -1 && PyErr_Occurred()) { + return NULL; + } if (id < 0) { + PyErr_Format(PyExc_ValueError, + "interpreter ID must be a non-negative int, got %R", idobj); return NULL; } } @@ -202,23 +193,26 @@ interpid_richcompare(PyObject *self, PyObject *other, int op) interpid *otherid = (interpid *)other; equal = (id->id == otherid->id); } - else { - other = PyNumber_Long(other); - if (other == NULL) { - PyErr_Clear(); - Py_RETURN_NOTIMPLEMENTED; - } - int64_t otherid = PyLong_AsLongLong(other); - Py_DECREF(other); - if (otherid == -1 && PyErr_Occurred() != NULL) { + else if (PyLong_CheckExact(other)) { + /* Fast path */ + int overflow; + long long otherid = PyLong_AsLongLongAndOverflow(other, &overflow); + if (otherid == -1 && PyErr_Occurred()) { return NULL; } - if (otherid < 0) { - equal = 0; - } - else { - equal = (id->id == otherid); + equal = !overflow && (otherid >= 0) && (id->id == otherid); + } + else if (PyNumber_Check(other)) { + PyObject *pyid = PyLong_FromLongLong(id->id); + if (pyid == NULL) { + return NULL; } + PyObject *res = PyObject_RichCompare(pyid, other, op); + Py_DECREF(pyid); + return res; + } + else { + Py_RETURN_NOTIMPLEMENTED; } if ((op == Py_EQ && equal) || (op == Py_NE && !equal)) { @@ -250,8 +244,7 @@ PyTypeObject _PyInterpreterID_Type = { 0, /* tp_getattro */ 0, /* tp_setattro */ 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | - Py_TPFLAGS_LONG_SUBCLASS, /* tp_flags */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ interpid_doc, /* tp_doc */ 0, /* tp_traverse */ 0, /* tp_clear */ @@ -262,7 +255,7 @@ PyTypeObject _PyInterpreterID_Type = { 0, /* tp_methods */ 0, /* tp_members */ 0, /* tp_getset */ - &PyLong_Type, /* tp_base */ + 0, /* tp_base */ 0, /* tp_dict */ 0, /* tp_descr_get */ 0, /* tp_descr_set */ @@ -297,12 +290,17 @@ _PyInterpreterID_LookUp(PyObject *requested_id) if (PyObject_TypeCheck(requested_id, &_PyInterpreterID_Type)) { id = ((interpid *)requested_id)->id; } - else { + else if (PyIndex_Check(requested_id)) { id = PyLong_AsLongLong(requested_id); if (id == -1 && PyErr_Occurred() != NULL) { return NULL; } assert(id <= INT64_MAX); } + else { + PyErr_Format(PyExc_TypeError, "interpreter ID must be an int, got %.100s", + requested_id->ob_type->tp_name); + return NULL; + } return _PyInterpreterState_LookUpID(id); }