gh-84570: Implement Waiting in SendChannel.send() (gh-110565)

We had been faking it (poorly).

We will add timeouts separately.
This commit is contained in:
Eric Snow 2023-10-10 03:35:14 -06:00 committed by GitHub
parent 46462ff929
commit 757cc35b6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 182 additions and 109 deletions

View File

@ -208,11 +208,7 @@ class SendChannel(_ChannelEnd):
This blocks until the object is received.
"""
_channels.send(self._id, obj)
# XXX We are missing a low-level channel_send_wait().
# See bpo-32604 and gh-19829.
# Until that shows up we fake it:
time.sleep(2)
_channels.send(self._id, obj, blocking=True)
def send_nowait(self, obj):
"""Send the object to the channel's receiving end.
@ -223,14 +219,14 @@ class SendChannel(_ChannelEnd):
# XXX Note that at the moment channel_send() only ever returns
# None. This should be fixed when channel_send_wait() is added.
# See bpo-32604 and gh-19829.
return _channels.send(self._id, obj)
return _channels.send(self._id, obj, blocking=False)
def send_buffer(self, obj):
"""Send the object's buffer to the channel's receiving end.
This blocks until the object is received.
"""
_channels.send_buffer(self._id, obj)
_channels.send_buffer(self._id, obj, blocking=True)
def send_buffer_nowait(self, obj):
"""Send the object's buffer to the channel's receiving end.
@ -238,7 +234,7 @@ class SendChannel(_ChannelEnd):
If the object is immediately received then return True
(else False). Otherwise this is the same as send().
"""
return _channels.send_buffer(self._id, obj)
return _channels.send_buffer(self._id, obj, blocking=False)
def close(self):
_channels.close(self._id, send=True)

View File

@ -21,6 +21,13 @@ channels = import_helper.import_module('_xxinterpchannels')
##################################
# helpers
def recv_wait(cid):
while True:
try:
return channels.recv(cid)
except channels.ChannelEmptyError:
time.sleep(0.1)
#@contextmanager
#def run_threaded(id, source, **shared):
# def run():
@ -189,7 +196,7 @@ def run_action(cid, action, end, state, *, hideclosed=True):
def _run_action(cid, action, end, state):
if action == 'use':
if end == 'send':
channels.send(cid, b'spam')
channels.send(cid, b'spam', blocking=False)
return state.incr()
elif end == 'recv':
if not state.pending:
@ -332,7 +339,7 @@ class ChannelIDTests(TestBase):
chan = channels.create()
obj = channels.create()
channels.send(chan, obj)
channels.send(chan, obj, blocking=False)
got = channels.recv(chan)
self.assertEqual(got, obj)
@ -390,7 +397,7 @@ class ChannelTests(TestBase):
"""Test basic listing channel interpreters."""
interp0 = interpreters.get_main()
cid = channels.create()
channels.send(cid, "send")
channels.send(cid, "send", blocking=False)
# Test for a channel that has one end associated to an interpreter.
send_interps = channels.list_interpreters(cid, send=True)
recv_interps = channels.list_interpreters(cid, send=False)
@ -416,10 +423,10 @@ class ChannelTests(TestBase):
interp3 = interpreters.create()
cid = channels.create()
channels.send(cid, "send")
channels.send(cid, "send", blocking=False)
_run_output(interp1, dedent(f"""
import _xxinterpchannels as _channels
_channels.send({cid}, "send")
_channels.send({cid}, "send", blocking=False)
"""))
_run_output(interp2, dedent(f"""
import _xxinterpchannels as _channels
@ -439,7 +446,7 @@ class ChannelTests(TestBase):
interp0 = interpreters.get_main()
interp1 = interpreters.create()
cid = channels.create()
channels.send(cid, "send")
channels.send(cid, "send", blocking=False)
_run_output(interp1, dedent(f"""
import _xxinterpchannels as _channels
obj = _channels.recv({cid})
@ -465,12 +472,12 @@ class ChannelTests(TestBase):
interp1 = interpreters.create()
interp2 = interpreters.create()
cid = channels.create()
channels.send(cid, "data")
channels.send(cid, "data", blocking=False)
_run_output(interp1, dedent(f"""
import _xxinterpchannels as _channels
obj = _channels.recv({cid})
"""))
channels.send(cid, "data")
channels.send(cid, "data", blocking=False)
_run_output(interp2, dedent(f"""
import _xxinterpchannels as _channels
obj = _channels.recv({cid})
@ -506,7 +513,7 @@ class ChannelTests(TestBase):
interp1 = interpreters.create()
cid = channels.create()
# Put something in the channel so that it's not empty.
channels.send(cid, "send")
channels.send(cid, "send", blocking=False)
# Check initial state.
send_interps = channels.list_interpreters(cid, send=True)
@ -528,7 +535,7 @@ class ChannelTests(TestBase):
interp1 = interpreters.create()
cid = channels.create()
# Put something in the channel so that it's not empty.
channels.send(cid, "send")
channels.send(cid, "send", blocking=False)
# Check initial state.
send_interps = channels.list_interpreters(cid, send=True)
@ -562,7 +569,7 @@ class ChannelTests(TestBase):
def test_send_recv_main(self):
cid = channels.create()
orig = b'spam'
channels.send(cid, orig)
channels.send(cid, orig, blocking=False)
obj = channels.recv(cid)
self.assertEqual(obj, orig)
@ -574,7 +581,7 @@ class ChannelTests(TestBase):
import _xxinterpchannels as _channels
cid = _channels.create()
orig = b'spam'
_channels.send(cid, orig)
_channels.send(cid, orig, blocking=False)
obj = _channels.recv(cid)
assert obj is not orig
assert obj == orig
@ -585,7 +592,7 @@ class ChannelTests(TestBase):
id1 = interpreters.create()
out = _run_output(id1, dedent(f"""
import _xxinterpchannels as _channels
_channels.send({cid}, b'spam')
_channels.send({cid}, b'spam', blocking=False)
"""))
obj = channels.recv(cid)
@ -595,19 +602,14 @@ class ChannelTests(TestBase):
cid = channels.create()
def f():
while True:
try:
obj = channels.recv(cid)
break
except channels.ChannelEmptyError:
time.sleep(0.1)
obj = recv_wait(cid)
channels.send(cid, obj)
t = threading.Thread(target=f)
t.start()
channels.send(cid, b'spam')
obj = recv_wait(cid)
t.join()
obj = channels.recv(cid)
self.assertEqual(obj, b'spam')
@ -634,8 +636,8 @@ class ChannelTests(TestBase):
t.start()
channels.send(cid, b'spam')
obj = recv_wait(cid)
t.join()
obj = channels.recv(cid)
self.assertEqual(obj, b'eggs')
@ -656,10 +658,10 @@ class ChannelTests(TestBase):
default = object()
cid = channels.create()
obj1 = channels.recv(cid, default)
channels.send(cid, None)
channels.send(cid, 1)
channels.send(cid, b'spam')
channels.send(cid, b'eggs')
channels.send(cid, None, blocking=False)
channels.send(cid, 1, blocking=False)
channels.send(cid, b'spam', blocking=False)
channels.send(cid, b'eggs', blocking=False)
obj2 = channels.recv(cid, default)
obj3 = channels.recv(cid, default)
obj4 = channels.recv(cid)
@ -679,7 +681,7 @@ class ChannelTests(TestBase):
interp = interpreters.create()
interpreters.run_string(interp, dedent(f"""
import _xxinterpchannels as _channels
_channels.send({cid1}, b'spam')
_channels.send({cid1}, b'spam', blocking=False)
"""))
interpreters.destroy(interp)
@ -692,9 +694,9 @@ class ChannelTests(TestBase):
interp = interpreters.create()
interpreters.run_string(interp, dedent(f"""
import _xxinterpchannels as _channels
_channels.send({cid2}, b'spam')
_channels.send({cid2}, b'spam', blocking=False)
"""))
channels.send(cid2, b'eggs')
channels.send(cid2, b'eggs', blocking=False)
interpreters.destroy(interp)
channels.recv(cid2)
@ -706,7 +708,7 @@ class ChannelTests(TestBase):
def test_send_buffer(self):
buf = bytearray(b'spamspamspam')
cid = channels.create()
channels.send_buffer(cid, buf)
channels.send_buffer(cid, buf, blocking=False)
obj = channels.recv(cid)
self.assertIsNot(obj, buf)
@ -728,7 +730,7 @@ class ChannelTests(TestBase):
]
for obj in objects:
with self.subTest(obj):
channels.send(cid, obj)
channels.send(cid, obj, blocking=False)
got = channels.recv(cid)
self.assertEqual(got, obj)
@ -744,7 +746,7 @@ class ChannelTests(TestBase):
out = _run_output(interp, dedent("""
import _xxinterpchannels as _channels
print(cid.end)
_channels.send(cid, b'spam')
_channels.send(cid, b'spam', blocking=False)
"""),
dict(cid=cid.send))
obj = channels.recv(cid)
@ -764,7 +766,7 @@ class ChannelTests(TestBase):
out = _run_output(interp, dedent("""
import _xxinterpchannels as _channels
print(chan.id.end)
_channels.send(chan.id, b'spam')
_channels.send(chan.id, b'spam', blocking=False)
"""),
dict(chan=cid.send))
obj = channels.recv(cid)
@ -776,7 +778,7 @@ class ChannelTests(TestBase):
def test_close_single_user(self):
cid = channels.create()
channels.send(cid, b'spam')
channels.send(cid, b'spam', blocking=False)
channels.recv(cid)
channels.close(cid)
@ -791,7 +793,7 @@ class ChannelTests(TestBase):
id2 = interpreters.create()
interpreters.run_string(id1, dedent(f"""
import _xxinterpchannels as _channels
_channels.send({cid}, b'spam')
_channels.send({cid}, b'spam', blocking=False)
"""))
interpreters.run_string(id2, dedent(f"""
import _xxinterpchannels as _channels
@ -811,7 +813,7 @@ class ChannelTests(TestBase):
def test_close_multiple_times(self):
cid = channels.create()
channels.send(cid, b'spam')
channels.send(cid, b'spam', blocking=False)
channels.recv(cid)
channels.close(cid)
@ -828,7 +830,7 @@ class ChannelTests(TestBase):
for send, recv in tests:
with self.subTest((send, recv)):
cid = channels.create()
channels.send(cid, b'spam')
channels.send(cid, b'spam', blocking=False)
channels.recv(cid)
channels.close(cid, send=send, recv=recv)
@ -839,31 +841,31 @@ class ChannelTests(TestBase):
def test_close_defaults_with_unused_items(self):
cid = channels.create()
channels.send(cid, b'spam')
channels.send(cid, b'ham')
channels.send(cid, b'spam', blocking=False)
channels.send(cid, b'ham', blocking=False)
with self.assertRaises(channels.ChannelNotEmptyError):
channels.close(cid)
channels.recv(cid)
channels.send(cid, b'eggs')
channels.send(cid, b'eggs', blocking=False)
def test_close_recv_with_unused_items_unforced(self):
cid = channels.create()
channels.send(cid, b'spam')
channels.send(cid, b'ham')
channels.send(cid, b'spam', blocking=False)
channels.send(cid, b'ham', blocking=False)
with self.assertRaises(channels.ChannelNotEmptyError):
channels.close(cid, recv=True)
channels.recv(cid)
channels.send(cid, b'eggs')
channels.send(cid, b'eggs', blocking=False)
channels.recv(cid)
channels.recv(cid)
channels.close(cid, recv=True)
def test_close_send_with_unused_items_unforced(self):
cid = channels.create()
channels.send(cid, b'spam')
channels.send(cid, b'ham')
channels.send(cid, b'spam', blocking=False)
channels.send(cid, b'ham', blocking=False)
channels.close(cid, send=True)
with self.assertRaises(channels.ChannelClosedError):
@ -875,21 +877,21 @@ class ChannelTests(TestBase):
def test_close_both_with_unused_items_unforced(self):
cid = channels.create()
channels.send(cid, b'spam')
channels.send(cid, b'ham')
channels.send(cid, b'spam', blocking=False)
channels.send(cid, b'ham', blocking=False)
with self.assertRaises(channels.ChannelNotEmptyError):
channels.close(cid, recv=True, send=True)
channels.recv(cid)
channels.send(cid, b'eggs')
channels.send(cid, b'eggs', blocking=False)
channels.recv(cid)
channels.recv(cid)
channels.close(cid, recv=True)
def test_close_recv_with_unused_items_forced(self):
cid = channels.create()
channels.send(cid, b'spam')
channels.send(cid, b'ham')
channels.send(cid, b'spam', blocking=False)
channels.send(cid, b'ham', blocking=False)
channels.close(cid, recv=True, force=True)
with self.assertRaises(channels.ChannelClosedError):
@ -899,8 +901,8 @@ class ChannelTests(TestBase):
def test_close_send_with_unused_items_forced(self):
cid = channels.create()
channels.send(cid, b'spam')
channels.send(cid, b'ham')
channels.send(cid, b'spam', blocking=False)
channels.send(cid, b'ham', blocking=False)
channels.close(cid, send=True, force=True)
with self.assertRaises(channels.ChannelClosedError):
@ -910,8 +912,8 @@ class ChannelTests(TestBase):
def test_close_both_with_unused_items_forced(self):
cid = channels.create()
channels.send(cid, b'spam')
channels.send(cid, b'ham')
channels.send(cid, b'spam', blocking=False)
channels.send(cid, b'ham', blocking=False)
channels.close(cid, send=True, recv=True, force=True)
with self.assertRaises(channels.ChannelClosedError):
@ -930,7 +932,7 @@ class ChannelTests(TestBase):
def test_close_by_unassociated_interp(self):
cid = channels.create()
channels.send(cid, b'spam')
channels.send(cid, b'spam', blocking=False)
interp = interpreters.create()
interpreters.run_string(interp, dedent(f"""
import _xxinterpchannels as _channels
@ -943,9 +945,9 @@ class ChannelTests(TestBase):
def test_close_used_multiple_times_by_single_user(self):
cid = channels.create()
channels.send(cid, b'spam')
channels.send(cid, b'spam')
channels.send(cid, b'spam')
channels.send(cid, b'spam', blocking=False)
channels.send(cid, b'spam', blocking=False)
channels.send(cid, b'spam', blocking=False)
channels.recv(cid)
channels.close(cid, force=True)
@ -1017,7 +1019,7 @@ class ChannelReleaseTests(TestBase):
def test_single_user(self):
cid = channels.create()
channels.send(cid, b'spam')
channels.send(cid, b'spam', blocking=False)
channels.recv(cid)
channels.release(cid, send=True, recv=True)
@ -1032,7 +1034,7 @@ class ChannelReleaseTests(TestBase):
id2 = interpreters.create()
interpreters.run_string(id1, dedent(f"""
import _xxinterpchannels as _channels
_channels.send({cid}, b'spam')
_channels.send({cid}, b'spam', blocking=False)
"""))
out = _run_output(id2, dedent(f"""
import _xxinterpchannels as _channels
@ -1048,7 +1050,7 @@ class ChannelReleaseTests(TestBase):
def test_no_kwargs(self):
cid = channels.create()
channels.send(cid, b'spam')
channels.send(cid, b'spam', blocking=False)
channels.recv(cid)
channels.release(cid)
@ -1059,7 +1061,7 @@ class ChannelReleaseTests(TestBase):
def test_multiple_times(self):
cid = channels.create()
channels.send(cid, b'spam')
channels.send(cid, b'spam', blocking=False)
channels.recv(cid)
channels.release(cid, send=True, recv=True)
@ -1068,8 +1070,8 @@ class ChannelReleaseTests(TestBase):
def test_with_unused_items(self):
cid = channels.create()
channels.send(cid, b'spam')
channels.send(cid, b'ham')
channels.send(cid, b'spam', blocking=False)
channels.send(cid, b'ham', blocking=False)
channels.release(cid, send=True, recv=True)
with self.assertRaises(channels.ChannelClosedError):
@ -1086,7 +1088,7 @@ class ChannelReleaseTests(TestBase):
def test_by_unassociated_interp(self):
cid = channels.create()
channels.send(cid, b'spam')
channels.send(cid, b'spam', blocking=False)
interp = interpreters.create()
interpreters.run_string(interp, dedent(f"""
import _xxinterpchannels as _channels
@ -1105,7 +1107,7 @@ class ChannelReleaseTests(TestBase):
interp = interpreters.create()
interpreters.run_string(interp, dedent(f"""
import _xxinterpchannels as _channels
obj = _channels.send({cid}, b'spam')
obj = _channels.send({cid}, b'spam', blocking=False)
_channels.release({cid})
"""))
@ -1115,9 +1117,9 @@ class ChannelReleaseTests(TestBase):
def test_partially(self):
# XXX Is partial close too weird/confusing?
cid = channels.create()
channels.send(cid, None)
channels.send(cid, None, blocking=False)
channels.recv(cid)
channels.send(cid, b'spam')
channels.send(cid, b'spam', blocking=False)
channels.release(cid, send=True)
obj = channels.recv(cid)
@ -1125,9 +1127,9 @@ class ChannelReleaseTests(TestBase):
def test_used_multiple_times_by_single_user(self):
cid = channels.create()
channels.send(cid, b'spam')
channels.send(cid, b'spam')
channels.send(cid, b'spam')
channels.send(cid, b'spam', blocking=False)
channels.send(cid, b'spam', blocking=False)
channels.send(cid, b'spam', blocking=False)
channels.recv(cid)
channels.release(cid, send=True, recv=True)
@ -1212,7 +1214,7 @@ class ChannelCloseFixture(namedtuple('ChannelCloseFixture',
cid = _xxsubchannels.create()
# We purposefully send back an int to avoid tying the
# channel to the other interpreter.
_xxsubchannels.send({ch}, int(cid))
_xxsubchannels.send({ch}, int(cid), blocking=False)
del _xxsubinterpreters
""")
self._cid = channels.recv(ch)
@ -1442,8 +1444,8 @@ class ExhaustiveChannelTests(TestBase):
{repr(fix.state)},
hideclosed={hideclosed},
)
channels.send({_cid}, result.pending.to_bytes(1, 'little'))
channels.send({_cid}, b'X' if result.closed else b'')
channels.send({_cid}, result.pending.to_bytes(1, 'little'), blocking=False)
channels.send({_cid}, b'X' if result.closed else b'', blocking=False)
""")
result = ChannelState(
pending=int.from_bytes(channels.recv(_cid), 'little'),
@ -1490,7 +1492,7 @@ class ExhaustiveChannelTests(TestBase):
""")
run_interp(interp.id, """
with helpers.expect_channel_closed():
channels.send(cid, b'spam')
channels.send(cid, b'spam', blocking=False)
""")
run_interp(interp.id, """
with helpers.expect_channel_closed():

View File

@ -964,8 +964,8 @@ class TestSendRecv(TestBase):
orig = b'spam'
s.send(orig)
t.join()
obj = r.recv()
t.join()
self.assertEqual(obj, orig)
self.assertIsNot(obj, orig)

View File

@ -234,6 +234,17 @@ add_new_type(PyObject *mod, PyType_Spec *spec, crossinterpdatafunc shared,
return cls;
}
static void
wait_for_lock(PyThread_type_lock mutex)
{
Py_BEGIN_ALLOW_THREADS
// XXX Handle eintr, etc.
PyThread_acquire_lock(mutex, WAIT_LOCK);
Py_END_ALLOW_THREADS
PyThread_release_lock(mutex);
}
/* Cross-interpreter Buffer Views *******************************************/
@ -567,6 +578,7 @@ struct _channelitem;
typedef struct _channelitem {
_PyCrossInterpreterData *data;
PyThread_type_lock recv_mutex;
struct _channelitem *next;
} _channelitem;
@ -612,10 +624,11 @@ _channelitem_free_all(_channelitem *item)
}
static _PyCrossInterpreterData *
_channelitem_popped(_channelitem *item)
_channelitem_popped(_channelitem *item, PyThread_type_lock *recv_mutex)
{
_PyCrossInterpreterData *data = item->data;
item->data = NULL;
*recv_mutex = item->recv_mutex;
_channelitem_free(item);
return data;
}
@ -657,13 +670,15 @@ _channelqueue_free(_channelqueue *queue)
}
static int
_channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data)
_channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data,
PyThread_type_lock recv_mutex)
{
_channelitem *item = _channelitem_new();
if (item == NULL) {
return -1;
}
item->data = data;
item->recv_mutex = recv_mutex;
queue->count += 1;
if (queue->first == NULL) {
@ -677,7 +692,7 @@ _channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data)
}
static _PyCrossInterpreterData *
_channelqueue_get(_channelqueue *queue)
_channelqueue_get(_channelqueue *queue, PyThread_type_lock *recv_mutex)
{
_channelitem *item = queue->first;
if (item == NULL) {
@ -689,7 +704,7 @@ _channelqueue_get(_channelqueue *queue)
}
queue->count -= 1;
return _channelitem_popped(item);
return _channelitem_popped(item, recv_mutex);
}
static void
@ -1006,7 +1021,7 @@ _channel_free(_PyChannelState *chan)
static int
_channel_add(_PyChannelState *chan, int64_t interp,
_PyCrossInterpreterData *data)
_PyCrossInterpreterData *data, PyThread_type_lock recv_mutex)
{
int res = -1;
PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
@ -1020,7 +1035,7 @@ _channel_add(_PyChannelState *chan, int64_t interp,
goto done;
}
if (_channelqueue_put(chan->queue, data) != 0) {
if (_channelqueue_put(chan->queue, data, recv_mutex) != 0) {
goto done;
}
@ -1046,12 +1061,17 @@ _channel_next(_PyChannelState *chan, int64_t interp,
goto done;
}
_PyCrossInterpreterData *data = _channelqueue_get(chan->queue);
PyThread_type_lock recv_mutex = NULL;
_PyCrossInterpreterData *data = _channelqueue_get(chan->queue, &recv_mutex);
if (data == NULL && !PyErr_Occurred() && chan->closing != NULL) {
chan->open = 0;
}
*res = data;
if (recv_mutex != NULL) {
PyThread_release_lock(recv_mutex);
}
done:
PyThread_release_lock(chan->mutex);
if (chan->queue->count == 0) {
@ -1571,7 +1591,8 @@ _channel_destroy(_channels *channels, int64_t id)
}
static int
_channel_send(_channels *channels, int64_t id, PyObject *obj)
_channel_send(_channels *channels, int64_t id, PyObject *obj,
PyThread_type_lock recv_mutex)
{
PyInterpreterState *interp = _get_current_interp();
if (interp == NULL) {
@ -1606,7 +1627,8 @@ _channel_send(_channels *channels, int64_t id, PyObject *obj)
}
// Add the data to the channel.
int res = _channel_add(chan, PyInterpreterState_GetID(interp), data);
int res = _channel_add(chan, PyInterpreterState_GetID(interp), data,
recv_mutex);
PyThread_release_lock(mutex);
if (res != 0) {
// We may chain an exception here:
@ -2489,42 +2511,70 @@ receive end.");
static PyObject *
channel_send(PyObject *self, PyObject *args, PyObject *kwds)
{
static char *kwlist[] = {"cid", "obj", NULL};
// XXX Add a timeout arg.
static char *kwlist[] = {"cid", "obj", "blocking", NULL};
int64_t cid;
struct channel_id_converter_data cid_data = {
.module = self,
};
PyObject *obj;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O:channel_send", kwlist,
channel_id_converter, &cid_data, &obj)) {
int blocking = 1;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O|$p:channel_send", kwlist,
channel_id_converter, &cid_data, &obj,
&blocking)) {
return NULL;
}
cid = cid_data.cid;
int err = _channel_send(&_globals.channels, cid, obj);
if (handle_channel_error(err, self, cid)) {
return NULL;
if (blocking) {
PyThread_type_lock mutex = PyThread_allocate_lock();
if (mutex == NULL) {
PyErr_NoMemory();
return NULL;
}
PyThread_acquire_lock(mutex, WAIT_LOCK);
/* Queue up the object. */
int err = _channel_send(&_globals.channels, cid, obj, mutex);
if (handle_channel_error(err, self, cid)) {
PyThread_release_lock(mutex);
return NULL;
}
/* Wait until the object is received. */
wait_for_lock(mutex);
}
else {
/* Queue up the object. */
int err = _channel_send(&_globals.channels, cid, obj, NULL);
if (handle_channel_error(err, self, cid)) {
return NULL;
}
}
Py_RETURN_NONE;
}
PyDoc_STRVAR(channel_send_doc,
"channel_send(cid, obj)\n\
"channel_send(cid, obj, blocking=True)\n\
\n\
Add the object's data to the channel's queue.");
Add the object's data to the channel's queue.\n\
By default this waits for the object to be received.");
static PyObject *
channel_send_buffer(PyObject *self, PyObject *args, PyObject *kwds)
{
static char *kwlist[] = {"cid", "obj", NULL};
static char *kwlist[] = {"cid", "obj", "blocking", NULL};
int64_t cid;
struct channel_id_converter_data cid_data = {
.module = self,
};
PyObject *obj;
int blocking = 1;
if (!PyArg_ParseTupleAndKeywords(args, kwds,
"O&O:channel_send_buffer", kwlist,
channel_id_converter, &cid_data, &obj)) {
"O&O|$p:channel_send_buffer", kwlist,
channel_id_converter, &cid_data, &obj,
&blocking)) {
return NULL;
}
cid = cid_data.cid;
@ -2534,18 +2584,43 @@ channel_send_buffer(PyObject *self, PyObject *args, PyObject *kwds)
return NULL;
}
int err = _channel_send(&_globals.channels, cid, tempobj);
Py_DECREF(tempobj);
if (handle_channel_error(err, self, cid)) {
return NULL;
if (blocking) {
PyThread_type_lock mutex = PyThread_allocate_lock();
if (mutex == NULL) {
Py_DECREF(tempobj);
PyErr_NoMemory();
return NULL;
}
PyThread_acquire_lock(mutex, WAIT_LOCK);
/* Queue up the buffer. */
int err = _channel_send(&_globals.channels, cid, tempobj, mutex);
Py_DECREF(tempobj);
if (handle_channel_error(err, self, cid)) {
PyThread_acquire_lock(mutex, WAIT_LOCK);
return NULL;
}
/* Wait until the buffer is received. */
wait_for_lock(mutex);
}
else {
/* Queue up the buffer. */
int err = _channel_send(&_globals.channels, cid, tempobj, NULL);
Py_DECREF(tempobj);
if (handle_channel_error(err, self, cid)) {
return NULL;
}
}
Py_RETURN_NONE;
}
PyDoc_STRVAR(channel_send_buffer_doc,
"channel_send_buffer(cid, obj)\n\
"channel_send_buffer(cid, obj, blocking=True)\n\
\n\
Add the object's buffer to the channel's queue.");
Add the object's buffer to the channel's queue.\n\
By default this waits for the object to be received.");
static PyObject *
channel_recv(PyObject *self, PyObject *args, PyObject *kwds)