diff --git a/Lib/test/test__xxsubinterpreters.py b/Lib/test/test__xxsubinterpreters.py index f5474f4101d..207b5db5d8f 100644 --- a/Lib/test/test__xxsubinterpreters.py +++ b/Lib/test/test__xxsubinterpreters.py @@ -514,7 +514,7 @@ class IsRunningTests(TestBase): interpreters.is_running(1_000_000) def test_bad_id(self): - with self.assertRaises(RuntimeError): + with self.assertRaises(ValueError): interpreters.is_running(-1) @@ -530,18 +530,15 @@ class InterpreterIDTests(TestBase): def __index__(self): return 10 - for id in ('10', '1_0', Int()): - with self.subTest(id=id): - id = interpreters.InterpreterID(id, force=True) - self.assertEqual(int(id), 10) + id = interpreters.InterpreterID(Int(), force=True) + self.assertEqual(int(id), 10) def test_bad_id(self): self.assertRaises(TypeError, interpreters.InterpreterID, object()) self.assertRaises(TypeError, interpreters.InterpreterID, 10.0) + self.assertRaises(TypeError, interpreters.InterpreterID, '10') self.assertRaises(TypeError, interpreters.InterpreterID, b'10') self.assertRaises(ValueError, interpreters.InterpreterID, -1) - self.assertRaises(ValueError, interpreters.InterpreterID, '-1') - self.assertRaises(ValueError, interpreters.InterpreterID, 'spam') self.assertRaises(OverflowError, interpreters.InterpreterID, 2**64) def test_does_not_exist(self): @@ -720,7 +717,7 @@ class DestroyTests(TestBase): interpreters.destroy(1_000_000) def test_bad_id(self): - with self.assertRaises(RuntimeError): + with self.assertRaises(ValueError): interpreters.destroy(-1) def test_from_current(self): @@ -863,7 +860,7 @@ class RunStringTests(TestBase): interpreters.run_string(id, 'print("spam")') def test_error_id(self): - with self.assertRaises(RuntimeError): + with self.assertRaises(ValueError): interpreters.run_string(-1, 'print("spam")') def test_bad_id(self): diff --git a/Objects/interpreteridobject.c b/Objects/interpreteridobject.c index 3edbb85e6ac..94f5dd709bb 100644 --- a/Objects/interpreteridobject.c +++ b/Objects/interpreteridobject.c @@ -35,53 +35,46 @@ newinterpid(PyTypeObject *cls, int64_t id, int force) return self; } +static int +interp_id_converter(PyObject *arg, void *ptr) +{ + int64_t id; + if (PyObject_TypeCheck(arg, &_PyInterpreterID_Type)) { + id = ((interpid *)arg)->id; + } + else if (PyIndex_Check(arg)) { + id = PyLong_AsLongLong(arg); + if (id == -1 && PyErr_Occurred()) { + return 0; + } + if (id < 0) { + PyErr_Format(PyExc_ValueError, + "interpreter ID must be a non-negative int, got %R", arg); + return 0; + } + } + else { + PyErr_Format(PyExc_TypeError, + "interpreter ID must be an int, got %.100s", + arg->ob_type->tp_name); + return 0; + } + *(int64_t *)ptr = id; + return 1; +} + static PyObject * interpid_new(PyTypeObject *cls, PyObject *args, PyObject *kwds) { static char *kwlist[] = {"id", "force", NULL}; - PyObject *idobj; + int64_t id; int force = 0; if (!PyArg_ParseTupleAndKeywords(args, kwds, - "O|$p:InterpreterID.__init__", kwlist, - &idobj, &force)) { + "O&|$p:InterpreterID.__init__", kwlist, + interp_id_converter, &id, &force)) { return NULL; } - // Coerce and check the ID. - int64_t id; - if (PyObject_TypeCheck(idobj, &_PyInterpreterID_Type)) { - id = ((interpid *)idobj)->id; - } - else { - PyObject *pyid; - if (PyIndex_Check(idobj)) { - pyid = idobj; - Py_INCREF(pyid); - } - else if (PyUnicode_Check(idobj)) { - pyid = PyNumber_Long(idobj); - if (pyid == NULL) { - return NULL; - } - } - else { - PyErr_Format(PyExc_TypeError, - "interpreter ID must be an int, got %.100s", - idobj->ob_type->tp_name); - return NULL; - } - id = PyLong_AsLongLong(pyid); - Py_DECREF(pyid); - if (id == -1 && PyErr_Occurred()) { - return NULL; - } - if (id < 0) { - PyErr_Format(PyExc_ValueError, - "interpreter ID must be a non-negative int, got %R", idobj); - return NULL; - } - } - return (PyObject *)newinterpid(cls, id, force); } @@ -287,19 +280,7 @@ PyInterpreterState * _PyInterpreterID_LookUp(PyObject *requested_id) { int64_t id; - if (PyObject_TypeCheck(requested_id, &_PyInterpreterID_Type)) { - id = ((interpid *)requested_id)->id; - } - else if (PyIndex_Check(requested_id)) { - id = PyLong_AsLongLong(requested_id); - if (id == -1 && PyErr_Occurred() != NULL) { - return NULL; - } - assert(id <= INT64_MAX); - } - else { - PyErr_Format(PyExc_TypeError, "interpreter ID must be an int, got %.100s", - requested_id->ob_type->tp_name); + if (!interp_id_converter(requested_id, &id)) { return NULL; } return _PyInterpreterState_LookUpID(id);