cpython/Lib/test/test__xxsubinterpreters.py

1230 lines
39 KiB
Python
Raw Normal View History

import contextlib
import os
import pickle
from textwrap import dedent, indent
import threading
import time
import unittest
from test import support
from test.support import script_helper
interpreters = support.import_module('_xxsubinterpreters')
def _captured_script(script):
r, w = os.pipe()
indented = script.replace('\n', '\n ')
wrapped = dedent(f"""
import contextlib
with open({w}, 'w') as chan:
with contextlib.redirect_stdout(chan):
{indented}
""")
return wrapped, open(r)
def _run_output(interp, request, shared=None):
script, chan = _captured_script(request)
with chan:
interpreters.run_string(interp, script, shared)
return chan.read()
@contextlib.contextmanager
def _running(interp):
r, w = os.pipe()
def run():
interpreters.run_string(interp, dedent(f"""
# wait for "signal"
with open({r}) as chan:
chan.read()
"""))
t = threading.Thread(target=run)
t.start()
yield
with open(w, 'w') as chan:
chan.write('done')
t.join()
class IsShareableTests(unittest.TestCase):
def test_default_shareables(self):
shareables = [
# singletons
None,
# builtin objects
b'spam',
]
for obj in shareables:
with self.subTest(obj):
self.assertTrue(
interpreters.is_shareable(obj))
def test_not_shareable(self):
class Cheese:
def __init__(self, name):
self.name = name
def __str__(self):
return self.name
class SubBytes(bytes):
"""A subclass of a shareable type."""
not_shareables = [
# singletons
True,
False,
NotImplemented,
...,
# builtin types and objects
type,
object,
object(),
Exception(),
42,
100.0,
'spam',
# user-defined types and objects
Cheese,
Cheese('Wensleydale'),
SubBytes(b'spam'),
]
for obj in not_shareables:
with self.subTest(obj):
self.assertFalse(
interpreters.is_shareable(obj))
class TestBase(unittest.TestCase):
def tearDown(self):
for id in interpreters.list_all():
if id == 0: # main
continue
try:
interpreters.destroy(id)
except RuntimeError:
pass # already destroyed
for cid in interpreters.channel_list_all():
try:
interpreters.channel_destroy(cid)
except interpreters.ChannelNotFoundError:
pass # already destroyed
class ListAllTests(TestBase):
def test_initial(self):
main = interpreters.get_main()
ids = interpreters.list_all()
self.assertEqual(ids, [main])
def test_after_creating(self):
main = interpreters.get_main()
first = interpreters.create()
second = interpreters.create()
ids = interpreters.list_all()
self.assertEqual(ids, [main, first, second])
def test_after_destroying(self):
main = interpreters.get_main()
first = interpreters.create()
second = interpreters.create()
interpreters.destroy(first)
ids = interpreters.list_all()
self.assertEqual(ids, [main, second])
class GetCurrentTests(TestBase):
def test_main(self):
main = interpreters.get_main()
cur = interpreters.get_current()
self.assertEqual(cur, main)
def test_subinterpreter(self):
main = interpreters.get_main()
interp = interpreters.create()
out = _run_output(interp, dedent("""
import _xxsubinterpreters as _interpreters
print(int(_interpreters.get_current()))
"""))
cur = int(out.strip())
_, expected = interpreters.list_all()
self.assertEqual(cur, expected)
self.assertNotEqual(cur, main)
class GetMainTests(TestBase):
def test_from_main(self):
[expected] = interpreters.list_all()
main = interpreters.get_main()
self.assertEqual(main, expected)
def test_from_subinterpreter(self):
[expected] = interpreters.list_all()
interp = interpreters.create()
out = _run_output(interp, dedent("""
import _xxsubinterpreters as _interpreters
print(int(_interpreters.get_main()))
"""))
main = int(out.strip())
self.assertEqual(main, expected)
class IsRunningTests(TestBase):
def test_main(self):
main = interpreters.get_main()
self.assertTrue(interpreters.is_running(main))
def test_subinterpreter(self):
interp = interpreters.create()
self.assertFalse(interpreters.is_running(interp))
with _running(interp):
self.assertTrue(interpreters.is_running(interp))
self.assertFalse(interpreters.is_running(interp))
def test_from_subinterpreter(self):
interp = interpreters.create()
out = _run_output(interp, dedent(f"""
import _xxsubinterpreters as _interpreters
if _interpreters.is_running({int(interp)}):
print(True)
else:
print(False)
"""))
self.assertEqual(out.strip(), 'True')
def test_already_destroyed(self):
interp = interpreters.create()
interpreters.destroy(interp)
with self.assertRaises(RuntimeError):
interpreters.is_running(interp)
def test_does_not_exist(self):
with self.assertRaises(RuntimeError):
interpreters.is_running(1_000_000)
def test_bad_id(self):
with self.assertRaises(RuntimeError):
interpreters.is_running(-1)
class InterpreterIDTests(TestBase):
def test_with_int(self):
id = interpreters.InterpreterID(10, force=True)
self.assertEqual(int(id), 10)
def test_coerce_id(self):
id = interpreters.InterpreterID('10', force=True)
self.assertEqual(int(id), 10)
id = interpreters.InterpreterID(10.0, force=True)
self.assertEqual(int(id), 10)
class Int(str):
def __init__(self, value):
self._value = value
def __int__(self):
return self._value
id = interpreters.InterpreterID(Int(10), force=True)
self.assertEqual(int(id), 10)
def test_bad_id(self):
for id in [-1, 'spam']:
with self.subTest(id):
with self.assertRaises(ValueError):
interpreters.InterpreterID(id)
with self.assertRaises(OverflowError):
interpreters.InterpreterID(2**64)
with self.assertRaises(TypeError):
interpreters.InterpreterID(object())
def test_does_not_exist(self):
id = interpreters.channel_create()
with self.assertRaises(RuntimeError):
interpreters.InterpreterID(int(id) + 1) # unforced
def test_repr(self):
id = interpreters.InterpreterID(10, force=True)
self.assertEqual(repr(id), 'InterpreterID(10)')
def test_equality(self):
id1 = interpreters.create()
id2 = interpreters.InterpreterID(int(id1))
id3 = interpreters.create()
self.assertTrue(id1 == id1)
self.assertTrue(id1 == id2)
self.assertTrue(id1 == int(id1))
self.assertFalse(id1 == id3)
self.assertFalse(id1 != id1)
self.assertFalse(id1 != id2)
self.assertTrue(id1 != id3)
class CreateTests(TestBase):
def test_in_main(self):
id = interpreters.create()
self.assertIn(id, interpreters.list_all())
@unittest.skip('enable this test when working on pystate.c')
def test_unique_id(self):
seen = set()
for _ in range(100):
id = interpreters.create()
interpreters.destroy(id)
seen.add(id)
self.assertEqual(len(seen), 100)
def test_in_thread(self):
lock = threading.Lock()
id = None
def f():
nonlocal id
id = interpreters.create()
lock.acquire()
lock.release()
t = threading.Thread(target=f)
with lock:
t.start()
t.join()
self.assertIn(id, interpreters.list_all())
def test_in_subinterpreter(self):
main, = interpreters.list_all()
id1 = interpreters.create()
out = _run_output(id1, dedent("""
import _xxsubinterpreters as _interpreters
id = _interpreters.create()
print(int(id))
"""))
id2 = int(out.strip())
self.assertEqual(set(interpreters.list_all()), {main, id1, id2})
def test_in_threaded_subinterpreter(self):
main, = interpreters.list_all()
id1 = interpreters.create()
id2 = None
def f():
nonlocal id2
out = _run_output(id1, dedent("""
import _xxsubinterpreters as _interpreters
id = _interpreters.create()
print(int(id))
"""))
id2 = int(out.strip())
t = threading.Thread(target=f)
t.start()
t.join()
self.assertEqual(set(interpreters.list_all()), {main, id1, id2})
def test_after_destroy_all(self):
before = set(interpreters.list_all())
# Create 3 subinterpreters.
ids = []
for _ in range(3):
id = interpreters.create()
ids.append(id)
# Now destroy them.
for id in ids:
interpreters.destroy(id)
# Finally, create another.
id = interpreters.create()
self.assertEqual(set(interpreters.list_all()), before | {id})
def test_after_destroy_some(self):
before = set(interpreters.list_all())
# Create 3 subinterpreters.
id1 = interpreters.create()
id2 = interpreters.create()
id3 = interpreters.create()
# Now destroy 2 of them.
interpreters.destroy(id1)
interpreters.destroy(id3)
# Finally, create another.
id = interpreters.create()
self.assertEqual(set(interpreters.list_all()), before | {id, id2})
class DestroyTests(TestBase):
def test_one(self):
id1 = interpreters.create()
id2 = interpreters.create()
id3 = interpreters.create()
self.assertIn(id2, interpreters.list_all())
interpreters.destroy(id2)
self.assertNotIn(id2, interpreters.list_all())
self.assertIn(id1, interpreters.list_all())
self.assertIn(id3, interpreters.list_all())
def test_all(self):
before = set(interpreters.list_all())
ids = set()
for _ in range(3):
id = interpreters.create()
ids.add(id)
self.assertEqual(set(interpreters.list_all()), before | ids)
for id in ids:
interpreters.destroy(id)
self.assertEqual(set(interpreters.list_all()), before)
def test_main(self):
main, = interpreters.list_all()
with self.assertRaises(RuntimeError):
interpreters.destroy(main)
def f():
with self.assertRaises(RuntimeError):
interpreters.destroy(main)
t = threading.Thread(target=f)
t.start()
t.join()
def test_already_destroyed(self):
id = interpreters.create()
interpreters.destroy(id)
with self.assertRaises(RuntimeError):
interpreters.destroy(id)
def test_does_not_exist(self):
with self.assertRaises(RuntimeError):
interpreters.destroy(1_000_000)
def test_bad_id(self):
with self.assertRaises(RuntimeError):
interpreters.destroy(-1)
def test_from_current(self):
main, = interpreters.list_all()
id = interpreters.create()
script = dedent(f"""
import _xxsubinterpreters as _interpreters
try:
_interpreters.destroy({int(id)})
except RuntimeError:
pass
""")
interpreters.run_string(id, script)
self.assertEqual(set(interpreters.list_all()), {main, id})
def test_from_sibling(self):
main, = interpreters.list_all()
id1 = interpreters.create()
id2 = interpreters.create()
script = dedent(f"""
import _xxsubinterpreters as _interpreters
_interpreters.destroy({int(id2)})
""")
interpreters.run_string(id1, script)
self.assertEqual(set(interpreters.list_all()), {main, id1})
def test_from_other_thread(self):
id = interpreters.create()
def f():
interpreters.destroy(id)
t = threading.Thread(target=f)
t.start()
t.join()
def test_still_running(self):
main, = interpreters.list_all()
interp = interpreters.create()
with _running(interp):
with self.assertRaises(RuntimeError):
interpreters.destroy(interp)
self.assertTrue(interpreters.is_running(interp))
class RunStringTests(TestBase):
SCRIPT = dedent("""
with open('{}', 'w') as out:
out.write('{}')
""")
FILENAME = 'spam'
def setUp(self):
super().setUp()
self.id = interpreters.create()
self._fs = None
def tearDown(self):
if self._fs is not None:
self._fs.close()
super().tearDown()
@property
def fs(self):
if self._fs is None:
self._fs = FSFixture(self)
return self._fs
def test_success(self):
script, file = _captured_script('print("it worked!", end="")')
with file:
interpreters.run_string(self.id, script)
out = file.read()
self.assertEqual(out, 'it worked!')
def test_in_thread(self):
script, file = _captured_script('print("it worked!", end="")')
with file:
def f():
interpreters.run_string(self.id, script)
t = threading.Thread(target=f)
t.start()
t.join()
out = file.read()
self.assertEqual(out, 'it worked!')
def test_create_thread(self):
script, file = _captured_script("""
import threading
def f():
print('it worked!', end='')
t = threading.Thread(target=f)
t.start()
t.join()
""")
with file:
interpreters.run_string(self.id, script)
out = file.read()
self.assertEqual(out, 'it worked!')
@unittest.skipUnless(hasattr(os, 'fork'), "test needs os.fork()")
def test_fork(self):
import tempfile
with tempfile.NamedTemporaryFile('w+') as file:
file.write('')
file.flush()
expected = 'spam spam spam spam spam'
script = dedent(f"""
# (inspired by Lib/test/test_fork.py)
import os
pid = os.fork()
if pid == 0: # child
with open('{file.name}', 'w') as out:
out.write('{expected}')
# Kill the unittest runner in the child process.
os._exit(1)
else:
SHORT_SLEEP = 0.1
import time
for _ in range(10):
spid, status = os.waitpid(pid, os.WNOHANG)
if spid == pid:
break
time.sleep(SHORT_SLEEP)
assert(spid == pid)
""")
interpreters.run_string(self.id, script)
file.seek(0)
content = file.read()
self.assertEqual(content, expected)
def test_already_running(self):
with _running(self.id):
with self.assertRaises(RuntimeError):
interpreters.run_string(self.id, 'print("spam")')
def test_does_not_exist(self):
id = 0
while id in interpreters.list_all():
id += 1
with self.assertRaises(RuntimeError):
interpreters.run_string(id, 'print("spam")')
def test_error_id(self):
with self.assertRaises(RuntimeError):
interpreters.run_string(-1, 'print("spam")')
def test_bad_id(self):
with self.assertRaises(TypeError):
interpreters.run_string('spam', 'print("spam")')
def test_bad_script(self):
with self.assertRaises(TypeError):
interpreters.run_string(self.id, 10)
def test_bytes_for_script(self):
with self.assertRaises(TypeError):
interpreters.run_string(self.id, b'print("spam")')
@contextlib.contextmanager
def assert_run_failed(self, exctype, msg=None):
with self.assertRaises(interpreters.RunFailedError) as caught:
yield
if msg is None:
self.assertEqual(str(caught.exception).split(':')[0],
str(exctype))
else:
self.assertEqual(str(caught.exception),
"{}: {}".format(exctype, msg))
def test_invalid_syntax(self):
with self.assert_run_failed(SyntaxError):
# missing close paren
interpreters.run_string(self.id, 'print("spam"')
def test_failure(self):
with self.assert_run_failed(Exception, 'spam'):
interpreters.run_string(self.id, 'raise Exception("spam")')
def test_SystemExit(self):
with self.assert_run_failed(SystemExit, '42'):
interpreters.run_string(self.id, 'raise SystemExit(42)')
def test_sys_exit(self):
with self.assert_run_failed(SystemExit):
interpreters.run_string(self.id, dedent("""
import sys
sys.exit()
"""))
with self.assert_run_failed(SystemExit, '42'):
interpreters.run_string(self.id, dedent("""
import sys
sys.exit(42)
"""))
def test_with_shared(self):
r, w = os.pipe()
shared = {
'spam': b'ham',
'eggs': b'-1',
'cheddar': None,
}
script = dedent(f"""
eggs = int(eggs)
spam = 42
result = spam + eggs
ns = dict(vars())
del ns['__builtins__']
import pickle
with open({w}, 'wb') as chan:
pickle.dump(ns, chan)
""")
interpreters.run_string(self.id, script, shared)
with open(r, 'rb') as chan:
ns = pickle.load(chan)
self.assertEqual(ns['spam'], 42)
self.assertEqual(ns['eggs'], -1)
self.assertEqual(ns['result'], 41)
self.assertIsNone(ns['cheddar'])
def test_shared_overwrites(self):
interpreters.run_string(self.id, dedent("""
spam = 'eggs'
ns1 = dict(vars())
del ns1['__builtins__']
"""))
shared = {'spam': b'ham'}
script = dedent(f"""
ns2 = dict(vars())
del ns2['__builtins__']
""")
interpreters.run_string(self.id, script, shared)
r, w = os.pipe()
script = dedent(f"""
ns = dict(vars())
del ns['__builtins__']
import pickle
with open({w}, 'wb') as chan:
pickle.dump(ns, chan)
""")
interpreters.run_string(self.id, script)
with open(r, 'rb') as chan:
ns = pickle.load(chan)
self.assertEqual(ns['ns1']['spam'], 'eggs')
self.assertEqual(ns['ns2']['spam'], b'ham')
self.assertEqual(ns['spam'], b'ham')
def test_shared_overwrites_default_vars(self):
r, w = os.pipe()
shared = {'__name__': b'not __main__'}
script = dedent(f"""
spam = 42
ns = dict(vars())
del ns['__builtins__']
import pickle
with open({w}, 'wb') as chan:
pickle.dump(ns, chan)
""")
interpreters.run_string(self.id, script, shared)
with open(r, 'rb') as chan:
ns = pickle.load(chan)
self.assertEqual(ns['__name__'], b'not __main__')
def test_main_reused(self):
r, w = os.pipe()
interpreters.run_string(self.id, dedent(f"""
spam = True
ns = dict(vars())
del ns['__builtins__']
import pickle
with open({w}, 'wb') as chan:
pickle.dump(ns, chan)
del ns, pickle, chan
"""))
with open(r, 'rb') as chan:
ns1 = pickle.load(chan)
r, w = os.pipe()
interpreters.run_string(self.id, dedent(f"""
eggs = False
ns = dict(vars())
del ns['__builtins__']
import pickle
with open({w}, 'wb') as chan:
pickle.dump(ns, chan)
"""))
with open(r, 'rb') as chan:
ns2 = pickle.load(chan)
self.assertIn('spam', ns1)
self.assertNotIn('eggs', ns1)
self.assertIn('eggs', ns2)
self.assertIn('spam', ns2)
def test_execution_namespace_is_main(self):
r, w = os.pipe()
script = dedent(f"""
spam = 42
ns = dict(vars())
ns['__builtins__'] = str(ns['__builtins__'])
import pickle
with open({w}, 'wb') as chan:
pickle.dump(ns, chan)
""")
interpreters.run_string(self.id, script)
with open(r, 'rb') as chan:
ns = pickle.load(chan)
ns.pop('__builtins__')
ns.pop('__loader__')
self.assertEqual(ns, {
'__name__': '__main__',
'__annotations__': {},
'__doc__': None,
'__package__': None,
'__spec__': None,
'spam': 42,
})
# XXX Fix this test!
@unittest.skip('blocking forever')
def test_still_running_at_exit(self):
script = dedent(f"""
from textwrap import dedent
import threading
import _xxsubinterpreters as _interpreters
id = _interpreters.create()
def f():
_interpreters.run_string(id, dedent('''
import time
# Give plenty of time for the main interpreter to finish.
time.sleep(1_000_000)
'''))
t = threading.Thread(target=f)
t.start()
""")
with support.temp_dir() as dirname:
filename = script_helper.make_script(dirname, 'interp', script)
with script_helper.spawn_python(filename) as proc:
retcode = proc.wait()
self.assertEqual(retcode, 0)
class ChannelIDTests(TestBase):
def test_default_kwargs(self):
cid = interpreters._channel_id(10, force=True)
self.assertEqual(int(cid), 10)
self.assertEqual(cid.end, 'both')
def test_with_kwargs(self):
cid = interpreters._channel_id(10, send=True, force=True)
self.assertEqual(cid.end, 'send')
cid = interpreters._channel_id(10, send=True, recv=False, force=True)
self.assertEqual(cid.end, 'send')
cid = interpreters._channel_id(10, recv=True, force=True)
self.assertEqual(cid.end, 'recv')
cid = interpreters._channel_id(10, recv=True, send=False, force=True)
self.assertEqual(cid.end, 'recv')
cid = interpreters._channel_id(10, send=True, recv=True, force=True)
self.assertEqual(cid.end, 'both')
def test_coerce_id(self):
cid = interpreters._channel_id('10', force=True)
self.assertEqual(int(cid), 10)
cid = interpreters._channel_id(10.0, force=True)
self.assertEqual(int(cid), 10)
class Int(str):
def __init__(self, value):
self._value = value
def __int__(self):
return self._value
cid = interpreters._channel_id(Int(10), force=True)
self.assertEqual(int(cid), 10)
def test_bad_id(self):
for cid in [-1, 'spam']:
with self.subTest(cid):
with self.assertRaises(ValueError):
interpreters._channel_id(cid)
with self.assertRaises(OverflowError):
interpreters._channel_id(2**64)
with self.assertRaises(TypeError):
interpreters._channel_id(object())
def test_bad_kwargs(self):
with self.assertRaises(ValueError):
interpreters._channel_id(10, send=False, recv=False)
def test_does_not_exist(self):
cid = interpreters.channel_create()
with self.assertRaises(interpreters.ChannelNotFoundError):
interpreters._channel_id(int(cid) + 1) # unforced
def test_repr(self):
cid = interpreters._channel_id(10, force=True)
self.assertEqual(repr(cid), 'ChannelID(10)')
cid = interpreters._channel_id(10, send=True, force=True)
self.assertEqual(repr(cid), 'ChannelID(10, send=True)')
cid = interpreters._channel_id(10, recv=True, force=True)
self.assertEqual(repr(cid), 'ChannelID(10, recv=True)')
cid = interpreters._channel_id(10, send=True, recv=True, force=True)
self.assertEqual(repr(cid), 'ChannelID(10)')
def test_equality(self):
cid1 = interpreters.channel_create()
cid2 = interpreters._channel_id(int(cid1))
cid3 = interpreters.channel_create()
self.assertTrue(cid1 == cid1)
self.assertTrue(cid1 == cid2)
self.assertTrue(cid1 == int(cid1))
self.assertFalse(cid1 == cid3)
self.assertFalse(cid1 != cid1)
self.assertFalse(cid1 != cid2)
self.assertTrue(cid1 != cid3)
class ChannelTests(TestBase):
def test_sequential_ids(self):
before = interpreters.channel_list_all()
id1 = interpreters.channel_create()
id2 = interpreters.channel_create()
id3 = interpreters.channel_create()
after = interpreters.channel_list_all()
self.assertEqual(id2, int(id1) + 1)
self.assertEqual(id3, int(id2) + 1)
self.assertEqual(set(after) - set(before), {id1, id2, id3})
def test_ids_global(self):
id1 = interpreters.create()
out = _run_output(id1, dedent("""
import _xxsubinterpreters as _interpreters
cid = _interpreters.channel_create()
print(int(cid))
"""))
cid1 = int(out.strip())
id2 = interpreters.create()
out = _run_output(id2, dedent("""
import _xxsubinterpreters as _interpreters
cid = _interpreters.channel_create()
print(int(cid))
"""))
cid2 = int(out.strip())
self.assertEqual(cid2, int(cid1) + 1)
####################
def test_drop_single_user(self):
cid = interpreters.channel_create()
interpreters.channel_send(cid, b'spam')
interpreters.channel_recv(cid)
interpreters.channel_drop_interpreter(cid, send=True, recv=True)
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_send(cid, b'eggs')
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_recv(cid)
def test_drop_multiple_users(self):
cid = interpreters.channel_create()
id1 = interpreters.create()
id2 = interpreters.create()
interpreters.run_string(id1, dedent(f"""
import _xxsubinterpreters as _interpreters
_interpreters.channel_send({int(cid)}, b'spam')
"""))
out = _run_output(id2, dedent(f"""
import _xxsubinterpreters as _interpreters
obj = _interpreters.channel_recv({int(cid)})
_interpreters.channel_drop_interpreter({int(cid)})
print(repr(obj))
"""))
interpreters.run_string(id1, dedent(f"""
_interpreters.channel_drop_interpreter({int(cid)})
"""))
self.assertEqual(out.strip(), "b'spam'")
def test_drop_no_kwargs(self):
cid = interpreters.channel_create()
interpreters.channel_send(cid, b'spam')
interpreters.channel_recv(cid)
interpreters.channel_drop_interpreter(cid)
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_send(cid, b'eggs')
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_recv(cid)
def test_drop_multiple_times(self):
cid = interpreters.channel_create()
interpreters.channel_send(cid, b'spam')
interpreters.channel_recv(cid)
interpreters.channel_drop_interpreter(cid, send=True, recv=True)
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_drop_interpreter(cid, send=True, recv=True)
def test_drop_with_unused_items(self):
cid = interpreters.channel_create()
interpreters.channel_send(cid, b'spam')
interpreters.channel_send(cid, b'ham')
interpreters.channel_drop_interpreter(cid, send=True, recv=True)
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_recv(cid)
def test_drop_never_used(self):
cid = interpreters.channel_create()
interpreters.channel_drop_interpreter(cid)
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_send(cid, b'spam')
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_recv(cid)
def test_drop_by_unassociated_interp(self):
cid = interpreters.channel_create()
interpreters.channel_send(cid, b'spam')
interp = interpreters.create()
interpreters.run_string(interp, dedent(f"""
import _xxsubinterpreters as _interpreters
_interpreters.channel_drop_interpreter({int(cid)})
"""))
obj = interpreters.channel_recv(cid)
interpreters.channel_drop_interpreter(cid)
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_send(cid, b'eggs')
self.assertEqual(obj, b'spam')
def test_drop_close_if_unassociated(self):
cid = interpreters.channel_create()
interp = interpreters.create()
interpreters.run_string(interp, dedent(f"""
import _xxsubinterpreters as _interpreters
obj = _interpreters.channel_send({int(cid)}, b'spam')
_interpreters.channel_drop_interpreter({int(cid)})
"""))
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_recv(cid)
def test_drop_partially(self):
# XXX Is partial close too weird/confusing?
cid = interpreters.channel_create()
interpreters.channel_send(cid, None)
interpreters.channel_recv(cid)
interpreters.channel_send(cid, b'spam')
interpreters.channel_drop_interpreter(cid, send=True)
obj = interpreters.channel_recv(cid)
self.assertEqual(obj, b'spam')
def test_drop_used_multiple_times_by_single_user(self):
cid = interpreters.channel_create()
interpreters.channel_send(cid, b'spam')
interpreters.channel_send(cid, b'spam')
interpreters.channel_send(cid, b'spam')
interpreters.channel_recv(cid)
interpreters.channel_drop_interpreter(cid, send=True, recv=True)
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_send(cid, b'eggs')
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_recv(cid)
####################
def test_close_single_user(self):
cid = interpreters.channel_create()
interpreters.channel_send(cid, b'spam')
interpreters.channel_recv(cid)
interpreters.channel_close(cid)
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_send(cid, b'eggs')
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_recv(cid)
def test_close_multiple_users(self):
cid = interpreters.channel_create()
id1 = interpreters.create()
id2 = interpreters.create()
interpreters.run_string(id1, dedent(f"""
import _xxsubinterpreters as _interpreters
_interpreters.channel_send({int(cid)}, b'spam')
"""))
interpreters.run_string(id2, dedent(f"""
import _xxsubinterpreters as _interpreters
_interpreters.channel_recv({int(cid)})
"""))
interpreters.channel_close(cid)
with self.assertRaises(interpreters.RunFailedError) as cm:
interpreters.run_string(id1, dedent(f"""
_interpreters.channel_send({int(cid)}, b'spam')
"""))
self.assertIn('ChannelClosedError', str(cm.exception))
with self.assertRaises(interpreters.RunFailedError) as cm:
interpreters.run_string(id2, dedent(f"""
_interpreters.channel_send({int(cid)}, b'spam')
"""))
self.assertIn('ChannelClosedError', str(cm.exception))
def test_close_multiple_times(self):
cid = interpreters.channel_create()
interpreters.channel_send(cid, b'spam')
interpreters.channel_recv(cid)
interpreters.channel_close(cid)
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_close(cid)
def test_close_with_unused_items(self):
cid = interpreters.channel_create()
interpreters.channel_send(cid, b'spam')
interpreters.channel_send(cid, b'ham')
interpreters.channel_close(cid)
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_recv(cid)
def test_close_never_used(self):
cid = interpreters.channel_create()
interpreters.channel_close(cid)
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_send(cid, b'spam')
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_recv(cid)
def test_close_by_unassociated_interp(self):
cid = interpreters.channel_create()
interpreters.channel_send(cid, b'spam')
interp = interpreters.create()
interpreters.run_string(interp, dedent(f"""
import _xxsubinterpreters as _interpreters
_interpreters.channel_close({int(cid)})
"""))
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_recv(cid)
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_close(cid)
def test_close_used_multiple_times_by_single_user(self):
cid = interpreters.channel_create()
interpreters.channel_send(cid, b'spam')
interpreters.channel_send(cid, b'spam')
interpreters.channel_send(cid, b'spam')
interpreters.channel_recv(cid)
interpreters.channel_close(cid)
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_send(cid, b'eggs')
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_recv(cid)
####################
def test_send_recv_main(self):
cid = interpreters.channel_create()
orig = b'spam'
interpreters.channel_send(cid, orig)
obj = interpreters.channel_recv(cid)
self.assertEqual(obj, orig)
self.assertIsNot(obj, orig)
def test_send_recv_same_interpreter(self):
id1 = interpreters.create()
out = _run_output(id1, dedent("""
import _xxsubinterpreters as _interpreters
cid = _interpreters.channel_create()
orig = b'spam'
_interpreters.channel_send(cid, orig)
obj = _interpreters.channel_recv(cid)
assert obj is not orig
assert obj == orig
"""))
def test_send_recv_different_interpreters(self):
cid = interpreters.channel_create()
id1 = interpreters.create()
out = _run_output(id1, dedent(f"""
import _xxsubinterpreters as _interpreters
_interpreters.channel_send({int(cid)}, b'spam')
"""))
obj = interpreters.channel_recv(cid)
self.assertEqual(obj, b'spam')
def test_send_recv_different_threads(self):
cid = interpreters.channel_create()
def f():
while True:
try:
obj = interpreters.channel_recv(cid)
break
except interpreters.ChannelEmptyError:
time.sleep(0.1)
interpreters.channel_send(cid, obj)
t = threading.Thread(target=f)
t.start()
interpreters.channel_send(cid, b'spam')
t.join()
obj = interpreters.channel_recv(cid)
self.assertEqual(obj, b'spam')
def test_send_recv_different_interpreters_and_threads(self):
cid = interpreters.channel_create()
id1 = interpreters.create()
out = None
def f():
nonlocal out
out = _run_output(id1, dedent(f"""
import time
import _xxsubinterpreters as _interpreters
while True:
try:
obj = _interpreters.channel_recv({int(cid)})
break
except _interpreters.ChannelEmptyError:
time.sleep(0.1)
assert(obj == b'spam')
_interpreters.channel_send({int(cid)}, b'eggs')
"""))
t = threading.Thread(target=f)
t.start()
interpreters.channel_send(cid, b'spam')
t.join()
obj = interpreters.channel_recv(cid)
self.assertEqual(obj, b'eggs')
def test_send_not_found(self):
with self.assertRaises(interpreters.ChannelNotFoundError):
interpreters.channel_send(10, b'spam')
def test_recv_not_found(self):
with self.assertRaises(interpreters.ChannelNotFoundError):
interpreters.channel_recv(10)
def test_recv_empty(self):
cid = interpreters.channel_create()
with self.assertRaises(interpreters.ChannelEmptyError):
interpreters.channel_recv(cid)
def test_run_string_arg(self):
cid = interpreters.channel_create()
interp = interpreters.create()
out = _run_output(interp, dedent("""
import _xxsubinterpreters as _interpreters
print(cid.end)
_interpreters.channel_send(cid, b'spam')
"""),
dict(cid=cid.send))
obj = interpreters.channel_recv(cid)
self.assertEqual(obj, b'spam')
self.assertEqual(out.strip(), 'send')
if __name__ == '__main__':
unittest.main()