mirror of https://github.com/python/cpython
gh-84570: Implement Waiting in SendChannel.send() (gh-110565)
We had been faking it (poorly). We will add timeouts separately.
This commit is contained in:
parent
46462ff929
commit
757cc35b6b
|
@ -208,11 +208,7 @@ class SendChannel(_ChannelEnd):
|
|||
|
||||
This blocks until the object is received.
|
||||
"""
|
||||
_channels.send(self._id, obj)
|
||||
# XXX We are missing a low-level channel_send_wait().
|
||||
# See bpo-32604 and gh-19829.
|
||||
# Until that shows up we fake it:
|
||||
time.sleep(2)
|
||||
_channels.send(self._id, obj, blocking=True)
|
||||
|
||||
def send_nowait(self, obj):
|
||||
"""Send the object to the channel's receiving end.
|
||||
|
@ -223,14 +219,14 @@ class SendChannel(_ChannelEnd):
|
|||
# XXX Note that at the moment channel_send() only ever returns
|
||||
# None. This should be fixed when channel_send_wait() is added.
|
||||
# See bpo-32604 and gh-19829.
|
||||
return _channels.send(self._id, obj)
|
||||
return _channels.send(self._id, obj, blocking=False)
|
||||
|
||||
def send_buffer(self, obj):
|
||||
"""Send the object's buffer to the channel's receiving end.
|
||||
|
||||
This blocks until the object is received.
|
||||
"""
|
||||
_channels.send_buffer(self._id, obj)
|
||||
_channels.send_buffer(self._id, obj, blocking=True)
|
||||
|
||||
def send_buffer_nowait(self, obj):
|
||||
"""Send the object's buffer to the channel's receiving end.
|
||||
|
@ -238,7 +234,7 @@ class SendChannel(_ChannelEnd):
|
|||
If the object is immediately received then return True
|
||||
(else False). Otherwise this is the same as send().
|
||||
"""
|
||||
return _channels.send_buffer(self._id, obj)
|
||||
return _channels.send_buffer(self._id, obj, blocking=False)
|
||||
|
||||
def close(self):
|
||||
_channels.close(self._id, send=True)
|
||||
|
|
|
@ -21,6 +21,13 @@ channels = import_helper.import_module('_xxinterpchannels')
|
|||
##################################
|
||||
# helpers
|
||||
|
||||
def recv_wait(cid):
|
||||
while True:
|
||||
try:
|
||||
return channels.recv(cid)
|
||||
except channels.ChannelEmptyError:
|
||||
time.sleep(0.1)
|
||||
|
||||
#@contextmanager
|
||||
#def run_threaded(id, source, **shared):
|
||||
# def run():
|
||||
|
@ -189,7 +196,7 @@ def run_action(cid, action, end, state, *, hideclosed=True):
|
|||
def _run_action(cid, action, end, state):
|
||||
if action == 'use':
|
||||
if end == 'send':
|
||||
channels.send(cid, b'spam')
|
||||
channels.send(cid, b'spam', blocking=False)
|
||||
return state.incr()
|
||||
elif end == 'recv':
|
||||
if not state.pending:
|
||||
|
@ -332,7 +339,7 @@ class ChannelIDTests(TestBase):
|
|||
chan = channels.create()
|
||||
|
||||
obj = channels.create()
|
||||
channels.send(chan, obj)
|
||||
channels.send(chan, obj, blocking=False)
|
||||
got = channels.recv(chan)
|
||||
|
||||
self.assertEqual(got, obj)
|
||||
|
@ -390,7 +397,7 @@ class ChannelTests(TestBase):
|
|||
"""Test basic listing channel interpreters."""
|
||||
interp0 = interpreters.get_main()
|
||||
cid = channels.create()
|
||||
channels.send(cid, "send")
|
||||
channels.send(cid, "send", blocking=False)
|
||||
# Test for a channel that has one end associated to an interpreter.
|
||||
send_interps = channels.list_interpreters(cid, send=True)
|
||||
recv_interps = channels.list_interpreters(cid, send=False)
|
||||
|
@ -416,10 +423,10 @@ class ChannelTests(TestBase):
|
|||
interp3 = interpreters.create()
|
||||
cid = channels.create()
|
||||
|
||||
channels.send(cid, "send")
|
||||
channels.send(cid, "send", blocking=False)
|
||||
_run_output(interp1, dedent(f"""
|
||||
import _xxinterpchannels as _channels
|
||||
_channels.send({cid}, "send")
|
||||
_channels.send({cid}, "send", blocking=False)
|
||||
"""))
|
||||
_run_output(interp2, dedent(f"""
|
||||
import _xxinterpchannels as _channels
|
||||
|
@ -439,7 +446,7 @@ class ChannelTests(TestBase):
|
|||
interp0 = interpreters.get_main()
|
||||
interp1 = interpreters.create()
|
||||
cid = channels.create()
|
||||
channels.send(cid, "send")
|
||||
channels.send(cid, "send", blocking=False)
|
||||
_run_output(interp1, dedent(f"""
|
||||
import _xxinterpchannels as _channels
|
||||
obj = _channels.recv({cid})
|
||||
|
@ -465,12 +472,12 @@ class ChannelTests(TestBase):
|
|||
interp1 = interpreters.create()
|
||||
interp2 = interpreters.create()
|
||||
cid = channels.create()
|
||||
channels.send(cid, "data")
|
||||
channels.send(cid, "data", blocking=False)
|
||||
_run_output(interp1, dedent(f"""
|
||||
import _xxinterpchannels as _channels
|
||||
obj = _channels.recv({cid})
|
||||
"""))
|
||||
channels.send(cid, "data")
|
||||
channels.send(cid, "data", blocking=False)
|
||||
_run_output(interp2, dedent(f"""
|
||||
import _xxinterpchannels as _channels
|
||||
obj = _channels.recv({cid})
|
||||
|
@ -506,7 +513,7 @@ class ChannelTests(TestBase):
|
|||
interp1 = interpreters.create()
|
||||
cid = channels.create()
|
||||
# Put something in the channel so that it's not empty.
|
||||
channels.send(cid, "send")
|
||||
channels.send(cid, "send", blocking=False)
|
||||
|
||||
# Check initial state.
|
||||
send_interps = channels.list_interpreters(cid, send=True)
|
||||
|
@ -528,7 +535,7 @@ class ChannelTests(TestBase):
|
|||
interp1 = interpreters.create()
|
||||
cid = channels.create()
|
||||
# Put something in the channel so that it's not empty.
|
||||
channels.send(cid, "send")
|
||||
channels.send(cid, "send", blocking=False)
|
||||
|
||||
# Check initial state.
|
||||
send_interps = channels.list_interpreters(cid, send=True)
|
||||
|
@ -562,7 +569,7 @@ class ChannelTests(TestBase):
|
|||
def test_send_recv_main(self):
|
||||
cid = channels.create()
|
||||
orig = b'spam'
|
||||
channels.send(cid, orig)
|
||||
channels.send(cid, orig, blocking=False)
|
||||
obj = channels.recv(cid)
|
||||
|
||||
self.assertEqual(obj, orig)
|
||||
|
@ -574,7 +581,7 @@ class ChannelTests(TestBase):
|
|||
import _xxinterpchannels as _channels
|
||||
cid = _channels.create()
|
||||
orig = b'spam'
|
||||
_channels.send(cid, orig)
|
||||
_channels.send(cid, orig, blocking=False)
|
||||
obj = _channels.recv(cid)
|
||||
assert obj is not orig
|
||||
assert obj == orig
|
||||
|
@ -585,7 +592,7 @@ class ChannelTests(TestBase):
|
|||
id1 = interpreters.create()
|
||||
out = _run_output(id1, dedent(f"""
|
||||
import _xxinterpchannels as _channels
|
||||
_channels.send({cid}, b'spam')
|
||||
_channels.send({cid}, b'spam', blocking=False)
|
||||
"""))
|
||||
obj = channels.recv(cid)
|
||||
|
||||
|
@ -595,19 +602,14 @@ class ChannelTests(TestBase):
|
|||
cid = channels.create()
|
||||
|
||||
def f():
|
||||
while True:
|
||||
try:
|
||||
obj = channels.recv(cid)
|
||||
break
|
||||
except channels.ChannelEmptyError:
|
||||
time.sleep(0.1)
|
||||
obj = recv_wait(cid)
|
||||
channels.send(cid, obj)
|
||||
t = threading.Thread(target=f)
|
||||
t.start()
|
||||
|
||||
channels.send(cid, b'spam')
|
||||
obj = recv_wait(cid)
|
||||
t.join()
|
||||
obj = channels.recv(cid)
|
||||
|
||||
self.assertEqual(obj, b'spam')
|
||||
|
||||
|
@ -634,8 +636,8 @@ class ChannelTests(TestBase):
|
|||
t.start()
|
||||
|
||||
channels.send(cid, b'spam')
|
||||
obj = recv_wait(cid)
|
||||
t.join()
|
||||
obj = channels.recv(cid)
|
||||
|
||||
self.assertEqual(obj, b'eggs')
|
||||
|
||||
|
@ -656,10 +658,10 @@ class ChannelTests(TestBase):
|
|||
default = object()
|
||||
cid = channels.create()
|
||||
obj1 = channels.recv(cid, default)
|
||||
channels.send(cid, None)
|
||||
channels.send(cid, 1)
|
||||
channels.send(cid, b'spam')
|
||||
channels.send(cid, b'eggs')
|
||||
channels.send(cid, None, blocking=False)
|
||||
channels.send(cid, 1, blocking=False)
|
||||
channels.send(cid, b'spam', blocking=False)
|
||||
channels.send(cid, b'eggs', blocking=False)
|
||||
obj2 = channels.recv(cid, default)
|
||||
obj3 = channels.recv(cid, default)
|
||||
obj4 = channels.recv(cid)
|
||||
|
@ -679,7 +681,7 @@ class ChannelTests(TestBase):
|
|||
interp = interpreters.create()
|
||||
interpreters.run_string(interp, dedent(f"""
|
||||
import _xxinterpchannels as _channels
|
||||
_channels.send({cid1}, b'spam')
|
||||
_channels.send({cid1}, b'spam', blocking=False)
|
||||
"""))
|
||||
interpreters.destroy(interp)
|
||||
|
||||
|
@ -692,9 +694,9 @@ class ChannelTests(TestBase):
|
|||
interp = interpreters.create()
|
||||
interpreters.run_string(interp, dedent(f"""
|
||||
import _xxinterpchannels as _channels
|
||||
_channels.send({cid2}, b'spam')
|
||||
_channels.send({cid2}, b'spam', blocking=False)
|
||||
"""))
|
||||
channels.send(cid2, b'eggs')
|
||||
channels.send(cid2, b'eggs', blocking=False)
|
||||
interpreters.destroy(interp)
|
||||
|
||||
channels.recv(cid2)
|
||||
|
@ -706,7 +708,7 @@ class ChannelTests(TestBase):
|
|||
def test_send_buffer(self):
|
||||
buf = bytearray(b'spamspamspam')
|
||||
cid = channels.create()
|
||||
channels.send_buffer(cid, buf)
|
||||
channels.send_buffer(cid, buf, blocking=False)
|
||||
obj = channels.recv(cid)
|
||||
|
||||
self.assertIsNot(obj, buf)
|
||||
|
@ -728,7 +730,7 @@ class ChannelTests(TestBase):
|
|||
]
|
||||
for obj in objects:
|
||||
with self.subTest(obj):
|
||||
channels.send(cid, obj)
|
||||
channels.send(cid, obj, blocking=False)
|
||||
got = channels.recv(cid)
|
||||
|
||||
self.assertEqual(got, obj)
|
||||
|
@ -744,7 +746,7 @@ class ChannelTests(TestBase):
|
|||
out = _run_output(interp, dedent("""
|
||||
import _xxinterpchannels as _channels
|
||||
print(cid.end)
|
||||
_channels.send(cid, b'spam')
|
||||
_channels.send(cid, b'spam', blocking=False)
|
||||
"""),
|
||||
dict(cid=cid.send))
|
||||
obj = channels.recv(cid)
|
||||
|
@ -764,7 +766,7 @@ class ChannelTests(TestBase):
|
|||
out = _run_output(interp, dedent("""
|
||||
import _xxinterpchannels as _channels
|
||||
print(chan.id.end)
|
||||
_channels.send(chan.id, b'spam')
|
||||
_channels.send(chan.id, b'spam', blocking=False)
|
||||
"""),
|
||||
dict(chan=cid.send))
|
||||
obj = channels.recv(cid)
|
||||
|
@ -776,7 +778,7 @@ class ChannelTests(TestBase):
|
|||
|
||||
def test_close_single_user(self):
|
||||
cid = channels.create()
|
||||
channels.send(cid, b'spam')
|
||||
channels.send(cid, b'spam', blocking=False)
|
||||
channels.recv(cid)
|
||||
channels.close(cid)
|
||||
|
||||
|
@ -791,7 +793,7 @@ class ChannelTests(TestBase):
|
|||
id2 = interpreters.create()
|
||||
interpreters.run_string(id1, dedent(f"""
|
||||
import _xxinterpchannels as _channels
|
||||
_channels.send({cid}, b'spam')
|
||||
_channels.send({cid}, b'spam', blocking=False)
|
||||
"""))
|
||||
interpreters.run_string(id2, dedent(f"""
|
||||
import _xxinterpchannels as _channels
|
||||
|
@ -811,7 +813,7 @@ class ChannelTests(TestBase):
|
|||
|
||||
def test_close_multiple_times(self):
|
||||
cid = channels.create()
|
||||
channels.send(cid, b'spam')
|
||||
channels.send(cid, b'spam', blocking=False)
|
||||
channels.recv(cid)
|
||||
channels.close(cid)
|
||||
|
||||
|
@ -828,7 +830,7 @@ class ChannelTests(TestBase):
|
|||
for send, recv in tests:
|
||||
with self.subTest((send, recv)):
|
||||
cid = channels.create()
|
||||
channels.send(cid, b'spam')
|
||||
channels.send(cid, b'spam', blocking=False)
|
||||
channels.recv(cid)
|
||||
channels.close(cid, send=send, recv=recv)
|
||||
|
||||
|
@ -839,31 +841,31 @@ class ChannelTests(TestBase):
|
|||
|
||||
def test_close_defaults_with_unused_items(self):
|
||||
cid = channels.create()
|
||||
channels.send(cid, b'spam')
|
||||
channels.send(cid, b'ham')
|
||||
channels.send(cid, b'spam', blocking=False)
|
||||
channels.send(cid, b'ham', blocking=False)
|
||||
|
||||
with self.assertRaises(channels.ChannelNotEmptyError):
|
||||
channels.close(cid)
|
||||
channels.recv(cid)
|
||||
channels.send(cid, b'eggs')
|
||||
channels.send(cid, b'eggs', blocking=False)
|
||||
|
||||
def test_close_recv_with_unused_items_unforced(self):
|
||||
cid = channels.create()
|
||||
channels.send(cid, b'spam')
|
||||
channels.send(cid, b'ham')
|
||||
channels.send(cid, b'spam', blocking=False)
|
||||
channels.send(cid, b'ham', blocking=False)
|
||||
|
||||
with self.assertRaises(channels.ChannelNotEmptyError):
|
||||
channels.close(cid, recv=True)
|
||||
channels.recv(cid)
|
||||
channels.send(cid, b'eggs')
|
||||
channels.send(cid, b'eggs', blocking=False)
|
||||
channels.recv(cid)
|
||||
channels.recv(cid)
|
||||
channels.close(cid, recv=True)
|
||||
|
||||
def test_close_send_with_unused_items_unforced(self):
|
||||
cid = channels.create()
|
||||
channels.send(cid, b'spam')
|
||||
channels.send(cid, b'ham')
|
||||
channels.send(cid, b'spam', blocking=False)
|
||||
channels.send(cid, b'ham', blocking=False)
|
||||
channels.close(cid, send=True)
|
||||
|
||||
with self.assertRaises(channels.ChannelClosedError):
|
||||
|
@ -875,21 +877,21 @@ class ChannelTests(TestBase):
|
|||
|
||||
def test_close_both_with_unused_items_unforced(self):
|
||||
cid = channels.create()
|
||||
channels.send(cid, b'spam')
|
||||
channels.send(cid, b'ham')
|
||||
channels.send(cid, b'spam', blocking=False)
|
||||
channels.send(cid, b'ham', blocking=False)
|
||||
|
||||
with self.assertRaises(channels.ChannelNotEmptyError):
|
||||
channels.close(cid, recv=True, send=True)
|
||||
channels.recv(cid)
|
||||
channels.send(cid, b'eggs')
|
||||
channels.send(cid, b'eggs', blocking=False)
|
||||
channels.recv(cid)
|
||||
channels.recv(cid)
|
||||
channels.close(cid, recv=True)
|
||||
|
||||
def test_close_recv_with_unused_items_forced(self):
|
||||
cid = channels.create()
|
||||
channels.send(cid, b'spam')
|
||||
channels.send(cid, b'ham')
|
||||
channels.send(cid, b'spam', blocking=False)
|
||||
channels.send(cid, b'ham', blocking=False)
|
||||
channels.close(cid, recv=True, force=True)
|
||||
|
||||
with self.assertRaises(channels.ChannelClosedError):
|
||||
|
@ -899,8 +901,8 @@ class ChannelTests(TestBase):
|
|||
|
||||
def test_close_send_with_unused_items_forced(self):
|
||||
cid = channels.create()
|
||||
channels.send(cid, b'spam')
|
||||
channels.send(cid, b'ham')
|
||||
channels.send(cid, b'spam', blocking=False)
|
||||
channels.send(cid, b'ham', blocking=False)
|
||||
channels.close(cid, send=True, force=True)
|
||||
|
||||
with self.assertRaises(channels.ChannelClosedError):
|
||||
|
@ -910,8 +912,8 @@ class ChannelTests(TestBase):
|
|||
|
||||
def test_close_both_with_unused_items_forced(self):
|
||||
cid = channels.create()
|
||||
channels.send(cid, b'spam')
|
||||
channels.send(cid, b'ham')
|
||||
channels.send(cid, b'spam', blocking=False)
|
||||
channels.send(cid, b'ham', blocking=False)
|
||||
channels.close(cid, send=True, recv=True, force=True)
|
||||
|
||||
with self.assertRaises(channels.ChannelClosedError):
|
||||
|
@ -930,7 +932,7 @@ class ChannelTests(TestBase):
|
|||
|
||||
def test_close_by_unassociated_interp(self):
|
||||
cid = channels.create()
|
||||
channels.send(cid, b'spam')
|
||||
channels.send(cid, b'spam', blocking=False)
|
||||
interp = interpreters.create()
|
||||
interpreters.run_string(interp, dedent(f"""
|
||||
import _xxinterpchannels as _channels
|
||||
|
@ -943,9 +945,9 @@ class ChannelTests(TestBase):
|
|||
|
||||
def test_close_used_multiple_times_by_single_user(self):
|
||||
cid = channels.create()
|
||||
channels.send(cid, b'spam')
|
||||
channels.send(cid, b'spam')
|
||||
channels.send(cid, b'spam')
|
||||
channels.send(cid, b'spam', blocking=False)
|
||||
channels.send(cid, b'spam', blocking=False)
|
||||
channels.send(cid, b'spam', blocking=False)
|
||||
channels.recv(cid)
|
||||
channels.close(cid, force=True)
|
||||
|
||||
|
@ -1017,7 +1019,7 @@ class ChannelReleaseTests(TestBase):
|
|||
|
||||
def test_single_user(self):
|
||||
cid = channels.create()
|
||||
channels.send(cid, b'spam')
|
||||
channels.send(cid, b'spam', blocking=False)
|
||||
channels.recv(cid)
|
||||
channels.release(cid, send=True, recv=True)
|
||||
|
||||
|
@ -1032,7 +1034,7 @@ class ChannelReleaseTests(TestBase):
|
|||
id2 = interpreters.create()
|
||||
interpreters.run_string(id1, dedent(f"""
|
||||
import _xxinterpchannels as _channels
|
||||
_channels.send({cid}, b'spam')
|
||||
_channels.send({cid}, b'spam', blocking=False)
|
||||
"""))
|
||||
out = _run_output(id2, dedent(f"""
|
||||
import _xxinterpchannels as _channels
|
||||
|
@ -1048,7 +1050,7 @@ class ChannelReleaseTests(TestBase):
|
|||
|
||||
def test_no_kwargs(self):
|
||||
cid = channels.create()
|
||||
channels.send(cid, b'spam')
|
||||
channels.send(cid, b'spam', blocking=False)
|
||||
channels.recv(cid)
|
||||
channels.release(cid)
|
||||
|
||||
|
@ -1059,7 +1061,7 @@ class ChannelReleaseTests(TestBase):
|
|||
|
||||
def test_multiple_times(self):
|
||||
cid = channels.create()
|
||||
channels.send(cid, b'spam')
|
||||
channels.send(cid, b'spam', blocking=False)
|
||||
channels.recv(cid)
|
||||
channels.release(cid, send=True, recv=True)
|
||||
|
||||
|
@ -1068,8 +1070,8 @@ class ChannelReleaseTests(TestBase):
|
|||
|
||||
def test_with_unused_items(self):
|
||||
cid = channels.create()
|
||||
channels.send(cid, b'spam')
|
||||
channels.send(cid, b'ham')
|
||||
channels.send(cid, b'spam', blocking=False)
|
||||
channels.send(cid, b'ham', blocking=False)
|
||||
channels.release(cid, send=True, recv=True)
|
||||
|
||||
with self.assertRaises(channels.ChannelClosedError):
|
||||
|
@ -1086,7 +1088,7 @@ class ChannelReleaseTests(TestBase):
|
|||
|
||||
def test_by_unassociated_interp(self):
|
||||
cid = channels.create()
|
||||
channels.send(cid, b'spam')
|
||||
channels.send(cid, b'spam', blocking=False)
|
||||
interp = interpreters.create()
|
||||
interpreters.run_string(interp, dedent(f"""
|
||||
import _xxinterpchannels as _channels
|
||||
|
@ -1105,7 +1107,7 @@ class ChannelReleaseTests(TestBase):
|
|||
interp = interpreters.create()
|
||||
interpreters.run_string(interp, dedent(f"""
|
||||
import _xxinterpchannels as _channels
|
||||
obj = _channels.send({cid}, b'spam')
|
||||
obj = _channels.send({cid}, b'spam', blocking=False)
|
||||
_channels.release({cid})
|
||||
"""))
|
||||
|
||||
|
@ -1115,9 +1117,9 @@ class ChannelReleaseTests(TestBase):
|
|||
def test_partially(self):
|
||||
# XXX Is partial close too weird/confusing?
|
||||
cid = channels.create()
|
||||
channels.send(cid, None)
|
||||
channels.send(cid, None, blocking=False)
|
||||
channels.recv(cid)
|
||||
channels.send(cid, b'spam')
|
||||
channels.send(cid, b'spam', blocking=False)
|
||||
channels.release(cid, send=True)
|
||||
obj = channels.recv(cid)
|
||||
|
||||
|
@ -1125,9 +1127,9 @@ class ChannelReleaseTests(TestBase):
|
|||
|
||||
def test_used_multiple_times_by_single_user(self):
|
||||
cid = channels.create()
|
||||
channels.send(cid, b'spam')
|
||||
channels.send(cid, b'spam')
|
||||
channels.send(cid, b'spam')
|
||||
channels.send(cid, b'spam', blocking=False)
|
||||
channels.send(cid, b'spam', blocking=False)
|
||||
channels.send(cid, b'spam', blocking=False)
|
||||
channels.recv(cid)
|
||||
channels.release(cid, send=True, recv=True)
|
||||
|
||||
|
@ -1212,7 +1214,7 @@ class ChannelCloseFixture(namedtuple('ChannelCloseFixture',
|
|||
cid = _xxsubchannels.create()
|
||||
# We purposefully send back an int to avoid tying the
|
||||
# channel to the other interpreter.
|
||||
_xxsubchannels.send({ch}, int(cid))
|
||||
_xxsubchannels.send({ch}, int(cid), blocking=False)
|
||||
del _xxsubinterpreters
|
||||
""")
|
||||
self._cid = channels.recv(ch)
|
||||
|
@ -1442,8 +1444,8 @@ class ExhaustiveChannelTests(TestBase):
|
|||
{repr(fix.state)},
|
||||
hideclosed={hideclosed},
|
||||
)
|
||||
channels.send({_cid}, result.pending.to_bytes(1, 'little'))
|
||||
channels.send({_cid}, b'X' if result.closed else b'')
|
||||
channels.send({_cid}, result.pending.to_bytes(1, 'little'), blocking=False)
|
||||
channels.send({_cid}, b'X' if result.closed else b'', blocking=False)
|
||||
""")
|
||||
result = ChannelState(
|
||||
pending=int.from_bytes(channels.recv(_cid), 'little'),
|
||||
|
@ -1490,7 +1492,7 @@ class ExhaustiveChannelTests(TestBase):
|
|||
""")
|
||||
run_interp(interp.id, """
|
||||
with helpers.expect_channel_closed():
|
||||
channels.send(cid, b'spam')
|
||||
channels.send(cid, b'spam', blocking=False)
|
||||
""")
|
||||
run_interp(interp.id, """
|
||||
with helpers.expect_channel_closed():
|
||||
|
|
|
@ -964,8 +964,8 @@ class TestSendRecv(TestBase):
|
|||
|
||||
orig = b'spam'
|
||||
s.send(orig)
|
||||
t.join()
|
||||
obj = r.recv()
|
||||
t.join()
|
||||
|
||||
self.assertEqual(obj, orig)
|
||||
self.assertIsNot(obj, orig)
|
||||
|
|
|
@ -234,6 +234,17 @@ add_new_type(PyObject *mod, PyType_Spec *spec, crossinterpdatafunc shared,
|
|||
return cls;
|
||||
}
|
||||
|
||||
static void
|
||||
wait_for_lock(PyThread_type_lock mutex)
|
||||
{
|
||||
Py_BEGIN_ALLOW_THREADS
|
||||
// XXX Handle eintr, etc.
|
||||
PyThread_acquire_lock(mutex, WAIT_LOCK);
|
||||
Py_END_ALLOW_THREADS
|
||||
|
||||
PyThread_release_lock(mutex);
|
||||
}
|
||||
|
||||
|
||||
/* Cross-interpreter Buffer Views *******************************************/
|
||||
|
||||
|
@ -567,6 +578,7 @@ struct _channelitem;
|
|||
|
||||
typedef struct _channelitem {
|
||||
_PyCrossInterpreterData *data;
|
||||
PyThread_type_lock recv_mutex;
|
||||
struct _channelitem *next;
|
||||
} _channelitem;
|
||||
|
||||
|
@ -612,10 +624,11 @@ _channelitem_free_all(_channelitem *item)
|
|||
}
|
||||
|
||||
static _PyCrossInterpreterData *
|
||||
_channelitem_popped(_channelitem *item)
|
||||
_channelitem_popped(_channelitem *item, PyThread_type_lock *recv_mutex)
|
||||
{
|
||||
_PyCrossInterpreterData *data = item->data;
|
||||
item->data = NULL;
|
||||
*recv_mutex = item->recv_mutex;
|
||||
_channelitem_free(item);
|
||||
return data;
|
||||
}
|
||||
|
@ -657,13 +670,15 @@ _channelqueue_free(_channelqueue *queue)
|
|||
}
|
||||
|
||||
static int
|
||||
_channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data)
|
||||
_channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data,
|
||||
PyThread_type_lock recv_mutex)
|
||||
{
|
||||
_channelitem *item = _channelitem_new();
|
||||
if (item == NULL) {
|
||||
return -1;
|
||||
}
|
||||
item->data = data;
|
||||
item->recv_mutex = recv_mutex;
|
||||
|
||||
queue->count += 1;
|
||||
if (queue->first == NULL) {
|
||||
|
@ -677,7 +692,7 @@ _channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data)
|
|||
}
|
||||
|
||||
static _PyCrossInterpreterData *
|
||||
_channelqueue_get(_channelqueue *queue)
|
||||
_channelqueue_get(_channelqueue *queue, PyThread_type_lock *recv_mutex)
|
||||
{
|
||||
_channelitem *item = queue->first;
|
||||
if (item == NULL) {
|
||||
|
@ -689,7 +704,7 @@ _channelqueue_get(_channelqueue *queue)
|
|||
}
|
||||
queue->count -= 1;
|
||||
|
||||
return _channelitem_popped(item);
|
||||
return _channelitem_popped(item, recv_mutex);
|
||||
}
|
||||
|
||||
static void
|
||||
|
@ -1006,7 +1021,7 @@ _channel_free(_PyChannelState *chan)
|
|||
|
||||
static int
|
||||
_channel_add(_PyChannelState *chan, int64_t interp,
|
||||
_PyCrossInterpreterData *data)
|
||||
_PyCrossInterpreterData *data, PyThread_type_lock recv_mutex)
|
||||
{
|
||||
int res = -1;
|
||||
PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
|
||||
|
@ -1020,7 +1035,7 @@ _channel_add(_PyChannelState *chan, int64_t interp,
|
|||
goto done;
|
||||
}
|
||||
|
||||
if (_channelqueue_put(chan->queue, data) != 0) {
|
||||
if (_channelqueue_put(chan->queue, data, recv_mutex) != 0) {
|
||||
goto done;
|
||||
}
|
||||
|
||||
|
@ -1046,12 +1061,17 @@ _channel_next(_PyChannelState *chan, int64_t interp,
|
|||
goto done;
|
||||
}
|
||||
|
||||
_PyCrossInterpreterData *data = _channelqueue_get(chan->queue);
|
||||
PyThread_type_lock recv_mutex = NULL;
|
||||
_PyCrossInterpreterData *data = _channelqueue_get(chan->queue, &recv_mutex);
|
||||
if (data == NULL && !PyErr_Occurred() && chan->closing != NULL) {
|
||||
chan->open = 0;
|
||||
}
|
||||
*res = data;
|
||||
|
||||
if (recv_mutex != NULL) {
|
||||
PyThread_release_lock(recv_mutex);
|
||||
}
|
||||
|
||||
done:
|
||||
PyThread_release_lock(chan->mutex);
|
||||
if (chan->queue->count == 0) {
|
||||
|
@ -1571,7 +1591,8 @@ _channel_destroy(_channels *channels, int64_t id)
|
|||
}
|
||||
|
||||
static int
|
||||
_channel_send(_channels *channels, int64_t id, PyObject *obj)
|
||||
_channel_send(_channels *channels, int64_t id, PyObject *obj,
|
||||
PyThread_type_lock recv_mutex)
|
||||
{
|
||||
PyInterpreterState *interp = _get_current_interp();
|
||||
if (interp == NULL) {
|
||||
|
@ -1606,7 +1627,8 @@ _channel_send(_channels *channels, int64_t id, PyObject *obj)
|
|||
}
|
||||
|
||||
// Add the data to the channel.
|
||||
int res = _channel_add(chan, PyInterpreterState_GetID(interp), data);
|
||||
int res = _channel_add(chan, PyInterpreterState_GetID(interp), data,
|
||||
recv_mutex);
|
||||
PyThread_release_lock(mutex);
|
||||
if (res != 0) {
|
||||
// We may chain an exception here:
|
||||
|
@ -2489,42 +2511,70 @@ receive end.");
|
|||
static PyObject *
|
||||
channel_send(PyObject *self, PyObject *args, PyObject *kwds)
|
||||
{
|
||||
static char *kwlist[] = {"cid", "obj", NULL};
|
||||
// XXX Add a timeout arg.
|
||||
static char *kwlist[] = {"cid", "obj", "blocking", NULL};
|
||||
int64_t cid;
|
||||
struct channel_id_converter_data cid_data = {
|
||||
.module = self,
|
||||
};
|
||||
PyObject *obj;
|
||||
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O:channel_send", kwlist,
|
||||
channel_id_converter, &cid_data, &obj)) {
|
||||
int blocking = 1;
|
||||
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O|$p:channel_send", kwlist,
|
||||
channel_id_converter, &cid_data, &obj,
|
||||
&blocking)) {
|
||||
return NULL;
|
||||
}
|
||||
cid = cid_data.cid;
|
||||
|
||||
int err = _channel_send(&_globals.channels, cid, obj);
|
||||
if (handle_channel_error(err, self, cid)) {
|
||||
return NULL;
|
||||
if (blocking) {
|
||||
PyThread_type_lock mutex = PyThread_allocate_lock();
|
||||
if (mutex == NULL) {
|
||||
PyErr_NoMemory();
|
||||
return NULL;
|
||||
}
|
||||
PyThread_acquire_lock(mutex, WAIT_LOCK);
|
||||
|
||||
/* Queue up the object. */
|
||||
int err = _channel_send(&_globals.channels, cid, obj, mutex);
|
||||
if (handle_channel_error(err, self, cid)) {
|
||||
PyThread_release_lock(mutex);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
/* Wait until the object is received. */
|
||||
wait_for_lock(mutex);
|
||||
}
|
||||
else {
|
||||
/* Queue up the object. */
|
||||
int err = _channel_send(&_globals.channels, cid, obj, NULL);
|
||||
if (handle_channel_error(err, self, cid)) {
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
PyDoc_STRVAR(channel_send_doc,
|
||||
"channel_send(cid, obj)\n\
|
||||
"channel_send(cid, obj, blocking=True)\n\
|
||||
\n\
|
||||
Add the object's data to the channel's queue.");
|
||||
Add the object's data to the channel's queue.\n\
|
||||
By default this waits for the object to be received.");
|
||||
|
||||
static PyObject *
|
||||
channel_send_buffer(PyObject *self, PyObject *args, PyObject *kwds)
|
||||
{
|
||||
static char *kwlist[] = {"cid", "obj", NULL};
|
||||
static char *kwlist[] = {"cid", "obj", "blocking", NULL};
|
||||
int64_t cid;
|
||||
struct channel_id_converter_data cid_data = {
|
||||
.module = self,
|
||||
};
|
||||
PyObject *obj;
|
||||
int blocking = 1;
|
||||
if (!PyArg_ParseTupleAndKeywords(args, kwds,
|
||||
"O&O:channel_send_buffer", kwlist,
|
||||
channel_id_converter, &cid_data, &obj)) {
|
||||
"O&O|$p:channel_send_buffer", kwlist,
|
||||
channel_id_converter, &cid_data, &obj,
|
||||
&blocking)) {
|
||||
return NULL;
|
||||
}
|
||||
cid = cid_data.cid;
|
||||
|
@ -2534,18 +2584,43 @@ channel_send_buffer(PyObject *self, PyObject *args, PyObject *kwds)
|
|||
return NULL;
|
||||
}
|
||||
|
||||
int err = _channel_send(&_globals.channels, cid, tempobj);
|
||||
Py_DECREF(tempobj);
|
||||
if (handle_channel_error(err, self, cid)) {
|
||||
return NULL;
|
||||
if (blocking) {
|
||||
PyThread_type_lock mutex = PyThread_allocate_lock();
|
||||
if (mutex == NULL) {
|
||||
Py_DECREF(tempobj);
|
||||
PyErr_NoMemory();
|
||||
return NULL;
|
||||
}
|
||||
PyThread_acquire_lock(mutex, WAIT_LOCK);
|
||||
|
||||
/* Queue up the buffer. */
|
||||
int err = _channel_send(&_globals.channels, cid, tempobj, mutex);
|
||||
Py_DECREF(tempobj);
|
||||
if (handle_channel_error(err, self, cid)) {
|
||||
PyThread_acquire_lock(mutex, WAIT_LOCK);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
/* Wait until the buffer is received. */
|
||||
wait_for_lock(mutex);
|
||||
}
|
||||
else {
|
||||
/* Queue up the buffer. */
|
||||
int err = _channel_send(&_globals.channels, cid, tempobj, NULL);
|
||||
Py_DECREF(tempobj);
|
||||
if (handle_channel_error(err, self, cid)) {
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
PyDoc_STRVAR(channel_send_buffer_doc,
|
||||
"channel_send_buffer(cid, obj)\n\
|
||||
"channel_send_buffer(cid, obj, blocking=True)\n\
|
||||
\n\
|
||||
Add the object's buffer to the channel's queue.");
|
||||
Add the object's buffer to the channel's queue.\n\
|
||||
By default this waits for the object to be received.");
|
||||
|
||||
static PyObject *
|
||||
channel_recv(PyObject *self, PyObject *args, PyObject *kwds)
|
||||
|
|
Loading…
Reference in New Issue