From c58c63fdf615a1c2bfc995dd0b938d82e32b6cde Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Tue, 17 Oct 2023 17:05:49 -0600 Subject: [PATCH] gh-84570: Add Timeouts to SendChannel.send() and RecvChannel.recv() (gh-110567) --- Include/internal/pycore_pythread.h | 6 ++ Lib/test/support/interpreters.py | 20 +++-- Lib/test/test__xxinterpchannels.py | 134 ++++++++++++++++++++++++----- Lib/test/test_interpreters.py | 5 ++ Modules/_queuemodule.c | 2 + Modules/_threadmodule.c | 11 +-- Modules/_xxinterpchannelsmodule.c | 43 +++++---- Python/thread.c | 34 ++++++++ 8 files changed, 205 insertions(+), 50 deletions(-) diff --git a/Include/internal/pycore_pythread.h b/Include/internal/pycore_pythread.h index ffd7398eaee..d31ffc78130 100644 --- a/Include/internal/pycore_pythread.h +++ b/Include/internal/pycore_pythread.h @@ -89,6 +89,12 @@ extern int _PyThread_at_fork_reinit(PyThread_type_lock *lock); // unset: -1 seconds, in nanoseconds #define PyThread_UNSET_TIMEOUT ((_PyTime_t)(-1 * 1000 * 1000 * 1000)) +// Exported for the _xxinterpchannels module. +PyAPI_FUNC(int) PyThread_ParseTimeoutArg( + PyObject *arg, + int blocking, + PY_TIMEOUT_T *timeout); + /* Helper to acquire an interruptible lock with a timeout. If the lock acquire * is interrupted, signal handlers are run, and if they raise an exception, * PY_LOCK_INTR is returned. Otherwise, PY_LOCK_ACQUIRED or PY_LOCK_FAILURE diff --git a/Lib/test/support/interpreters.py b/Lib/test/support/interpreters.py index 9ba6862a9ee..f8f42c0e024 100644 --- a/Lib/test/support/interpreters.py +++ b/Lib/test/support/interpreters.py @@ -170,15 +170,25 @@ class RecvChannel(_ChannelEnd): _end = 'recv' - def recv(self, *, _sentinel=object(), _delay=10 / 1000): # 10 milliseconds + def recv(self, timeout=None, *, + _sentinel=object(), + _delay=10 / 1000, # 10 milliseconds + ): """Return the next object from the channel. This blocks until an object has been sent, if none have been sent already. """ + if timeout is not None: + timeout = int(timeout) + if timeout < 0: + raise ValueError(f'timeout value must be non-negative') + end = time.time() + timeout obj = _channels.recv(self._id, _sentinel) while obj is _sentinel: time.sleep(_delay) + if timeout is not None and time.time() >= end: + raise TimeoutError obj = _channels.recv(self._id, _sentinel) return obj @@ -203,12 +213,12 @@ class SendChannel(_ChannelEnd): _end = 'send' - def send(self, obj): + def send(self, obj, timeout=None): """Send the object (i.e. its data) to the channel's receiving end. This blocks until the object is received. """ - _channels.send(self._id, obj, blocking=True) + _channels.send(self._id, obj, timeout=timeout, blocking=True) def send_nowait(self, obj): """Send the object to the channel's receiving end. @@ -221,12 +231,12 @@ class SendChannel(_ChannelEnd): # See bpo-32604 and gh-19829. return _channels.send(self._id, obj, blocking=False) - def send_buffer(self, obj): + def send_buffer(self, obj, timeout=None): """Send the object's buffer to the channel's receiving end. This blocks until the object is received. """ - _channels.send_buffer(self._id, obj, blocking=True) + _channels.send_buffer(self._id, obj, timeout=timeout, blocking=True) def send_buffer_nowait(self, obj): """Send the object's buffer to the channel's receiving end. diff --git a/Lib/test/test__xxinterpchannels.py b/Lib/test/test__xxinterpchannels.py index 90a1224498f..1c1ef3fac9d 100644 --- a/Lib/test/test__xxinterpchannels.py +++ b/Lib/test/test__xxinterpchannels.py @@ -864,22 +864,34 @@ class ChannelTests(TestBase): self.assertEqual(received, obj) - def test_send_closed_while_waiting(self): + def test_send_timeout(self): obj = b'spam' - wait = self.build_send_waiter(obj) - cid = channels.create() - def f(): - wait() - channels.close(cid, force=True) - t = threading.Thread(target=f) - t.start() - with self.assertRaises(channels.ChannelClosedError): - channels.send(cid, obj, blocking=True) - t.join() - def test_send_buffer_closed_while_waiting(self): + with self.subTest('non-blocking with timeout'): + cid = channels.create() + with self.assertRaises(ValueError): + channels.send(cid, obj, blocking=False, timeout=0.1) + + with self.subTest('timeout hit'): + cid = channels.create() + with self.assertRaises(TimeoutError): + channels.send(cid, obj, blocking=True, timeout=0.1) + with self.assertRaises(channels.ChannelEmptyError): + received = channels.recv(cid) + print(repr(received)) + + with self.subTest('timeout not hit'): + cid = channels.create() + def f(): + recv_wait(cid) + t = threading.Thread(target=f) + t.start() + channels.send(cid, obj, blocking=True, timeout=10) + t.join() + + def test_send_buffer_timeout(self): try: - self._has_run_once + self._has_run_once_timeout except AttributeError: # At the moment, this test leaks a few references. # It looks like the leak originates with the addition @@ -888,19 +900,95 @@ class ChannelTests(TestBase): # if the refleak isn't fixed yet, so we skip here. raise unittest.SkipTest('temporarily skipped due to refleaks') else: - self._has_run_once = True + self._has_run_once_timeout = True + + obj = bytearray(b'spam') + + with self.subTest('non-blocking with timeout'): + cid = channels.create() + with self.assertRaises(ValueError): + channels.send_buffer(cid, obj, blocking=False, timeout=0.1) + + with self.subTest('timeout hit'): + cid = channels.create() + with self.assertRaises(TimeoutError): + channels.send_buffer(cid, obj, blocking=True, timeout=0.1) + with self.assertRaises(channels.ChannelEmptyError): + received = channels.recv(cid) + print(repr(received)) + + with self.subTest('timeout not hit'): + cid = channels.create() + def f(): + recv_wait(cid) + t = threading.Thread(target=f) + t.start() + channels.send_buffer(cid, obj, blocking=True, timeout=10) + t.join() + + def test_send_closed_while_waiting(self): + obj = b'spam' + wait = self.build_send_waiter(obj) + + with self.subTest('without timeout'): + cid = channels.create() + def f(): + wait() + channels.close(cid, force=True) + t = threading.Thread(target=f) + t.start() + with self.assertRaises(channels.ChannelClosedError): + channels.send(cid, obj, blocking=True) + t.join() + + with self.subTest('with timeout'): + cid = channels.create() + def f(): + wait() + channels.close(cid, force=True) + t = threading.Thread(target=f) + t.start() + with self.assertRaises(channels.ChannelClosedError): + channels.send(cid, obj, blocking=True, timeout=30) + t.join() + + def test_send_buffer_closed_while_waiting(self): + try: + self._has_run_once_closed + except AttributeError: + # At the moment, this test leaks a few references. + # It looks like the leak originates with the addition + # of _channels.send_buffer() (gh-110246), whereas the + # tests were added afterward. We want this test even + # if the refleak isn't fixed yet, so we skip here. + raise unittest.SkipTest('temporarily skipped due to refleaks') + else: + self._has_run_once_closed = True obj = bytearray(b'spam') wait = self.build_send_waiter(obj, buffer=True) - cid = channels.create() - def f(): - wait() - channels.close(cid, force=True) - t = threading.Thread(target=f) - t.start() - with self.assertRaises(channels.ChannelClosedError): - channels.send_buffer(cid, obj, blocking=True) - t.join() + + with self.subTest('without timeout'): + cid = channels.create() + def f(): + wait() + channels.close(cid, force=True) + t = threading.Thread(target=f) + t.start() + with self.assertRaises(channels.ChannelClosedError): + channels.send_buffer(cid, obj, blocking=True) + t.join() + + with self.subTest('with timeout'): + cid = channels.create() + def f(): + wait() + channels.close(cid, force=True) + t = threading.Thread(target=f) + t.start() + with self.assertRaises(channels.ChannelClosedError): + channels.send_buffer(cid, obj, blocking=True, timeout=30) + t.join() #------------------- # close diff --git a/Lib/test/test_interpreters.py b/Lib/test/test_interpreters.py index 0910b51bfe5..d2d52ec9a78 100644 --- a/Lib/test/test_interpreters.py +++ b/Lib/test/test_interpreters.py @@ -1022,6 +1022,11 @@ class TestSendRecv(TestBase): self.assertEqual(obj2, b'eggs') self.assertNotEqual(id(obj2), int(out)) + def test_recv_timeout(self): + r, _ = interpreters.create_channel() + with self.assertRaises(TimeoutError): + r.recv(timeout=1) + def test_recv_channel_does_not_exist(self): ch = interpreters.RecvChannel(1_000_000) with self.assertRaises(interpreters.ChannelNotFoundError): diff --git a/Modules/_queuemodule.c b/Modules/_queuemodule.c index b4bafb375c9..81a06cdb79a 100644 --- a/Modules/_queuemodule.c +++ b/Modules/_queuemodule.c @@ -214,6 +214,8 @@ _queue_SimpleQueue_get_impl(simplequeueobject *self, PyTypeObject *cls, PY_TIMEOUT_T microseconds; PyThreadState *tstate = PyThreadState_Get(); + // XXX Use PyThread_ParseTimeoutArg(). + if (block == 0) { /* Non-blocking */ microseconds = 0; diff --git a/Modules/_threadmodule.c b/Modules/_threadmodule.c index 7620511dd1d..4d453040503 100644 --- a/Modules/_threadmodule.c +++ b/Modules/_threadmodule.c @@ -88,14 +88,15 @@ lock_acquire_parse_args(PyObject *args, PyObject *kwds, char *kwlist[] = {"blocking", "timeout", NULL}; int blocking = 1; PyObject *timeout_obj = NULL; - const _PyTime_t unset_timeout = _PyTime_FromSeconds(-1); - - *timeout = unset_timeout ; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "|pO:acquire", kwlist, &blocking, &timeout_obj)) return -1; + // XXX Use PyThread_ParseTimeoutArg(). + + const _PyTime_t unset_timeout = _PyTime_FromSeconds(-1); + *timeout = unset_timeout; + if (timeout_obj && _PyTime_FromSecondsObject(timeout, timeout_obj, _PyTime_ROUND_TIMEOUT) < 0) @@ -108,7 +109,7 @@ lock_acquire_parse_args(PyObject *args, PyObject *kwds, } if (*timeout < 0 && *timeout != unset_timeout) { PyErr_SetString(PyExc_ValueError, - "timeout value must be positive"); + "timeout value must be a non-negative number"); return -1; } if (!blocking) diff --git a/Modules/_xxinterpchannelsmodule.c b/Modules/_xxinterpchannelsmodule.c index be53cbfc39b..2e2878d5c20 100644 --- a/Modules/_xxinterpchannelsmodule.c +++ b/Modules/_xxinterpchannelsmodule.c @@ -242,9 +242,8 @@ add_new_type(PyObject *mod, PyType_Spec *spec, crossinterpdatafunc shared, } static int -wait_for_lock(PyThread_type_lock mutex) +wait_for_lock(PyThread_type_lock mutex, PY_TIMEOUT_T timeout) { - PY_TIMEOUT_T timeout = PyThread_UNSET_TIMEOUT; PyLockStatus res = PyThread_acquire_lock_timed_with_retries(mutex, timeout); if (res == PY_LOCK_INTR) { /* KeyboardInterrupt, etc. */ @@ -1883,7 +1882,8 @@ _channel_clear_sent(_channels *channels, int64_t cid, _waiting_t *waiting) } static int -_channel_send_wait(_channels *channels, int64_t cid, PyObject *obj) +_channel_send_wait(_channels *channels, int64_t cid, PyObject *obj, + PY_TIMEOUT_T timeout) { // We use a stack variable here, so we must ensure that &waiting // is not held by any channel item at the point this function exits. @@ -1901,7 +1901,7 @@ _channel_send_wait(_channels *channels, int64_t cid, PyObject *obj) } /* Wait until the object is received. */ - if (wait_for_lock(waiting.mutex) < 0) { + if (wait_for_lock(waiting.mutex, timeout) < 0) { assert(PyErr_Occurred()); _waiting_finish_releasing(&waiting); /* The send() call is failing now, so make sure the item @@ -2816,25 +2816,29 @@ receive end."); static PyObject * channel_send(PyObject *self, PyObject *args, PyObject *kwds) { - // XXX Add a timeout arg. - static char *kwlist[] = {"cid", "obj", "blocking", NULL}; - int64_t cid; + static char *kwlist[] = {"cid", "obj", "blocking", "timeout", NULL}; struct channel_id_converter_data cid_data = { .module = self, }; PyObject *obj; int blocking = 1; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O|$p:channel_send", kwlist, + PyObject *timeout_obj = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O|$pO:channel_send", kwlist, channel_id_converter, &cid_data, &obj, - &blocking)) { + &blocking, &timeout_obj)) { + return NULL; + } + + int64_t cid = cid_data.cid; + PY_TIMEOUT_T timeout; + if (PyThread_ParseTimeoutArg(timeout_obj, blocking, &timeout) < 0) { return NULL; } - cid = cid_data.cid; /* Queue up the object. */ int err = 0; if (blocking) { - err = _channel_send_wait(&_globals.channels, cid, obj); + err = _channel_send_wait(&_globals.channels, cid, obj, timeout); } else { err = _channel_send(&_globals.channels, cid, obj, NULL); @@ -2855,20 +2859,25 @@ By default this waits for the object to be received."); static PyObject * channel_send_buffer(PyObject *self, PyObject *args, PyObject *kwds) { - static char *kwlist[] = {"cid", "obj", "blocking", NULL}; - int64_t cid; + static char *kwlist[] = {"cid", "obj", "blocking", "timeout", NULL}; struct channel_id_converter_data cid_data = { .module = self, }; PyObject *obj; int blocking = 1; + PyObject *timeout_obj = NULL; if (!PyArg_ParseTupleAndKeywords(args, kwds, - "O&O|$p:channel_send_buffer", kwlist, + "O&O|$pO:channel_send_buffer", kwlist, channel_id_converter, &cid_data, &obj, - &blocking)) { + &blocking, &timeout_obj)) { + return NULL; + } + + int64_t cid = cid_data.cid; + PY_TIMEOUT_T timeout; + if (PyThread_ParseTimeoutArg(timeout_obj, blocking, &timeout) < 0) { return NULL; } - cid = cid_data.cid; PyObject *tempobj = PyMemoryView_FromObject(obj); if (tempobj == NULL) { @@ -2878,7 +2887,7 @@ channel_send_buffer(PyObject *self, PyObject *args, PyObject *kwds) /* Queue up the object. */ int err = 0; if (blocking) { - err = _channel_send_wait(&_globals.channels, cid, tempobj); + err = _channel_send_wait(&_globals.channels, cid, tempobj, timeout); } else { err = _channel_send(&_globals.channels, cid, tempobj, NULL); diff --git a/Python/thread.c b/Python/thread.c index 7185dd43d96..fefae839161 100644 --- a/Python/thread.c +++ b/Python/thread.c @@ -93,6 +93,40 @@ PyThread_set_stacksize(size_t size) } +int +PyThread_ParseTimeoutArg(PyObject *arg, int blocking, PY_TIMEOUT_T *timeout_p) +{ + assert(_PyTime_FromSeconds(-1) == PyThread_UNSET_TIMEOUT); + if (arg == NULL || arg == Py_None) { + *timeout_p = blocking ? PyThread_UNSET_TIMEOUT : 0; + return 0; + } + if (!blocking) { + PyErr_SetString(PyExc_ValueError, + "can't specify a timeout for a non-blocking call"); + return -1; + } + + _PyTime_t timeout; + if (_PyTime_FromSecondsObject(&timeout, arg, _PyTime_ROUND_TIMEOUT) < 0) { + return -1; + } + if (timeout < 0) { + PyErr_SetString(PyExc_ValueError, + "timeout value must be a non-negative number"); + return -1; + } + + if (_PyTime_AsMicroseconds(timeout, + _PyTime_ROUND_TIMEOUT) > PY_TIMEOUT_MAX) { + PyErr_SetString(PyExc_OverflowError, + "timeout value is too large"); + return -1; + } + *timeout_p = timeout; + return 0; +} + PyLockStatus PyThread_acquire_lock_timed_with_retries(PyThread_type_lock lock, PY_TIMEOUT_T timeout)