bpo-32604: Implement force-closing channels. (gh-6937)
This will make it easier to clean up channels (e.g. when used in tests).
This commit is contained in:
parent
74fc9c0c09
commit
3ab0136ac5
|
@ -1379,12 +1379,104 @@ class ChannelTests(TestBase):
|
||||||
with self.assertRaises(interpreters.ChannelClosedError):
|
with self.assertRaises(interpreters.ChannelClosedError):
|
||||||
interpreters.channel_close(cid)
|
interpreters.channel_close(cid)
|
||||||
|
|
||||||
def test_close_with_unused_items(self):
|
def test_close_empty(self):
|
||||||
|
tests = [
|
||||||
|
(False, False),
|
||||||
|
(True, False),
|
||||||
|
(False, True),
|
||||||
|
(True, True),
|
||||||
|
]
|
||||||
|
for send, recv in tests:
|
||||||
|
with self.subTest((send, recv)):
|
||||||
|
cid = interpreters.channel_create()
|
||||||
|
interpreters.channel_send(cid, b'spam')
|
||||||
|
interpreters.channel_recv(cid)
|
||||||
|
interpreters.channel_close(cid, send=send, recv=recv)
|
||||||
|
|
||||||
|
with self.assertRaises(interpreters.ChannelClosedError):
|
||||||
|
interpreters.channel_send(cid, b'eggs')
|
||||||
|
with self.assertRaises(interpreters.ChannelClosedError):
|
||||||
|
interpreters.channel_recv(cid)
|
||||||
|
|
||||||
|
def test_close_defaults_with_unused_items(self):
|
||||||
cid = interpreters.channel_create()
|
cid = interpreters.channel_create()
|
||||||
interpreters.channel_send(cid, b'spam')
|
interpreters.channel_send(cid, b'spam')
|
||||||
interpreters.channel_send(cid, b'ham')
|
interpreters.channel_send(cid, b'ham')
|
||||||
interpreters.channel_close(cid)
|
|
||||||
|
|
||||||
|
with self.assertRaises(interpreters.ChannelNotEmptyError):
|
||||||
|
interpreters.channel_close(cid)
|
||||||
|
interpreters.channel_recv(cid)
|
||||||
|
interpreters.channel_send(cid, b'eggs')
|
||||||
|
|
||||||
|
def test_close_recv_with_unused_items_unforced(self):
|
||||||
|
cid = interpreters.channel_create()
|
||||||
|
interpreters.channel_send(cid, b'spam')
|
||||||
|
interpreters.channel_send(cid, b'ham')
|
||||||
|
|
||||||
|
with self.assertRaises(interpreters.ChannelNotEmptyError):
|
||||||
|
interpreters.channel_close(cid, recv=True)
|
||||||
|
interpreters.channel_recv(cid)
|
||||||
|
interpreters.channel_send(cid, b'eggs')
|
||||||
|
interpreters.channel_recv(cid)
|
||||||
|
interpreters.channel_recv(cid)
|
||||||
|
interpreters.channel_close(cid, recv=True)
|
||||||
|
|
||||||
|
def test_close_send_with_unused_items_unforced(self):
|
||||||
|
cid = interpreters.channel_create()
|
||||||
|
interpreters.channel_send(cid, b'spam')
|
||||||
|
interpreters.channel_send(cid, b'ham')
|
||||||
|
interpreters.channel_close(cid, send=True)
|
||||||
|
|
||||||
|
with self.assertRaises(interpreters.ChannelClosedError):
|
||||||
|
interpreters.channel_send(cid, b'eggs')
|
||||||
|
interpreters.channel_recv(cid)
|
||||||
|
interpreters.channel_recv(cid)
|
||||||
|
with self.assertRaises(interpreters.ChannelClosedError):
|
||||||
|
interpreters.channel_recv(cid)
|
||||||
|
|
||||||
|
def test_close_both_with_unused_items_unforced(self):
|
||||||
|
cid = interpreters.channel_create()
|
||||||
|
interpreters.channel_send(cid, b'spam')
|
||||||
|
interpreters.channel_send(cid, b'ham')
|
||||||
|
|
||||||
|
with self.assertRaises(interpreters.ChannelNotEmptyError):
|
||||||
|
interpreters.channel_close(cid, recv=True, send=True)
|
||||||
|
interpreters.channel_recv(cid)
|
||||||
|
interpreters.channel_send(cid, b'eggs')
|
||||||
|
interpreters.channel_recv(cid)
|
||||||
|
interpreters.channel_recv(cid)
|
||||||
|
interpreters.channel_close(cid, recv=True)
|
||||||
|
|
||||||
|
def test_close_recv_with_unused_items_forced(self):
|
||||||
|
cid = interpreters.channel_create()
|
||||||
|
interpreters.channel_send(cid, b'spam')
|
||||||
|
interpreters.channel_send(cid, b'ham')
|
||||||
|
interpreters.channel_close(cid, recv=True, force=True)
|
||||||
|
|
||||||
|
with self.assertRaises(interpreters.ChannelClosedError):
|
||||||
|
interpreters.channel_send(cid, b'eggs')
|
||||||
|
with self.assertRaises(interpreters.ChannelClosedError):
|
||||||
|
interpreters.channel_recv(cid)
|
||||||
|
|
||||||
|
def test_close_send_with_unused_items_forced(self):
|
||||||
|
cid = interpreters.channel_create()
|
||||||
|
interpreters.channel_send(cid, b'spam')
|
||||||
|
interpreters.channel_send(cid, b'ham')
|
||||||
|
interpreters.channel_close(cid, send=True, force=True)
|
||||||
|
|
||||||
|
with self.assertRaises(interpreters.ChannelClosedError):
|
||||||
|
interpreters.channel_send(cid, b'eggs')
|
||||||
|
with self.assertRaises(interpreters.ChannelClosedError):
|
||||||
|
interpreters.channel_recv(cid)
|
||||||
|
|
||||||
|
def test_close_both_with_unused_items_forced(self):
|
||||||
|
cid = interpreters.channel_create()
|
||||||
|
interpreters.channel_send(cid, b'spam')
|
||||||
|
interpreters.channel_send(cid, b'ham')
|
||||||
|
interpreters.channel_close(cid, send=True, recv=True, force=True)
|
||||||
|
|
||||||
|
with self.assertRaises(interpreters.ChannelClosedError):
|
||||||
|
interpreters.channel_send(cid, b'eggs')
|
||||||
with self.assertRaises(interpreters.ChannelClosedError):
|
with self.assertRaises(interpreters.ChannelClosedError):
|
||||||
interpreters.channel_recv(cid)
|
interpreters.channel_recv(cid)
|
||||||
|
|
||||||
|
@ -1403,7 +1495,7 @@ class ChannelTests(TestBase):
|
||||||
interp = interpreters.create()
|
interp = interpreters.create()
|
||||||
interpreters.run_string(interp, dedent(f"""
|
interpreters.run_string(interp, dedent(f"""
|
||||||
import _xxsubinterpreters as _interpreters
|
import _xxsubinterpreters as _interpreters
|
||||||
_interpreters.channel_close({cid})
|
_interpreters.channel_close({cid}, force=True)
|
||||||
"""))
|
"""))
|
||||||
with self.assertRaises(interpreters.ChannelClosedError):
|
with self.assertRaises(interpreters.ChannelClosedError):
|
||||||
interpreters.channel_recv(cid)
|
interpreters.channel_recv(cid)
|
||||||
|
@ -1416,7 +1508,7 @@ class ChannelTests(TestBase):
|
||||||
interpreters.channel_send(cid, b'spam')
|
interpreters.channel_send(cid, b'spam')
|
||||||
interpreters.channel_send(cid, b'spam')
|
interpreters.channel_send(cid, b'spam')
|
||||||
interpreters.channel_recv(cid)
|
interpreters.channel_recv(cid)
|
||||||
interpreters.channel_close(cid)
|
interpreters.channel_close(cid, force=True)
|
||||||
|
|
||||||
with self.assertRaises(interpreters.ChannelClosedError):
|
with self.assertRaises(interpreters.ChannelClosedError):
|
||||||
interpreters.channel_send(cid, b'eggs')
|
interpreters.channel_send(cid, b'eggs')
|
||||||
|
|
|
@ -306,10 +306,15 @@ _sharedexception_apply(_sharedexception *exc, PyObject *wrapperclass)
|
||||||
|
|
||||||
/* channel-specific code ****************************************************/
|
/* channel-specific code ****************************************************/
|
||||||
|
|
||||||
|
#define CHANNEL_SEND 1
|
||||||
|
#define CHANNEL_BOTH 0
|
||||||
|
#define CHANNEL_RECV -1
|
||||||
|
|
||||||
static PyObject *ChannelError;
|
static PyObject *ChannelError;
|
||||||
static PyObject *ChannelNotFoundError;
|
static PyObject *ChannelNotFoundError;
|
||||||
static PyObject *ChannelClosedError;
|
static PyObject *ChannelClosedError;
|
||||||
static PyObject *ChannelEmptyError;
|
static PyObject *ChannelEmptyError;
|
||||||
|
static PyObject *ChannelNotEmptyError;
|
||||||
|
|
||||||
static int
|
static int
|
||||||
channel_exceptions_init(PyObject *ns)
|
channel_exceptions_init(PyObject *ns)
|
||||||
|
@ -356,6 +361,16 @@ channel_exceptions_init(PyObject *ns)
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// An operation tried to close a non-empty channel.
|
||||||
|
ChannelNotEmptyError = PyErr_NewException(
|
||||||
|
"_xxsubinterpreters.ChannelNotEmptyError", ChannelError, NULL);
|
||||||
|
if (ChannelNotEmptyError == NULL) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
if (PyDict_SetItemString(ns, "ChannelNotEmptyError", ChannelNotEmptyError) != 0) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -696,8 +711,11 @@ _channelends_close_interpreter(_channelends *ends, int64_t interp, int which)
|
||||||
}
|
}
|
||||||
|
|
||||||
static void
|
static void
|
||||||
_channelends_close_all(_channelends *ends)
|
_channelends_close_all(_channelends *ends, int which, int force)
|
||||||
{
|
{
|
||||||
|
// XXX Handle the ends.
|
||||||
|
// XXX Handle force is True.
|
||||||
|
|
||||||
// Ensure all the "send"-associated interpreters are closed.
|
// Ensure all the "send"-associated interpreters are closed.
|
||||||
_channelend *end;
|
_channelend *end;
|
||||||
for (end = ends->send; end != NULL; end = end->next) {
|
for (end = ends->send; end != NULL; end = end->next) {
|
||||||
|
@ -713,12 +731,16 @@ _channelends_close_all(_channelends *ends)
|
||||||
/* channels */
|
/* channels */
|
||||||
|
|
||||||
struct _channel;
|
struct _channel;
|
||||||
|
struct _channel_closing;
|
||||||
|
static void _channel_clear_closing(struct _channel *);
|
||||||
|
static void _channel_finish_closing(struct _channel *);
|
||||||
|
|
||||||
typedef struct _channel {
|
typedef struct _channel {
|
||||||
PyThread_type_lock mutex;
|
PyThread_type_lock mutex;
|
||||||
_channelqueue *queue;
|
_channelqueue *queue;
|
||||||
_channelends *ends;
|
_channelends *ends;
|
||||||
int open;
|
int open;
|
||||||
|
struct _channel_closing *closing;
|
||||||
} _PyChannelState;
|
} _PyChannelState;
|
||||||
|
|
||||||
static _PyChannelState *
|
static _PyChannelState *
|
||||||
|
@ -747,12 +769,14 @@ _channel_new(void)
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
chan->open = 1;
|
chan->open = 1;
|
||||||
|
chan->closing = NULL;
|
||||||
return chan;
|
return chan;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void
|
static void
|
||||||
_channel_free(_PyChannelState *chan)
|
_channel_free(_PyChannelState *chan)
|
||||||
{
|
{
|
||||||
|
_channel_clear_closing(chan);
|
||||||
PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
|
PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
|
||||||
_channelqueue_free(chan->queue);
|
_channelqueue_free(chan->queue);
|
||||||
_channelends_free(chan->ends);
|
_channelends_free(chan->ends);
|
||||||
|
@ -802,13 +826,20 @@ _channel_next(_PyChannelState *chan, int64_t interp)
|
||||||
}
|
}
|
||||||
|
|
||||||
data = _channelqueue_get(chan->queue);
|
data = _channelqueue_get(chan->queue);
|
||||||
|
if (data == NULL && !PyErr_Occurred() && chan->closing != NULL) {
|
||||||
|
chan->open = 0;
|
||||||
|
}
|
||||||
|
|
||||||
done:
|
done:
|
||||||
PyThread_release_lock(chan->mutex);
|
PyThread_release_lock(chan->mutex);
|
||||||
|
if (chan->queue->count == 0) {
|
||||||
|
_channel_finish_closing(chan);
|
||||||
|
}
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
static int
|
static int
|
||||||
_channel_close_interpreter(_PyChannelState *chan, int64_t interp, int which)
|
_channel_close_interpreter(_PyChannelState *chan, int64_t interp, int end)
|
||||||
{
|
{
|
||||||
PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
|
PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
|
||||||
|
|
||||||
|
@ -818,7 +849,7 @@ _channel_close_interpreter(_PyChannelState *chan, int64_t interp, int which)
|
||||||
goto done;
|
goto done;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (_channelends_close_interpreter(chan->ends, interp, which) != 0) {
|
if (_channelends_close_interpreter(chan->ends, interp, end) != 0) {
|
||||||
goto done;
|
goto done;
|
||||||
}
|
}
|
||||||
chan->open = _channelends_is_open(chan->ends);
|
chan->open = _channelends_is_open(chan->ends);
|
||||||
|
@ -830,7 +861,7 @@ done:
|
||||||
}
|
}
|
||||||
|
|
||||||
static int
|
static int
|
||||||
_channel_close_all(_PyChannelState *chan)
|
_channel_close_all(_PyChannelState *chan, int end, int force)
|
||||||
{
|
{
|
||||||
int res = -1;
|
int res = -1;
|
||||||
PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
|
PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
|
||||||
|
@ -840,11 +871,17 @@ _channel_close_all(_PyChannelState *chan)
|
||||||
goto done;
|
goto done;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!force && chan->queue->count > 0) {
|
||||||
|
PyErr_SetString(ChannelNotEmptyError,
|
||||||
|
"may not be closed if not empty (try force=True)");
|
||||||
|
goto done;
|
||||||
|
}
|
||||||
|
|
||||||
chan->open = 0;
|
chan->open = 0;
|
||||||
|
|
||||||
// We *could* also just leave these in place, since we've marked
|
// We *could* also just leave these in place, since we've marked
|
||||||
// the channel as closed already.
|
// the channel as closed already.
|
||||||
_channelends_close_all(chan->ends);
|
_channelends_close_all(chan->ends, end, force);
|
||||||
|
|
||||||
res = 0;
|
res = 0;
|
||||||
done:
|
done:
|
||||||
|
@ -889,6 +926,9 @@ _channelref_new(int64_t id, _PyChannelState *chan)
|
||||||
static void
|
static void
|
||||||
_channelref_free(_channelref *ref)
|
_channelref_free(_channelref *ref)
|
||||||
{
|
{
|
||||||
|
if (ref->chan != NULL) {
|
||||||
|
_channel_clear_closing(ref->chan);
|
||||||
|
}
|
||||||
//_channelref_clear(ref);
|
//_channelref_clear(ref);
|
||||||
PyMem_Free(ref);
|
PyMem_Free(ref);
|
||||||
}
|
}
|
||||||
|
@ -1009,8 +1049,12 @@ done:
|
||||||
return cid;
|
return cid;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* forward */
|
||||||
|
static int _channel_set_closing(struct _channelref *, PyThread_type_lock);
|
||||||
|
|
||||||
static int
|
static int
|
||||||
_channels_close(_channels *channels, int64_t cid, _PyChannelState **pchan)
|
_channels_close(_channels *channels, int64_t cid, _PyChannelState **pchan,
|
||||||
|
int end, int force)
|
||||||
{
|
{
|
||||||
int res = -1;
|
int res = -1;
|
||||||
PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
|
PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
|
||||||
|
@ -1028,8 +1072,29 @@ _channels_close(_channels *channels, int64_t cid, _PyChannelState **pchan)
|
||||||
PyErr_Format(ChannelClosedError, "channel %d closed", cid);
|
PyErr_Format(ChannelClosedError, "channel %d closed", cid);
|
||||||
goto done;
|
goto done;
|
||||||
}
|
}
|
||||||
|
else if (!force && end == CHANNEL_SEND && ref->chan->closing != NULL) {
|
||||||
|
PyErr_Format(ChannelClosedError, "channel %d closed", cid);
|
||||||
|
goto done;
|
||||||
|
}
|
||||||
else {
|
else {
|
||||||
if (_channel_close_all(ref->chan) != 0) {
|
if (_channel_close_all(ref->chan, end, force) != 0) {
|
||||||
|
if (end == CHANNEL_SEND &&
|
||||||
|
PyErr_ExceptionMatches(ChannelNotEmptyError)) {
|
||||||
|
if (ref->chan->closing != NULL) {
|
||||||
|
PyErr_Format(ChannelClosedError, "channel %d closed", cid);
|
||||||
|
goto done;
|
||||||
|
}
|
||||||
|
// Mark the channel as closing and return. The channel
|
||||||
|
// will be cleaned up in _channel_next().
|
||||||
|
PyErr_Clear();
|
||||||
|
if (_channel_set_closing(ref, channels->mutex) != 0) {
|
||||||
|
goto done;
|
||||||
|
}
|
||||||
|
if (pchan != NULL) {
|
||||||
|
*pchan = ref->chan;
|
||||||
|
}
|
||||||
|
res = 0;
|
||||||
|
}
|
||||||
goto done;
|
goto done;
|
||||||
}
|
}
|
||||||
if (pchan != NULL) {
|
if (pchan != NULL) {
|
||||||
|
@ -1161,6 +1226,60 @@ done:
|
||||||
return cids;
|
return cids;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* support for closing non-empty channels */
|
||||||
|
|
||||||
|
struct _channel_closing {
|
||||||
|
struct _channelref *ref;
|
||||||
|
};
|
||||||
|
|
||||||
|
static int
|
||||||
|
_channel_set_closing(struct _channelref *ref, PyThread_type_lock mutex) {
|
||||||
|
struct _channel *chan = ref->chan;
|
||||||
|
if (chan == NULL) {
|
||||||
|
// already closed
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
int res = -1;
|
||||||
|
PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
|
||||||
|
if (chan->closing != NULL) {
|
||||||
|
PyErr_SetString(ChannelClosedError, "channel closed");
|
||||||
|
goto done;
|
||||||
|
}
|
||||||
|
chan->closing = PyMem_NEW(struct _channel_closing, 1);
|
||||||
|
if (chan->closing == NULL) {
|
||||||
|
goto done;
|
||||||
|
}
|
||||||
|
chan->closing->ref = ref;
|
||||||
|
|
||||||
|
res = 0;
|
||||||
|
done:
|
||||||
|
PyThread_release_lock(chan->mutex);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void
|
||||||
|
_channel_clear_closing(struct _channel *chan) {
|
||||||
|
PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
|
||||||
|
if (chan->closing != NULL) {
|
||||||
|
PyMem_Free(chan->closing);
|
||||||
|
chan->closing = NULL;
|
||||||
|
}
|
||||||
|
PyThread_release_lock(chan->mutex);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void
|
||||||
|
_channel_finish_closing(struct _channel *chan) {
|
||||||
|
struct _channel_closing *closing = chan->closing;
|
||||||
|
if (closing == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
_channelref *ref = closing->ref;
|
||||||
|
_channel_clear_closing(chan);
|
||||||
|
// Do the things that would have been done in _channels_close().
|
||||||
|
ref->chan = NULL;
|
||||||
|
_channel_free(chan);
|
||||||
|
};
|
||||||
|
|
||||||
/* "high"-level channel-related functions */
|
/* "high"-level channel-related functions */
|
||||||
|
|
||||||
static int64_t
|
static int64_t
|
||||||
|
@ -1207,6 +1326,12 @@ _channel_send(_channels *channels, int64_t id, PyObject *obj)
|
||||||
}
|
}
|
||||||
// Past this point we are responsible for releasing the mutex.
|
// Past this point we are responsible for releasing the mutex.
|
||||||
|
|
||||||
|
if (chan->closing != NULL) {
|
||||||
|
PyErr_Format(ChannelClosedError, "channel %d closed", id);
|
||||||
|
PyThread_release_lock(mutex);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
// Convert the object to cross-interpreter data.
|
// Convert the object to cross-interpreter data.
|
||||||
_PyCrossInterpreterData *data = PyMem_NEW(_PyCrossInterpreterData, 1);
|
_PyCrossInterpreterData *data = PyMem_NEW(_PyCrossInterpreterData, 1);
|
||||||
if (data == NULL) {
|
if (data == NULL) {
|
||||||
|
@ -1290,16 +1415,13 @@ _channel_drop(_channels *channels, int64_t id, int send, int recv)
|
||||||
}
|
}
|
||||||
|
|
||||||
static int
|
static int
|
||||||
_channel_close(_channels *channels, int64_t id)
|
_channel_close(_channels *channels, int64_t id, int end, int force)
|
||||||
{
|
{
|
||||||
return _channels_close(channels, id, NULL);
|
return _channels_close(channels, id, NULL, end, force);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ChannelID class */
|
/* ChannelID class */
|
||||||
|
|
||||||
#define CHANNEL_SEND 1
|
|
||||||
#define CHANNEL_RECV -1
|
|
||||||
|
|
||||||
static PyTypeObject ChannelIDtype;
|
static PyTypeObject ChannelIDtype;
|
||||||
|
|
||||||
typedef struct channelid {
|
typedef struct channelid {
|
||||||
|
@ -2555,15 +2677,8 @@ channel_close(PyObject *self, PyObject *args, PyObject *kwds)
|
||||||
if (cid < 0) {
|
if (cid < 0) {
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
if (send == 0 && recv == 0) {
|
|
||||||
send = 1;
|
|
||||||
recv = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// XXX Handle the ends.
|
if (_channel_close(&_globals.channels, cid, send-recv, force) != 0) {
|
||||||
// XXX Handle force is True.
|
|
||||||
|
|
||||||
if (_channel_close(&_globals.channels, cid) != 0) {
|
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
|
|
Loading…
Reference in New Issue