bpo-32604: Implement force-closing channels. (gh-6937)

This will make it easier to clean up channels (e.g. when used in tests).
This commit is contained in:
Eric Snow 2018-05-17 10:27:09 -04:00 committed by GitHub
parent 74fc9c0c09
commit 3ab0136ac5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 232 additions and 25 deletions

View File

@ -1379,12 +1379,104 @@ class ChannelTests(TestBase):
with self.assertRaises(interpreters.ChannelClosedError): with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_close(cid) interpreters.channel_close(cid)
def test_close_with_unused_items(self): def test_close_empty(self):
tests = [
(False, False),
(True, False),
(False, True),
(True, True),
]
for send, recv in tests:
with self.subTest((send, recv)):
cid = interpreters.channel_create()
interpreters.channel_send(cid, b'spam')
interpreters.channel_recv(cid)
interpreters.channel_close(cid, send=send, recv=recv)
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_send(cid, b'eggs')
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_recv(cid)
def test_close_defaults_with_unused_items(self):
cid = interpreters.channel_create() cid = interpreters.channel_create()
interpreters.channel_send(cid, b'spam') interpreters.channel_send(cid, b'spam')
interpreters.channel_send(cid, b'ham') interpreters.channel_send(cid, b'ham')
interpreters.channel_close(cid)
with self.assertRaises(interpreters.ChannelNotEmptyError):
interpreters.channel_close(cid)
interpreters.channel_recv(cid)
interpreters.channel_send(cid, b'eggs')
def test_close_recv_with_unused_items_unforced(self):
cid = interpreters.channel_create()
interpreters.channel_send(cid, b'spam')
interpreters.channel_send(cid, b'ham')
with self.assertRaises(interpreters.ChannelNotEmptyError):
interpreters.channel_close(cid, recv=True)
interpreters.channel_recv(cid)
interpreters.channel_send(cid, b'eggs')
interpreters.channel_recv(cid)
interpreters.channel_recv(cid)
interpreters.channel_close(cid, recv=True)
def test_close_send_with_unused_items_unforced(self):
cid = interpreters.channel_create()
interpreters.channel_send(cid, b'spam')
interpreters.channel_send(cid, b'ham')
interpreters.channel_close(cid, send=True)
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_send(cid, b'eggs')
interpreters.channel_recv(cid)
interpreters.channel_recv(cid)
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_recv(cid)
def test_close_both_with_unused_items_unforced(self):
cid = interpreters.channel_create()
interpreters.channel_send(cid, b'spam')
interpreters.channel_send(cid, b'ham')
with self.assertRaises(interpreters.ChannelNotEmptyError):
interpreters.channel_close(cid, recv=True, send=True)
interpreters.channel_recv(cid)
interpreters.channel_send(cid, b'eggs')
interpreters.channel_recv(cid)
interpreters.channel_recv(cid)
interpreters.channel_close(cid, recv=True)
def test_close_recv_with_unused_items_forced(self):
cid = interpreters.channel_create()
interpreters.channel_send(cid, b'spam')
interpreters.channel_send(cid, b'ham')
interpreters.channel_close(cid, recv=True, force=True)
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_send(cid, b'eggs')
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_recv(cid)
def test_close_send_with_unused_items_forced(self):
cid = interpreters.channel_create()
interpreters.channel_send(cid, b'spam')
interpreters.channel_send(cid, b'ham')
interpreters.channel_close(cid, send=True, force=True)
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_send(cid, b'eggs')
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_recv(cid)
def test_close_both_with_unused_items_forced(self):
cid = interpreters.channel_create()
interpreters.channel_send(cid, b'spam')
interpreters.channel_send(cid, b'ham')
interpreters.channel_close(cid, send=True, recv=True, force=True)
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_send(cid, b'eggs')
with self.assertRaises(interpreters.ChannelClosedError): with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_recv(cid) interpreters.channel_recv(cid)
@ -1403,7 +1495,7 @@ class ChannelTests(TestBase):
interp = interpreters.create() interp = interpreters.create()
interpreters.run_string(interp, dedent(f""" interpreters.run_string(interp, dedent(f"""
import _xxsubinterpreters as _interpreters import _xxsubinterpreters as _interpreters
_interpreters.channel_close({cid}) _interpreters.channel_close({cid}, force=True)
""")) """))
with self.assertRaises(interpreters.ChannelClosedError): with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_recv(cid) interpreters.channel_recv(cid)
@ -1416,7 +1508,7 @@ class ChannelTests(TestBase):
interpreters.channel_send(cid, b'spam') interpreters.channel_send(cid, b'spam')
interpreters.channel_send(cid, b'spam') interpreters.channel_send(cid, b'spam')
interpreters.channel_recv(cid) interpreters.channel_recv(cid)
interpreters.channel_close(cid) interpreters.channel_close(cid, force=True)
with self.assertRaises(interpreters.ChannelClosedError): with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_send(cid, b'eggs') interpreters.channel_send(cid, b'eggs')

View File

@ -306,10 +306,15 @@ _sharedexception_apply(_sharedexception *exc, PyObject *wrapperclass)
/* channel-specific code ****************************************************/ /* channel-specific code ****************************************************/
#define CHANNEL_SEND 1
#define CHANNEL_BOTH 0
#define CHANNEL_RECV -1
static PyObject *ChannelError; static PyObject *ChannelError;
static PyObject *ChannelNotFoundError; static PyObject *ChannelNotFoundError;
static PyObject *ChannelClosedError; static PyObject *ChannelClosedError;
static PyObject *ChannelEmptyError; static PyObject *ChannelEmptyError;
static PyObject *ChannelNotEmptyError;
static int static int
channel_exceptions_init(PyObject *ns) channel_exceptions_init(PyObject *ns)
@ -356,6 +361,16 @@ channel_exceptions_init(PyObject *ns)
return -1; return -1;
} }
// An operation tried to close a non-empty channel.
ChannelNotEmptyError = PyErr_NewException(
"_xxsubinterpreters.ChannelNotEmptyError", ChannelError, NULL);
if (ChannelNotEmptyError == NULL) {
return -1;
}
if (PyDict_SetItemString(ns, "ChannelNotEmptyError", ChannelNotEmptyError) != 0) {
return -1;
}
return 0; return 0;
} }
@ -696,8 +711,11 @@ _channelends_close_interpreter(_channelends *ends, int64_t interp, int which)
} }
static void static void
_channelends_close_all(_channelends *ends) _channelends_close_all(_channelends *ends, int which, int force)
{ {
// XXX Handle the ends.
// XXX Handle force is True.
// Ensure all the "send"-associated interpreters are closed. // Ensure all the "send"-associated interpreters are closed.
_channelend *end; _channelend *end;
for (end = ends->send; end != NULL; end = end->next) { for (end = ends->send; end != NULL; end = end->next) {
@ -713,12 +731,16 @@ _channelends_close_all(_channelends *ends)
/* channels */ /* channels */
struct _channel; struct _channel;
struct _channel_closing;
static void _channel_clear_closing(struct _channel *);
static void _channel_finish_closing(struct _channel *);
typedef struct _channel { typedef struct _channel {
PyThread_type_lock mutex; PyThread_type_lock mutex;
_channelqueue *queue; _channelqueue *queue;
_channelends *ends; _channelends *ends;
int open; int open;
struct _channel_closing *closing;
} _PyChannelState; } _PyChannelState;
static _PyChannelState * static _PyChannelState *
@ -747,12 +769,14 @@ _channel_new(void)
return NULL; return NULL;
} }
chan->open = 1; chan->open = 1;
chan->closing = NULL;
return chan; return chan;
} }
static void static void
_channel_free(_PyChannelState *chan) _channel_free(_PyChannelState *chan)
{ {
_channel_clear_closing(chan);
PyThread_acquire_lock(chan->mutex, WAIT_LOCK); PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
_channelqueue_free(chan->queue); _channelqueue_free(chan->queue);
_channelends_free(chan->ends); _channelends_free(chan->ends);
@ -802,13 +826,20 @@ _channel_next(_PyChannelState *chan, int64_t interp)
} }
data = _channelqueue_get(chan->queue); data = _channelqueue_get(chan->queue);
if (data == NULL && !PyErr_Occurred() && chan->closing != NULL) {
chan->open = 0;
}
done: done:
PyThread_release_lock(chan->mutex); PyThread_release_lock(chan->mutex);
if (chan->queue->count == 0) {
_channel_finish_closing(chan);
}
return data; return data;
} }
static int static int
_channel_close_interpreter(_PyChannelState *chan, int64_t interp, int which) _channel_close_interpreter(_PyChannelState *chan, int64_t interp, int end)
{ {
PyThread_acquire_lock(chan->mutex, WAIT_LOCK); PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
@ -818,7 +849,7 @@ _channel_close_interpreter(_PyChannelState *chan, int64_t interp, int which)
goto done; goto done;
} }
if (_channelends_close_interpreter(chan->ends, interp, which) != 0) { if (_channelends_close_interpreter(chan->ends, interp, end) != 0) {
goto done; goto done;
} }
chan->open = _channelends_is_open(chan->ends); chan->open = _channelends_is_open(chan->ends);
@ -830,7 +861,7 @@ done:
} }
static int static int
_channel_close_all(_PyChannelState *chan) _channel_close_all(_PyChannelState *chan, int end, int force)
{ {
int res = -1; int res = -1;
PyThread_acquire_lock(chan->mutex, WAIT_LOCK); PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
@ -840,11 +871,17 @@ _channel_close_all(_PyChannelState *chan)
goto done; goto done;
} }
if (!force && chan->queue->count > 0) {
PyErr_SetString(ChannelNotEmptyError,
"may not be closed if not empty (try force=True)");
goto done;
}
chan->open = 0; chan->open = 0;
// We *could* also just leave these in place, since we've marked // We *could* also just leave these in place, since we've marked
// the channel as closed already. // the channel as closed already.
_channelends_close_all(chan->ends); _channelends_close_all(chan->ends, end, force);
res = 0; res = 0;
done: done:
@ -889,6 +926,9 @@ _channelref_new(int64_t id, _PyChannelState *chan)
static void static void
_channelref_free(_channelref *ref) _channelref_free(_channelref *ref)
{ {
if (ref->chan != NULL) {
_channel_clear_closing(ref->chan);
}
//_channelref_clear(ref); //_channelref_clear(ref);
PyMem_Free(ref); PyMem_Free(ref);
} }
@ -1009,8 +1049,12 @@ done:
return cid; return cid;
} }
/* forward */
static int _channel_set_closing(struct _channelref *, PyThread_type_lock);
static int static int
_channels_close(_channels *channels, int64_t cid, _PyChannelState **pchan) _channels_close(_channels *channels, int64_t cid, _PyChannelState **pchan,
int end, int force)
{ {
int res = -1; int res = -1;
PyThread_acquire_lock(channels->mutex, WAIT_LOCK); PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
@ -1028,14 +1072,35 @@ _channels_close(_channels *channels, int64_t cid, _PyChannelState **pchan)
PyErr_Format(ChannelClosedError, "channel %d closed", cid); PyErr_Format(ChannelClosedError, "channel %d closed", cid);
goto done; goto done;
} }
else if (!force && end == CHANNEL_SEND && ref->chan->closing != NULL) {
PyErr_Format(ChannelClosedError, "channel %d closed", cid);
goto done;
}
else { else {
if (_channel_close_all(ref->chan) != 0) { if (_channel_close_all(ref->chan, end, force) != 0) {
if (end == CHANNEL_SEND &&
PyErr_ExceptionMatches(ChannelNotEmptyError)) {
if (ref->chan->closing != NULL) {
PyErr_Format(ChannelClosedError, "channel %d closed", cid);
goto done;
}
// Mark the channel as closing and return. The channel
// will be cleaned up in _channel_next().
PyErr_Clear();
if (_channel_set_closing(ref, channels->mutex) != 0) {
goto done;
}
if (pchan != NULL) {
*pchan = ref->chan;
}
res = 0;
}
goto done; goto done;
} }
if (pchan != NULL) { if (pchan != NULL) {
*pchan = ref->chan; *pchan = ref->chan;
} }
else { else {
_channel_free(ref->chan); _channel_free(ref->chan);
} }
ref->chan = NULL; ref->chan = NULL;
@ -1161,6 +1226,60 @@ done:
return cids; return cids;
} }
/* support for closing non-empty channels */
struct _channel_closing {
struct _channelref *ref;
};
static int
_channel_set_closing(struct _channelref *ref, PyThread_type_lock mutex) {
struct _channel *chan = ref->chan;
if (chan == NULL) {
// already closed
return 0;
}
int res = -1;
PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
if (chan->closing != NULL) {
PyErr_SetString(ChannelClosedError, "channel closed");
goto done;
}
chan->closing = PyMem_NEW(struct _channel_closing, 1);
if (chan->closing == NULL) {
goto done;
}
chan->closing->ref = ref;
res = 0;
done:
PyThread_release_lock(chan->mutex);
return res;
}
static void
_channel_clear_closing(struct _channel *chan) {
PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
if (chan->closing != NULL) {
PyMem_Free(chan->closing);
chan->closing = NULL;
}
PyThread_release_lock(chan->mutex);
}
static void
_channel_finish_closing(struct _channel *chan) {
struct _channel_closing *closing = chan->closing;
if (closing == NULL) {
return;
}
_channelref *ref = closing->ref;
_channel_clear_closing(chan);
// Do the things that would have been done in _channels_close().
ref->chan = NULL;
_channel_free(chan);
};
/* "high"-level channel-related functions */ /* "high"-level channel-related functions */
static int64_t static int64_t
@ -1207,6 +1326,12 @@ _channel_send(_channels *channels, int64_t id, PyObject *obj)
} }
// Past this point we are responsible for releasing the mutex. // Past this point we are responsible for releasing the mutex.
if (chan->closing != NULL) {
PyErr_Format(ChannelClosedError, "channel %d closed", id);
PyThread_release_lock(mutex);
return -1;
}
// Convert the object to cross-interpreter data. // Convert the object to cross-interpreter data.
_PyCrossInterpreterData *data = PyMem_NEW(_PyCrossInterpreterData, 1); _PyCrossInterpreterData *data = PyMem_NEW(_PyCrossInterpreterData, 1);
if (data == NULL) { if (data == NULL) {
@ -1290,16 +1415,13 @@ _channel_drop(_channels *channels, int64_t id, int send, int recv)
} }
static int static int
_channel_close(_channels *channels, int64_t id) _channel_close(_channels *channels, int64_t id, int end, int force)
{ {
return _channels_close(channels, id, NULL); return _channels_close(channels, id, NULL, end, force);
} }
/* ChannelID class */ /* ChannelID class */
#define CHANNEL_SEND 1
#define CHANNEL_RECV -1
static PyTypeObject ChannelIDtype; static PyTypeObject ChannelIDtype;
typedef struct channelid { typedef struct channelid {
@ -2555,15 +2677,8 @@ channel_close(PyObject *self, PyObject *args, PyObject *kwds)
if (cid < 0) { if (cid < 0) {
return NULL; return NULL;
} }
if (send == 0 && recv == 0) {
send = 1;
recv = 1;
}
// XXX Handle the ends. if (_channel_close(&_globals.channels, cid, send-recv, force) != 0) {
// XXX Handle force is True.
if (_channel_close(&_globals.channels, cid) != 0) {
return NULL; return NULL;
} }
Py_RETURN_NONE; Py_RETURN_NONE;