gh-76785: Module-level Fixes for test.support.interpreters (gh-110236)

* add RecvChannel.close() and SendChannel.close()
* make RecvChannel and SendChannel shareable
* expose ChannelEmptyError and ChannelNotEmptyError
This commit is contained in:
Eric Snow 2023-10-02 14:47:41 -06:00 committed by GitHub
parent 014aacda62
commit a8f5dab58d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 206 additions and 25 deletions

View File

@ -7,7 +7,8 @@ import _xxinterpchannels as _channels
# aliases:
from _xxsubinterpreters import is_shareable
from _xxinterpchannels import (
ChannelError, ChannelNotFoundError, ChannelEmptyError,
ChannelError, ChannelNotFoundError, ChannelClosedError,
ChannelEmptyError, ChannelNotEmptyError,
)
@ -117,10 +118,16 @@ def list_all_channels():
class _ChannelEnd:
"""The base class for RecvChannel and SendChannel."""
def __init__(self, id):
if not isinstance(id, (int, _channels.ChannelID)):
raise TypeError(f'id must be an int, got {id!r}')
self._id = id
_end = None
def __init__(self, cid):
if self._end == 'send':
cid = _channels._channel_id(cid, send=True, force=True)
elif self._end == 'recv':
cid = _channels._channel_id(cid, recv=True, force=True)
else:
raise NotImplementedError(self._end)
self._id = cid
def __repr__(self):
return f'{type(self).__name__}(id={int(self._id)})'
@ -147,6 +154,8 @@ _NOT_SET = object()
class RecvChannel(_ChannelEnd):
"""The receiving end of a cross-interpreter channel."""
_end = 'recv'
def recv(self, *, _sentinel=object(), _delay=10 / 1000): # 10 milliseconds
"""Return the next object from the channel.
@ -171,10 +180,15 @@ class RecvChannel(_ChannelEnd):
else:
return _channels.recv(self._id, default)
def close(self):
_channels.close(self._id, recv=True)
class SendChannel(_ChannelEnd):
"""The sending end of a cross-interpreter channel."""
_end = 'send'
def send(self, obj):
"""Send the object (i.e. its data) to the channel's receiving end.
@ -196,3 +210,9 @@ class SendChannel(_ChannelEnd):
# None. This should be fixed when channel_send_wait() is added.
# See bpo-32604 and gh-19829.
return _channels.send(self._id, obj)
def close(self):
_channels.close(self._id, send=True)
_channels._register_end_types(SendChannel, RecvChannel)

View File

@ -822,6 +822,22 @@ class TestChannels(TestBase):
after = set(interpreters.list_all_channels())
self.assertEqual(after, created)
def test_shareable(self):
rch, sch = interpreters.create_channel()
self.assertTrue(
interpreters.is_shareable(rch))
self.assertTrue(
interpreters.is_shareable(sch))
sch.send_nowait(rch)
sch.send_nowait(sch)
rch2 = rch.recv()
sch2 = rch.recv()
self.assertEqual(rch2, rch)
self.assertEqual(sch2, sch)
class TestRecvChannelAttrs(TestBase):

View File

@ -198,6 +198,9 @@ _release_xid_data(_PyCrossInterpreterData *data, int flags)
/* module state *************************************************************/
typedef struct {
PyTypeObject *send_channel_type;
PyTypeObject *recv_channel_type;
/* heap types */
PyTypeObject *ChannelIDType;
@ -218,6 +221,21 @@ get_module_state(PyObject *mod)
return state;
}
static module_state *
_get_current_module_state(void)
{
PyObject *mod = _get_current_module();
if (mod == NULL) {
// XXX import it?
PyErr_SetString(PyExc_RuntimeError,
MODULE_NAME " module not imported yet");
return NULL;
}
module_state *state = get_module_state(mod);
Py_DECREF(mod);
return state;
}
static int
traverse_module_state(module_state *state, visitproc visit, void *arg)
{
@ -237,6 +255,9 @@ traverse_module_state(module_state *state, visitproc visit, void *arg)
static int
clear_module_state(module_state *state)
{
Py_CLEAR(state->send_channel_type);
Py_CLEAR(state->recv_channel_type);
/* heap types */
if (state->ChannelIDType != NULL) {
(void)_PyCrossInterpreterData_UnregisterClass(state->ChannelIDType);
@ -1529,17 +1550,20 @@ typedef struct channelid {
struct channel_id_converter_data {
PyObject *module;
int64_t cid;
int end;
};
static int
channel_id_converter(PyObject *arg, void *ptr)
{
int64_t cid;
int end = 0;
struct channel_id_converter_data *data = ptr;
module_state *state = get_module_state(data->module);
assert(state != NULL);
if (PyObject_TypeCheck(arg, state->ChannelIDType)) {
cid = ((channelid *)arg)->id;
end = ((channelid *)arg)->end;
}
else if (PyIndex_Check(arg)) {
cid = PyLong_AsLongLong(arg);
@ -1559,6 +1583,7 @@ channel_id_converter(PyObject *arg, void *ptr)
return 0;
}
data->cid = cid;
data->end = end;
return 1;
}
@ -1600,6 +1625,7 @@ _channelid_new(PyObject *mod, PyTypeObject *cls,
{
static char *kwlist[] = {"id", "send", "recv", "force", "_resolve", NULL};
int64_t cid;
int end;
struct channel_id_converter_data cid_data = {
.module = mod,
};
@ -1614,6 +1640,7 @@ _channelid_new(PyObject *mod, PyTypeObject *cls,
return NULL;
}
cid = cid_data.cid;
end = cid_data.end;
// Handle "send" and "recv".
if (send == 0 && recv == 0) {
@ -1621,14 +1648,17 @@ _channelid_new(PyObject *mod, PyTypeObject *cls,
"'send' and 'recv' cannot both be False");
return NULL;
}
int end = 0;
if (send == 1) {
else if (send == 1) {
if (recv == 0 || recv == -1) {
end = CHANNEL_SEND;
}
else {
assert(recv == 1);
end = 0;
}
}
else if (recv == 1) {
assert(send == 0 || send == -1);
end = CHANNEL_RECV;
}
@ -1773,21 +1803,12 @@ done:
return res;
}
static PyTypeObject * _get_current_channel_end_type(int end);
static PyObject *
_channel_from_cid(PyObject *cid, int end)
{
PyObject *highlevel = PyImport_ImportModule("interpreters");
if (highlevel == NULL) {
PyErr_Clear();
highlevel = PyImport_ImportModule("test.support.interpreters");
if (highlevel == NULL) {
return NULL;
}
}
const char *clsname = (end == CHANNEL_RECV) ? "RecvChannel" :
"SendChannel";
PyObject *cls = PyObject_GetAttrString(highlevel, clsname);
Py_DECREF(highlevel);
PyObject *cls = (PyObject *)_get_current_channel_end_type(end);
if (cls == NULL) {
return NULL;
}
@ -1943,6 +1964,103 @@ static PyType_Spec ChannelIDType_spec = {
};
/* SendChannel and RecvChannel classes */
// XXX Use a new __xid__ protocol instead?
static PyTypeObject *
_get_current_channel_end_type(int end)
{
module_state *state = _get_current_module_state();
if (state == NULL) {
return NULL;
}
PyTypeObject *cls;
if (end == CHANNEL_SEND) {
cls = state->send_channel_type;
}
else {
assert(end == CHANNEL_RECV);
cls = state->recv_channel_type;
}
if (cls == NULL) {
PyObject *highlevel = PyImport_ImportModule("interpreters");
if (highlevel == NULL) {
PyErr_Clear();
highlevel = PyImport_ImportModule("test.support.interpreters");
if (highlevel == NULL) {
return NULL;
}
}
if (end == CHANNEL_SEND) {
cls = state->send_channel_type;
}
else {
cls = state->recv_channel_type;
}
assert(cls != NULL);
}
return cls;
}
static PyObject *
_channel_end_from_xid(_PyCrossInterpreterData *data)
{
channelid *cid = (channelid *)_channelid_from_xid(data);
if (cid == NULL) {
return NULL;
}
PyTypeObject *cls = _get_current_channel_end_type(cid->end);
if (cls == NULL) {
return NULL;
}
PyObject *obj = PyObject_CallOneArg((PyObject *)cls, (PyObject *)cid);
Py_DECREF(cid);
return obj;
}
static int
_channel_end_shared(PyThreadState *tstate, PyObject *obj,
_PyCrossInterpreterData *data)
{
PyObject *cidobj = PyObject_GetAttrString(obj, "_id");
if (cidobj == NULL) {
return -1;
}
if (_channelid_shared(tstate, cidobj, data) < 0) {
return -1;
}
data->new_object = _channel_end_from_xid;
return 0;
}
static int
set_channel_end_types(PyObject *mod, PyTypeObject *send, PyTypeObject *recv)
{
module_state *state = get_module_state(mod);
if (state == NULL) {
return -1;
}
if (state->send_channel_type != NULL
|| state->recv_channel_type != NULL)
{
PyErr_SetString(PyExc_TypeError, "already registered");
return -1;
}
state->send_channel_type = (PyTypeObject *)Py_NewRef(send);
state->recv_channel_type = (PyTypeObject *)Py_NewRef(recv);
if (_PyCrossInterpreterData_RegisterClass(send, _channel_end_shared)) {
return -1;
}
if (_PyCrossInterpreterData_RegisterClass(recv, _channel_end_shared)) {
return -1;
}
return 0;
}
/* module level code ********************************************************/
/* globals is the process-global state for the module. It holds all
@ -2346,13 +2464,38 @@ channel__channel_id(PyObject *self, PyObject *args, PyObject *kwds)
return NULL;
}
PyTypeObject *cls = state->ChannelIDType;
PyObject *mod = get_module_from_owned_type(cls);
if (mod == NULL) {
assert(get_module_from_owned_type(cls) == self);
return _channelid_new(self, cls, args, kwds);
}
static PyObject *
channel__register_end_types(PyObject *self, PyObject *args, PyObject *kwds)
{
static char *kwlist[] = {"send", "recv", NULL};
PyObject *send;
PyObject *recv;
if (!PyArg_ParseTupleAndKeywords(args, kwds,
"OO:_register_end_types", kwlist,
&send, &recv)) {
return NULL;
}
PyObject *cid = _channelid_new(mod, cls, args, kwds);
Py_DECREF(mod);
return cid;
if (!PyType_Check(send)) {
PyErr_SetString(PyExc_TypeError, "expected a type for 'send'");
return NULL;
}
if (!PyType_Check(recv)) {
PyErr_SetString(PyExc_TypeError, "expected a type for 'recv'");
return NULL;
}
PyTypeObject *cls_send = (PyTypeObject *)send;
PyTypeObject *cls_recv = (PyTypeObject *)recv;
if (set_channel_end_types(self, cls_send, cls_recv) < 0) {
return NULL;
}
Py_RETURN_NONE;
}
static PyMethodDef module_functions[] = {
@ -2374,6 +2517,8 @@ static PyMethodDef module_functions[] = {
METH_VARARGS | METH_KEYWORDS, channel_release_doc},
{"_channel_id", _PyCFunction_CAST(channel__channel_id),
METH_VARARGS | METH_KEYWORDS, NULL},
{"_register_end_types", _PyCFunction_CAST(channel__register_end_types),
METH_VARARGS | METH_KEYWORDS, NULL},
{NULL, NULL} /* sentinel */
};