mirror of https://github.com/python/cpython
gh-76785: Module-level Fixes for test.support.interpreters (gh-110236)
* add RecvChannel.close() and SendChannel.close() * make RecvChannel and SendChannel shareable * expose ChannelEmptyError and ChannelNotEmptyError
This commit is contained in:
parent
014aacda62
commit
a8f5dab58d
|
@ -7,7 +7,8 @@ import _xxinterpchannels as _channels
|
|||
# aliases:
|
||||
from _xxsubinterpreters import is_shareable
|
||||
from _xxinterpchannels import (
|
||||
ChannelError, ChannelNotFoundError, ChannelEmptyError,
|
||||
ChannelError, ChannelNotFoundError, ChannelClosedError,
|
||||
ChannelEmptyError, ChannelNotEmptyError,
|
||||
)
|
||||
|
||||
|
||||
|
@ -117,10 +118,16 @@ def list_all_channels():
|
|||
class _ChannelEnd:
|
||||
"""The base class for RecvChannel and SendChannel."""
|
||||
|
||||
def __init__(self, id):
|
||||
if not isinstance(id, (int, _channels.ChannelID)):
|
||||
raise TypeError(f'id must be an int, got {id!r}')
|
||||
self._id = id
|
||||
_end = None
|
||||
|
||||
def __init__(self, cid):
|
||||
if self._end == 'send':
|
||||
cid = _channels._channel_id(cid, send=True, force=True)
|
||||
elif self._end == 'recv':
|
||||
cid = _channels._channel_id(cid, recv=True, force=True)
|
||||
else:
|
||||
raise NotImplementedError(self._end)
|
||||
self._id = cid
|
||||
|
||||
def __repr__(self):
|
||||
return f'{type(self).__name__}(id={int(self._id)})'
|
||||
|
@ -147,6 +154,8 @@ _NOT_SET = object()
|
|||
class RecvChannel(_ChannelEnd):
|
||||
"""The receiving end of a cross-interpreter channel."""
|
||||
|
||||
_end = 'recv'
|
||||
|
||||
def recv(self, *, _sentinel=object(), _delay=10 / 1000): # 10 milliseconds
|
||||
"""Return the next object from the channel.
|
||||
|
||||
|
@ -171,10 +180,15 @@ class RecvChannel(_ChannelEnd):
|
|||
else:
|
||||
return _channels.recv(self._id, default)
|
||||
|
||||
def close(self):
|
||||
_channels.close(self._id, recv=True)
|
||||
|
||||
|
||||
class SendChannel(_ChannelEnd):
|
||||
"""The sending end of a cross-interpreter channel."""
|
||||
|
||||
_end = 'send'
|
||||
|
||||
def send(self, obj):
|
||||
"""Send the object (i.e. its data) to the channel's receiving end.
|
||||
|
||||
|
@ -196,3 +210,9 @@ class SendChannel(_ChannelEnd):
|
|||
# None. This should be fixed when channel_send_wait() is added.
|
||||
# See bpo-32604 and gh-19829.
|
||||
return _channels.send(self._id, obj)
|
||||
|
||||
def close(self):
|
||||
_channels.close(self._id, send=True)
|
||||
|
||||
|
||||
_channels._register_end_types(SendChannel, RecvChannel)
|
||||
|
|
|
@ -822,6 +822,22 @@ class TestChannels(TestBase):
|
|||
after = set(interpreters.list_all_channels())
|
||||
self.assertEqual(after, created)
|
||||
|
||||
def test_shareable(self):
|
||||
rch, sch = interpreters.create_channel()
|
||||
|
||||
self.assertTrue(
|
||||
interpreters.is_shareable(rch))
|
||||
self.assertTrue(
|
||||
interpreters.is_shareable(sch))
|
||||
|
||||
sch.send_nowait(rch)
|
||||
sch.send_nowait(sch)
|
||||
rch2 = rch.recv()
|
||||
sch2 = rch.recv()
|
||||
|
||||
self.assertEqual(rch2, rch)
|
||||
self.assertEqual(sch2, sch)
|
||||
|
||||
|
||||
class TestRecvChannelAttrs(TestBase):
|
||||
|
||||
|
|
|
@ -198,6 +198,9 @@ _release_xid_data(_PyCrossInterpreterData *data, int flags)
|
|||
/* module state *************************************************************/
|
||||
|
||||
typedef struct {
|
||||
PyTypeObject *send_channel_type;
|
||||
PyTypeObject *recv_channel_type;
|
||||
|
||||
/* heap types */
|
||||
PyTypeObject *ChannelIDType;
|
||||
|
||||
|
@ -218,6 +221,21 @@ get_module_state(PyObject *mod)
|
|||
return state;
|
||||
}
|
||||
|
||||
static module_state *
|
||||
_get_current_module_state(void)
|
||||
{
|
||||
PyObject *mod = _get_current_module();
|
||||
if (mod == NULL) {
|
||||
// XXX import it?
|
||||
PyErr_SetString(PyExc_RuntimeError,
|
||||
MODULE_NAME " module not imported yet");
|
||||
return NULL;
|
||||
}
|
||||
module_state *state = get_module_state(mod);
|
||||
Py_DECREF(mod);
|
||||
return state;
|
||||
}
|
||||
|
||||
static int
|
||||
traverse_module_state(module_state *state, visitproc visit, void *arg)
|
||||
{
|
||||
|
@ -237,6 +255,9 @@ traverse_module_state(module_state *state, visitproc visit, void *arg)
|
|||
static int
|
||||
clear_module_state(module_state *state)
|
||||
{
|
||||
Py_CLEAR(state->send_channel_type);
|
||||
Py_CLEAR(state->recv_channel_type);
|
||||
|
||||
/* heap types */
|
||||
if (state->ChannelIDType != NULL) {
|
||||
(void)_PyCrossInterpreterData_UnregisterClass(state->ChannelIDType);
|
||||
|
@ -1529,17 +1550,20 @@ typedef struct channelid {
|
|||
struct channel_id_converter_data {
|
||||
PyObject *module;
|
||||
int64_t cid;
|
||||
int end;
|
||||
};
|
||||
|
||||
static int
|
||||
channel_id_converter(PyObject *arg, void *ptr)
|
||||
{
|
||||
int64_t cid;
|
||||
int end = 0;
|
||||
struct channel_id_converter_data *data = ptr;
|
||||
module_state *state = get_module_state(data->module);
|
||||
assert(state != NULL);
|
||||
if (PyObject_TypeCheck(arg, state->ChannelIDType)) {
|
||||
cid = ((channelid *)arg)->id;
|
||||
end = ((channelid *)arg)->end;
|
||||
}
|
||||
else if (PyIndex_Check(arg)) {
|
||||
cid = PyLong_AsLongLong(arg);
|
||||
|
@ -1559,6 +1583,7 @@ channel_id_converter(PyObject *arg, void *ptr)
|
|||
return 0;
|
||||
}
|
||||
data->cid = cid;
|
||||
data->end = end;
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
@ -1600,6 +1625,7 @@ _channelid_new(PyObject *mod, PyTypeObject *cls,
|
|||
{
|
||||
static char *kwlist[] = {"id", "send", "recv", "force", "_resolve", NULL};
|
||||
int64_t cid;
|
||||
int end;
|
||||
struct channel_id_converter_data cid_data = {
|
||||
.module = mod,
|
||||
};
|
||||
|
@ -1614,6 +1640,7 @@ _channelid_new(PyObject *mod, PyTypeObject *cls,
|
|||
return NULL;
|
||||
}
|
||||
cid = cid_data.cid;
|
||||
end = cid_data.end;
|
||||
|
||||
// Handle "send" and "recv".
|
||||
if (send == 0 && recv == 0) {
|
||||
|
@ -1621,14 +1648,17 @@ _channelid_new(PyObject *mod, PyTypeObject *cls,
|
|||
"'send' and 'recv' cannot both be False");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
int end = 0;
|
||||
if (send == 1) {
|
||||
else if (send == 1) {
|
||||
if (recv == 0 || recv == -1) {
|
||||
end = CHANNEL_SEND;
|
||||
}
|
||||
else {
|
||||
assert(recv == 1);
|
||||
end = 0;
|
||||
}
|
||||
}
|
||||
else if (recv == 1) {
|
||||
assert(send == 0 || send == -1);
|
||||
end = CHANNEL_RECV;
|
||||
}
|
||||
|
||||
|
@ -1773,21 +1803,12 @@ done:
|
|||
return res;
|
||||
}
|
||||
|
||||
static PyTypeObject * _get_current_channel_end_type(int end);
|
||||
|
||||
static PyObject *
|
||||
_channel_from_cid(PyObject *cid, int end)
|
||||
{
|
||||
PyObject *highlevel = PyImport_ImportModule("interpreters");
|
||||
if (highlevel == NULL) {
|
||||
PyErr_Clear();
|
||||
highlevel = PyImport_ImportModule("test.support.interpreters");
|
||||
if (highlevel == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
const char *clsname = (end == CHANNEL_RECV) ? "RecvChannel" :
|
||||
"SendChannel";
|
||||
PyObject *cls = PyObject_GetAttrString(highlevel, clsname);
|
||||
Py_DECREF(highlevel);
|
||||
PyObject *cls = (PyObject *)_get_current_channel_end_type(end);
|
||||
if (cls == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
|
@ -1943,6 +1964,103 @@ static PyType_Spec ChannelIDType_spec = {
|
|||
};
|
||||
|
||||
|
||||
/* SendChannel and RecvChannel classes */
|
||||
|
||||
// XXX Use a new __xid__ protocol instead?
|
||||
|
||||
static PyTypeObject *
|
||||
_get_current_channel_end_type(int end)
|
||||
{
|
||||
module_state *state = _get_current_module_state();
|
||||
if (state == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
PyTypeObject *cls;
|
||||
if (end == CHANNEL_SEND) {
|
||||
cls = state->send_channel_type;
|
||||
}
|
||||
else {
|
||||
assert(end == CHANNEL_RECV);
|
||||
cls = state->recv_channel_type;
|
||||
}
|
||||
if (cls == NULL) {
|
||||
PyObject *highlevel = PyImport_ImportModule("interpreters");
|
||||
if (highlevel == NULL) {
|
||||
PyErr_Clear();
|
||||
highlevel = PyImport_ImportModule("test.support.interpreters");
|
||||
if (highlevel == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
if (end == CHANNEL_SEND) {
|
||||
cls = state->send_channel_type;
|
||||
}
|
||||
else {
|
||||
cls = state->recv_channel_type;
|
||||
}
|
||||
assert(cls != NULL);
|
||||
}
|
||||
return cls;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
_channel_end_from_xid(_PyCrossInterpreterData *data)
|
||||
{
|
||||
channelid *cid = (channelid *)_channelid_from_xid(data);
|
||||
if (cid == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
PyTypeObject *cls = _get_current_channel_end_type(cid->end);
|
||||
if (cls == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
PyObject *obj = PyObject_CallOneArg((PyObject *)cls, (PyObject *)cid);
|
||||
Py_DECREF(cid);
|
||||
return obj;
|
||||
}
|
||||
|
||||
static int
|
||||
_channel_end_shared(PyThreadState *tstate, PyObject *obj,
|
||||
_PyCrossInterpreterData *data)
|
||||
{
|
||||
PyObject *cidobj = PyObject_GetAttrString(obj, "_id");
|
||||
if (cidobj == NULL) {
|
||||
return -1;
|
||||
}
|
||||
if (_channelid_shared(tstate, cidobj, data) < 0) {
|
||||
return -1;
|
||||
}
|
||||
data->new_object = _channel_end_from_xid;
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int
|
||||
set_channel_end_types(PyObject *mod, PyTypeObject *send, PyTypeObject *recv)
|
||||
{
|
||||
module_state *state = get_module_state(mod);
|
||||
if (state == NULL) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (state->send_channel_type != NULL
|
||||
|| state->recv_channel_type != NULL)
|
||||
{
|
||||
PyErr_SetString(PyExc_TypeError, "already registered");
|
||||
return -1;
|
||||
}
|
||||
state->send_channel_type = (PyTypeObject *)Py_NewRef(send);
|
||||
state->recv_channel_type = (PyTypeObject *)Py_NewRef(recv);
|
||||
|
||||
if (_PyCrossInterpreterData_RegisterClass(send, _channel_end_shared)) {
|
||||
return -1;
|
||||
}
|
||||
if (_PyCrossInterpreterData_RegisterClass(recv, _channel_end_shared)) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* module level code ********************************************************/
|
||||
|
||||
/* globals is the process-global state for the module. It holds all
|
||||
|
@ -2346,13 +2464,38 @@ channel__channel_id(PyObject *self, PyObject *args, PyObject *kwds)
|
|||
return NULL;
|
||||
}
|
||||
PyTypeObject *cls = state->ChannelIDType;
|
||||
PyObject *mod = get_module_from_owned_type(cls);
|
||||
if (mod == NULL) {
|
||||
assert(get_module_from_owned_type(cls) == self);
|
||||
|
||||
return _channelid_new(self, cls, args, kwds);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
channel__register_end_types(PyObject *self, PyObject *args, PyObject *kwds)
|
||||
{
|
||||
static char *kwlist[] = {"send", "recv", NULL};
|
||||
PyObject *send;
|
||||
PyObject *recv;
|
||||
if (!PyArg_ParseTupleAndKeywords(args, kwds,
|
||||
"OO:_register_end_types", kwlist,
|
||||
&send, &recv)) {
|
||||
return NULL;
|
||||
}
|
||||
PyObject *cid = _channelid_new(mod, cls, args, kwds);
|
||||
Py_DECREF(mod);
|
||||
return cid;
|
||||
if (!PyType_Check(send)) {
|
||||
PyErr_SetString(PyExc_TypeError, "expected a type for 'send'");
|
||||
return NULL;
|
||||
}
|
||||
if (!PyType_Check(recv)) {
|
||||
PyErr_SetString(PyExc_TypeError, "expected a type for 'recv'");
|
||||
return NULL;
|
||||
}
|
||||
PyTypeObject *cls_send = (PyTypeObject *)send;
|
||||
PyTypeObject *cls_recv = (PyTypeObject *)recv;
|
||||
|
||||
if (set_channel_end_types(self, cls_send, cls_recv) < 0) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyMethodDef module_functions[] = {
|
||||
|
@ -2374,6 +2517,8 @@ static PyMethodDef module_functions[] = {
|
|||
METH_VARARGS | METH_KEYWORDS, channel_release_doc},
|
||||
{"_channel_id", _PyCFunction_CAST(channel__channel_id),
|
||||
METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
{"_register_end_types", _PyCFunction_CAST(channel__register_end_types),
|
||||
METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
|
||||
{NULL, NULL} /* sentinel */
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue