1121 lines
35 KiB
Python
1121 lines
35 KiB
Python
import contextlib
|
|
import os
|
|
import pickle
|
|
from textwrap import dedent, indent
|
|
import threading
|
|
import unittest
|
|
|
|
from test import support
|
|
from test.support import script_helper
|
|
|
|
interpreters = support.import_module('_xxsubinterpreters')
|
|
|
|
|
|
def _captured_script(script):
|
|
r, w = os.pipe()
|
|
indented = script.replace('\n', '\n ')
|
|
wrapped = dedent(f"""
|
|
import contextlib
|
|
with open({w}, 'w') as chan:
|
|
with contextlib.redirect_stdout(chan):
|
|
{indented}
|
|
""")
|
|
return wrapped, open(r)
|
|
|
|
|
|
def _run_output(interp, request, shared=None):
|
|
script, chan = _captured_script(request)
|
|
with chan:
|
|
interpreters.run_string(interp, script, shared)
|
|
return chan.read()
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _running(interp):
|
|
r, w = os.pipe()
|
|
def run():
|
|
interpreters.run_string(interp, dedent(f"""
|
|
# wait for "signal"
|
|
with open({r}) as chan:
|
|
chan.read()
|
|
"""))
|
|
|
|
t = threading.Thread(target=run)
|
|
t.start()
|
|
|
|
yield
|
|
|
|
with open(w, 'w') as chan:
|
|
chan.write('done')
|
|
t.join()
|
|
|
|
|
|
class IsShareableTests(unittest.TestCase):
|
|
|
|
def test_default_shareables(self):
|
|
shareables = [
|
|
# singletons
|
|
None,
|
|
# builtin objects
|
|
b'spam',
|
|
]
|
|
for obj in shareables:
|
|
with self.subTest(obj):
|
|
self.assertTrue(
|
|
interpreters.is_shareable(obj))
|
|
|
|
def test_not_shareable(self):
|
|
class Cheese:
|
|
def __init__(self, name):
|
|
self.name = name
|
|
def __str__(self):
|
|
return self.name
|
|
|
|
class SubBytes(bytes):
|
|
"""A subclass of a shareable type."""
|
|
|
|
not_shareables = [
|
|
# singletons
|
|
True,
|
|
False,
|
|
NotImplemented,
|
|
...,
|
|
# builtin types and objects
|
|
type,
|
|
object,
|
|
object(),
|
|
Exception(),
|
|
42,
|
|
100.0,
|
|
'spam',
|
|
# user-defined types and objects
|
|
Cheese,
|
|
Cheese('Wensleydale'),
|
|
SubBytes(b'spam'),
|
|
]
|
|
for obj in not_shareables:
|
|
with self.subTest(obj):
|
|
self.assertFalse(
|
|
interpreters.is_shareable(obj))
|
|
|
|
|
|
class TestBase(unittest.TestCase):
|
|
|
|
def tearDown(self):
|
|
for id in interpreters.list_all():
|
|
if id == 0: # main
|
|
continue
|
|
try:
|
|
interpreters.destroy(id)
|
|
except RuntimeError:
|
|
pass # already destroyed
|
|
|
|
for cid in interpreters.channel_list_all():
|
|
try:
|
|
interpreters.channel_destroy(cid)
|
|
except interpreters.ChannelNotFoundError:
|
|
pass # already destroyed
|
|
|
|
|
|
class ListAllTests(TestBase):
|
|
|
|
def test_initial(self):
|
|
main = interpreters.get_main()
|
|
ids = interpreters.list_all()
|
|
self.assertEqual(ids, [main])
|
|
|
|
def test_after_creating(self):
|
|
main = interpreters.get_main()
|
|
first = interpreters.create()
|
|
second = interpreters.create()
|
|
ids = interpreters.list_all()
|
|
self.assertEqual(ids, [main, first, second])
|
|
|
|
def test_after_destroying(self):
|
|
main = interpreters.get_main()
|
|
first = interpreters.create()
|
|
second = interpreters.create()
|
|
interpreters.destroy(first)
|
|
ids = interpreters.list_all()
|
|
self.assertEqual(ids, [main, second])
|
|
|
|
|
|
class GetCurrentTests(TestBase):
|
|
|
|
def test_main(self):
|
|
main = interpreters.get_main()
|
|
cur = interpreters.get_current()
|
|
self.assertEqual(cur, main)
|
|
|
|
def test_subinterpreter(self):
|
|
main = interpreters.get_main()
|
|
interp = interpreters.create()
|
|
out = _run_output(interp, dedent("""
|
|
import _xxsubinterpreters as _interpreters
|
|
print(_interpreters.get_current())
|
|
"""))
|
|
cur = int(out.strip())
|
|
_, expected = interpreters.list_all()
|
|
self.assertEqual(cur, expected)
|
|
self.assertNotEqual(cur, main)
|
|
|
|
|
|
class GetMainTests(TestBase):
|
|
|
|
def test_from_main(self):
|
|
[expected] = interpreters.list_all()
|
|
main = interpreters.get_main()
|
|
self.assertEqual(main, expected)
|
|
|
|
def test_from_subinterpreter(self):
|
|
[expected] = interpreters.list_all()
|
|
interp = interpreters.create()
|
|
out = _run_output(interp, dedent("""
|
|
import _xxsubinterpreters as _interpreters
|
|
print(_interpreters.get_main())
|
|
"""))
|
|
main = int(out.strip())
|
|
self.assertEqual(main, expected)
|
|
|
|
|
|
class IsRunningTests(TestBase):
|
|
|
|
def test_main(self):
|
|
main = interpreters.get_main()
|
|
self.assertTrue(interpreters.is_running(main))
|
|
|
|
def test_subinterpreter(self):
|
|
interp = interpreters.create()
|
|
self.assertFalse(interpreters.is_running(interp))
|
|
|
|
with _running(interp):
|
|
self.assertTrue(interpreters.is_running(interp))
|
|
self.assertFalse(interpreters.is_running(interp))
|
|
|
|
def test_from_subinterpreter(self):
|
|
interp = interpreters.create()
|
|
out = _run_output(interp, dedent(f"""
|
|
import _xxsubinterpreters as _interpreters
|
|
if _interpreters.is_running({interp}):
|
|
print(True)
|
|
else:
|
|
print(False)
|
|
"""))
|
|
self.assertEqual(out.strip(), 'True')
|
|
|
|
def test_already_destroyed(self):
|
|
interp = interpreters.create()
|
|
interpreters.destroy(interp)
|
|
with self.assertRaises(RuntimeError):
|
|
interpreters.is_running(interp)
|
|
|
|
def test_does_not_exist(self):
|
|
with self.assertRaises(RuntimeError):
|
|
interpreters.is_running(1_000_000)
|
|
|
|
def test_bad_id(self):
|
|
with self.assertRaises(RuntimeError):
|
|
interpreters.is_running(-1)
|
|
|
|
|
|
class CreateTests(TestBase):
|
|
|
|
def test_in_main(self):
|
|
id = interpreters.create()
|
|
|
|
self.assertIn(id, interpreters.list_all())
|
|
|
|
@unittest.skip('enable this test when working on pystate.c')
|
|
def test_unique_id(self):
|
|
seen = set()
|
|
for _ in range(100):
|
|
id = interpreters.create()
|
|
interpreters.destroy(id)
|
|
seen.add(id)
|
|
|
|
self.assertEqual(len(seen), 100)
|
|
|
|
def test_in_thread(self):
|
|
lock = threading.Lock()
|
|
id = None
|
|
def f():
|
|
nonlocal id
|
|
id = interpreters.create()
|
|
lock.acquire()
|
|
lock.release()
|
|
|
|
t = threading.Thread(target=f)
|
|
with lock:
|
|
t.start()
|
|
t.join()
|
|
self.assertIn(id, interpreters.list_all())
|
|
|
|
def test_in_subinterpreter(self):
|
|
main, = interpreters.list_all()
|
|
id1 = interpreters.create()
|
|
out = _run_output(id1, dedent("""
|
|
import _xxsubinterpreters as _interpreters
|
|
id = _interpreters.create()
|
|
print(id)
|
|
"""))
|
|
id2 = int(out.strip())
|
|
|
|
self.assertEqual(set(interpreters.list_all()), {main, id1, id2})
|
|
|
|
def test_in_threaded_subinterpreter(self):
|
|
main, = interpreters.list_all()
|
|
id1 = interpreters.create()
|
|
id2 = None
|
|
def f():
|
|
nonlocal id2
|
|
out = _run_output(id1, dedent("""
|
|
import _xxsubinterpreters as _interpreters
|
|
id = _interpreters.create()
|
|
print(id)
|
|
"""))
|
|
id2 = int(out.strip())
|
|
|
|
t = threading.Thread(target=f)
|
|
t.start()
|
|
t.join()
|
|
|
|
self.assertEqual(set(interpreters.list_all()), {main, id1, id2})
|
|
|
|
def test_after_destroy_all(self):
|
|
before = set(interpreters.list_all())
|
|
# Create 3 subinterpreters.
|
|
ids = []
|
|
for _ in range(3):
|
|
id = interpreters.create()
|
|
ids.append(id)
|
|
# Now destroy them.
|
|
for id in ids:
|
|
interpreters.destroy(id)
|
|
# Finally, create another.
|
|
id = interpreters.create()
|
|
self.assertEqual(set(interpreters.list_all()), before | {id})
|
|
|
|
def test_after_destroy_some(self):
|
|
before = set(interpreters.list_all())
|
|
# Create 3 subinterpreters.
|
|
id1 = interpreters.create()
|
|
id2 = interpreters.create()
|
|
id3 = interpreters.create()
|
|
# Now destroy 2 of them.
|
|
interpreters.destroy(id1)
|
|
interpreters.destroy(id3)
|
|
# Finally, create another.
|
|
id = interpreters.create()
|
|
self.assertEqual(set(interpreters.list_all()), before | {id, id2})
|
|
|
|
|
|
class DestroyTests(TestBase):
|
|
|
|
def test_one(self):
|
|
id1 = interpreters.create()
|
|
id2 = interpreters.create()
|
|
id3 = interpreters.create()
|
|
self.assertIn(id2, interpreters.list_all())
|
|
interpreters.destroy(id2)
|
|
self.assertNotIn(id2, interpreters.list_all())
|
|
self.assertIn(id1, interpreters.list_all())
|
|
self.assertIn(id3, interpreters.list_all())
|
|
|
|
def test_all(self):
|
|
before = set(interpreters.list_all())
|
|
ids = set()
|
|
for _ in range(3):
|
|
id = interpreters.create()
|
|
ids.add(id)
|
|
self.assertEqual(set(interpreters.list_all()), before | ids)
|
|
for id in ids:
|
|
interpreters.destroy(id)
|
|
self.assertEqual(set(interpreters.list_all()), before)
|
|
|
|
def test_main(self):
|
|
main, = interpreters.list_all()
|
|
with self.assertRaises(RuntimeError):
|
|
interpreters.destroy(main)
|
|
|
|
def f():
|
|
with self.assertRaises(RuntimeError):
|
|
interpreters.destroy(main)
|
|
|
|
t = threading.Thread(target=f)
|
|
t.start()
|
|
t.join()
|
|
|
|
def test_already_destroyed(self):
|
|
id = interpreters.create()
|
|
interpreters.destroy(id)
|
|
with self.assertRaises(RuntimeError):
|
|
interpreters.destroy(id)
|
|
|
|
def test_does_not_exist(self):
|
|
with self.assertRaises(RuntimeError):
|
|
interpreters.destroy(1_000_000)
|
|
|
|
def test_bad_id(self):
|
|
with self.assertRaises(RuntimeError):
|
|
interpreters.destroy(-1)
|
|
|
|
def test_from_current(self):
|
|
main, = interpreters.list_all()
|
|
id = interpreters.create()
|
|
script = dedent(f"""
|
|
import _xxsubinterpreters as _interpreters
|
|
try:
|
|
_interpreters.destroy({id})
|
|
except RuntimeError:
|
|
pass
|
|
""")
|
|
|
|
interpreters.run_string(id, script)
|
|
self.assertEqual(set(interpreters.list_all()), {main, id})
|
|
|
|
def test_from_sibling(self):
|
|
main, = interpreters.list_all()
|
|
id1 = interpreters.create()
|
|
id2 = interpreters.create()
|
|
script = dedent("""
|
|
import _xxsubinterpreters as _interpreters
|
|
_interpreters.destroy({})
|
|
""").format(id2)
|
|
interpreters.run_string(id1, script)
|
|
|
|
self.assertEqual(set(interpreters.list_all()), {main, id1})
|
|
|
|
def test_from_other_thread(self):
|
|
id = interpreters.create()
|
|
def f():
|
|
interpreters.destroy(id)
|
|
|
|
t = threading.Thread(target=f)
|
|
t.start()
|
|
t.join()
|
|
|
|
def test_still_running(self):
|
|
main, = interpreters.list_all()
|
|
interp = interpreters.create()
|
|
with _running(interp):
|
|
with self.assertRaises(RuntimeError):
|
|
interpreters.destroy(interp)
|
|
self.assertTrue(interpreters.is_running(interp))
|
|
|
|
|
|
class RunStringTests(TestBase):
|
|
|
|
SCRIPT = dedent("""
|
|
with open('{}', 'w') as out:
|
|
out.write('{}')
|
|
""")
|
|
FILENAME = 'spam'
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.id = interpreters.create()
|
|
self._fs = None
|
|
|
|
def tearDown(self):
|
|
if self._fs is not None:
|
|
self._fs.close()
|
|
super().tearDown()
|
|
|
|
@property
|
|
def fs(self):
|
|
if self._fs is None:
|
|
self._fs = FSFixture(self)
|
|
return self._fs
|
|
|
|
def test_success(self):
|
|
script, file = _captured_script('print("it worked!", end="")')
|
|
with file:
|
|
interpreters.run_string(self.id, script)
|
|
out = file.read()
|
|
|
|
self.assertEqual(out, 'it worked!')
|
|
|
|
def test_in_thread(self):
|
|
script, file = _captured_script('print("it worked!", end="")')
|
|
with file:
|
|
def f():
|
|
interpreters.run_string(self.id, script)
|
|
|
|
t = threading.Thread(target=f)
|
|
t.start()
|
|
t.join()
|
|
out = file.read()
|
|
|
|
self.assertEqual(out, 'it worked!')
|
|
|
|
def test_create_thread(self):
|
|
script, file = _captured_script("""
|
|
import threading
|
|
def f():
|
|
print('it worked!', end='')
|
|
|
|
t = threading.Thread(target=f)
|
|
t.start()
|
|
t.join()
|
|
""")
|
|
with file:
|
|
interpreters.run_string(self.id, script)
|
|
out = file.read()
|
|
|
|
self.assertEqual(out, 'it worked!')
|
|
|
|
@unittest.skipUnless(hasattr(os, 'fork'), "test needs os.fork()")
|
|
def test_fork(self):
|
|
import tempfile
|
|
with tempfile.NamedTemporaryFile('w+') as file:
|
|
file.write('')
|
|
file.flush()
|
|
|
|
expected = 'spam spam spam spam spam'
|
|
script = dedent(f"""
|
|
# (inspired by Lib/test/test_fork.py)
|
|
import os
|
|
pid = os.fork()
|
|
if pid == 0: # child
|
|
with open('{file.name}', 'w') as out:
|
|
out.write('{expected}')
|
|
# Kill the unittest runner in the child process.
|
|
os._exit(1)
|
|
else:
|
|
SHORT_SLEEP = 0.1
|
|
import time
|
|
for _ in range(10):
|
|
spid, status = os.waitpid(pid, os.WNOHANG)
|
|
if spid == pid:
|
|
break
|
|
time.sleep(SHORT_SLEEP)
|
|
assert(spid == pid)
|
|
""")
|
|
interpreters.run_string(self.id, script)
|
|
|
|
file.seek(0)
|
|
content = file.read()
|
|
self.assertEqual(content, expected)
|
|
|
|
def test_already_running(self):
|
|
with _running(self.id):
|
|
with self.assertRaises(RuntimeError):
|
|
interpreters.run_string(self.id, 'print("spam")')
|
|
|
|
def test_does_not_exist(self):
|
|
id = 0
|
|
while id in interpreters.list_all():
|
|
id += 1
|
|
with self.assertRaises(RuntimeError):
|
|
interpreters.run_string(id, 'print("spam")')
|
|
|
|
def test_error_id(self):
|
|
with self.assertRaises(RuntimeError):
|
|
interpreters.run_string(-1, 'print("spam")')
|
|
|
|
def test_bad_id(self):
|
|
with self.assertRaises(TypeError):
|
|
interpreters.run_string('spam', 'print("spam")')
|
|
|
|
def test_bad_script(self):
|
|
with self.assertRaises(TypeError):
|
|
interpreters.run_string(self.id, 10)
|
|
|
|
def test_bytes_for_script(self):
|
|
with self.assertRaises(TypeError):
|
|
interpreters.run_string(self.id, b'print("spam")')
|
|
|
|
@contextlib.contextmanager
|
|
def assert_run_failed(self, exctype, msg=None):
|
|
with self.assertRaises(interpreters.RunFailedError) as caught:
|
|
yield
|
|
if msg is None:
|
|
self.assertEqual(str(caught.exception).split(':')[0],
|
|
str(exctype))
|
|
else:
|
|
self.assertEqual(str(caught.exception),
|
|
"{}: {}".format(exctype, msg))
|
|
|
|
def test_invalid_syntax(self):
|
|
with self.assert_run_failed(SyntaxError):
|
|
# missing close paren
|
|
interpreters.run_string(self.id, 'print("spam"')
|
|
|
|
def test_failure(self):
|
|
with self.assert_run_failed(Exception, 'spam'):
|
|
interpreters.run_string(self.id, 'raise Exception("spam")')
|
|
|
|
def test_SystemExit(self):
|
|
with self.assert_run_failed(SystemExit, '42'):
|
|
interpreters.run_string(self.id, 'raise SystemExit(42)')
|
|
|
|
def test_sys_exit(self):
|
|
with self.assert_run_failed(SystemExit):
|
|
interpreters.run_string(self.id, dedent("""
|
|
import sys
|
|
sys.exit()
|
|
"""))
|
|
|
|
with self.assert_run_failed(SystemExit, '42'):
|
|
interpreters.run_string(self.id, dedent("""
|
|
import sys
|
|
sys.exit(42)
|
|
"""))
|
|
|
|
def test_with_shared(self):
|
|
r, w = os.pipe()
|
|
|
|
shared = {
|
|
'spam': b'ham',
|
|
'eggs': b'-1',
|
|
'cheddar': None,
|
|
}
|
|
script = dedent(f"""
|
|
eggs = int(eggs)
|
|
spam = 42
|
|
result = spam + eggs
|
|
|
|
ns = dict(vars())
|
|
del ns['__builtins__']
|
|
import pickle
|
|
with open({w}, 'wb') as chan:
|
|
pickle.dump(ns, chan)
|
|
""")
|
|
interpreters.run_string(self.id, script, shared)
|
|
with open(r, 'rb') as chan:
|
|
ns = pickle.load(chan)
|
|
|
|
self.assertEqual(ns['spam'], 42)
|
|
self.assertEqual(ns['eggs'], -1)
|
|
self.assertEqual(ns['result'], 41)
|
|
self.assertIsNone(ns['cheddar'])
|
|
|
|
def test_shared_overwrites(self):
|
|
interpreters.run_string(self.id, dedent("""
|
|
spam = 'eggs'
|
|
ns1 = dict(vars())
|
|
del ns1['__builtins__']
|
|
"""))
|
|
|
|
shared = {'spam': b'ham'}
|
|
script = dedent(f"""
|
|
ns2 = dict(vars())
|
|
del ns2['__builtins__']
|
|
""")
|
|
interpreters.run_string(self.id, script, shared)
|
|
|
|
r, w = os.pipe()
|
|
script = dedent(f"""
|
|
ns = dict(vars())
|
|
del ns['__builtins__']
|
|
import pickle
|
|
with open({w}, 'wb') as chan:
|
|
pickle.dump(ns, chan)
|
|
""")
|
|
interpreters.run_string(self.id, script)
|
|
with open(r, 'rb') as chan:
|
|
ns = pickle.load(chan)
|
|
|
|
self.assertEqual(ns['ns1']['spam'], 'eggs')
|
|
self.assertEqual(ns['ns2']['spam'], b'ham')
|
|
self.assertEqual(ns['spam'], b'ham')
|
|
|
|
def test_shared_overwrites_default_vars(self):
|
|
r, w = os.pipe()
|
|
|
|
shared = {'__name__': b'not __main__'}
|
|
script = dedent(f"""
|
|
spam = 42
|
|
|
|
ns = dict(vars())
|
|
del ns['__builtins__']
|
|
import pickle
|
|
with open({w}, 'wb') as chan:
|
|
pickle.dump(ns, chan)
|
|
""")
|
|
interpreters.run_string(self.id, script, shared)
|
|
with open(r, 'rb') as chan:
|
|
ns = pickle.load(chan)
|
|
|
|
self.assertEqual(ns['__name__'], b'not __main__')
|
|
|
|
def test_main_reused(self):
|
|
r, w = os.pipe()
|
|
interpreters.run_string(self.id, dedent(f"""
|
|
spam = True
|
|
|
|
ns = dict(vars())
|
|
del ns['__builtins__']
|
|
import pickle
|
|
with open({w}, 'wb') as chan:
|
|
pickle.dump(ns, chan)
|
|
del ns, pickle, chan
|
|
"""))
|
|
with open(r, 'rb') as chan:
|
|
ns1 = pickle.load(chan)
|
|
|
|
r, w = os.pipe()
|
|
interpreters.run_string(self.id, dedent(f"""
|
|
eggs = False
|
|
|
|
ns = dict(vars())
|
|
del ns['__builtins__']
|
|
import pickle
|
|
with open({w}, 'wb') as chan:
|
|
pickle.dump(ns, chan)
|
|
"""))
|
|
with open(r, 'rb') as chan:
|
|
ns2 = pickle.load(chan)
|
|
|
|
self.assertIn('spam', ns1)
|
|
self.assertNotIn('eggs', ns1)
|
|
self.assertIn('eggs', ns2)
|
|
self.assertIn('spam', ns2)
|
|
|
|
def test_execution_namespace_is_main(self):
|
|
r, w = os.pipe()
|
|
|
|
script = dedent(f"""
|
|
spam = 42
|
|
|
|
ns = dict(vars())
|
|
ns['__builtins__'] = str(ns['__builtins__'])
|
|
import pickle
|
|
with open({w}, 'wb') as chan:
|
|
pickle.dump(ns, chan)
|
|
""")
|
|
interpreters.run_string(self.id, script)
|
|
with open(r, 'rb') as chan:
|
|
ns = pickle.load(chan)
|
|
|
|
ns.pop('__builtins__')
|
|
ns.pop('__loader__')
|
|
self.assertEqual(ns, {
|
|
'__name__': '__main__',
|
|
'__annotations__': {},
|
|
'__doc__': None,
|
|
'__package__': None,
|
|
'__spec__': None,
|
|
'spam': 42,
|
|
})
|
|
|
|
def test_still_running_at_exit(self):
|
|
script = dedent(f"""
|
|
from textwrap import dedent
|
|
import threading
|
|
import _xxsubinterpreters as _interpreters
|
|
def f():
|
|
_interpreters.run_string(id, dedent('''
|
|
import time
|
|
# Give plenty of time for the main interpreter to finish.
|
|
time.sleep(1_000_000)
|
|
'''))
|
|
|
|
t = threading.Thread(target=f)
|
|
t.start()
|
|
""")
|
|
with support.temp_dir() as dirname:
|
|
filename = script_helper.make_script(dirname, 'interp', script)
|
|
with script_helper.spawn_python(filename) as proc:
|
|
retcode = proc.wait()
|
|
|
|
self.assertEqual(retcode, 0)
|
|
|
|
|
|
class ChannelIDTests(TestBase):
|
|
|
|
def test_default_kwargs(self):
|
|
cid = interpreters._channel_id(10, force=True)
|
|
|
|
self.assertEqual(int(cid), 10)
|
|
self.assertEqual(cid.end, 'both')
|
|
|
|
def test_with_kwargs(self):
|
|
cid = interpreters._channel_id(10, send=True, force=True)
|
|
self.assertEqual(cid.end, 'send')
|
|
|
|
cid = interpreters._channel_id(10, send=True, recv=False, force=True)
|
|
self.assertEqual(cid.end, 'send')
|
|
|
|
cid = interpreters._channel_id(10, recv=True, force=True)
|
|
self.assertEqual(cid.end, 'recv')
|
|
|
|
cid = interpreters._channel_id(10, recv=True, send=False, force=True)
|
|
self.assertEqual(cid.end, 'recv')
|
|
|
|
cid = interpreters._channel_id(10, send=True, recv=True, force=True)
|
|
self.assertEqual(cid.end, 'both')
|
|
|
|
def test_coerce_id(self):
|
|
cid = interpreters._channel_id('10', force=True)
|
|
self.assertEqual(int(cid), 10)
|
|
|
|
cid = interpreters._channel_id(10.0, force=True)
|
|
self.assertEqual(int(cid), 10)
|
|
|
|
class Int(str):
|
|
def __init__(self, value):
|
|
self._value = value
|
|
def __int__(self):
|
|
return self._value
|
|
|
|
cid = interpreters._channel_id(Int(10), force=True)
|
|
self.assertEqual(int(cid), 10)
|
|
|
|
def test_bad_id(self):
|
|
for cid in [-1, 'spam']:
|
|
with self.subTest(cid):
|
|
with self.assertRaises(ValueError):
|
|
interpreters._channel_id(cid)
|
|
with self.assertRaises(OverflowError):
|
|
interpreters._channel_id(2**64)
|
|
with self.assertRaises(TypeError):
|
|
interpreters._channel_id(object())
|
|
|
|
def test_bad_kwargs(self):
|
|
with self.assertRaises(ValueError):
|
|
interpreters._channel_id(10, send=False, recv=False)
|
|
|
|
def test_does_not_exist(self):
|
|
cid = interpreters.channel_create()
|
|
with self.assertRaises(interpreters.ChannelNotFoundError):
|
|
interpreters._channel_id(int(cid) + 1) # unforced
|
|
|
|
def test_repr(self):
|
|
cid = interpreters._channel_id(10, force=True)
|
|
self.assertEqual(repr(cid), 'ChannelID(10)')
|
|
|
|
cid = interpreters._channel_id(10, send=True, force=True)
|
|
self.assertEqual(repr(cid), 'ChannelID(10, send=True)')
|
|
|
|
cid = interpreters._channel_id(10, recv=True, force=True)
|
|
self.assertEqual(repr(cid), 'ChannelID(10, recv=True)')
|
|
|
|
cid = interpreters._channel_id(10, send=True, recv=True, force=True)
|
|
self.assertEqual(repr(cid), 'ChannelID(10)')
|
|
|
|
def test_equality(self):
|
|
cid1 = interpreters.channel_create()
|
|
cid2 = interpreters._channel_id(int(cid1))
|
|
cid3 = interpreters.channel_create()
|
|
|
|
self.assertTrue(cid1 == cid1)
|
|
self.assertTrue(cid1 == cid2)
|
|
self.assertTrue(cid1 == int(cid1))
|
|
self.assertFalse(cid1 == cid3)
|
|
|
|
self.assertFalse(cid1 != cid1)
|
|
self.assertFalse(cid1 != cid2)
|
|
self.assertTrue(cid1 != cid3)
|
|
|
|
|
|
class ChannelTests(TestBase):
|
|
|
|
def test_sequential_ids(self):
|
|
before = interpreters.channel_list_all()
|
|
id1 = interpreters.channel_create()
|
|
id2 = interpreters.channel_create()
|
|
id3 = interpreters.channel_create()
|
|
after = interpreters.channel_list_all()
|
|
|
|
self.assertEqual(id2, int(id1) + 1)
|
|
self.assertEqual(id3, int(id2) + 1)
|
|
self.assertEqual(set(after) - set(before), {id1, id2, id3})
|
|
|
|
def test_ids_global(self):
|
|
id1 = interpreters.create()
|
|
out = _run_output(id1, dedent("""
|
|
import _xxsubinterpreters as _interpreters
|
|
cid = _interpreters.channel_create()
|
|
print(int(cid))
|
|
"""))
|
|
cid1 = int(out.strip())
|
|
|
|
id2 = interpreters.create()
|
|
out = _run_output(id2, dedent("""
|
|
import _xxsubinterpreters as _interpreters
|
|
cid = _interpreters.channel_create()
|
|
print(int(cid))
|
|
"""))
|
|
cid2 = int(out.strip())
|
|
|
|
self.assertEqual(cid2, int(cid1) + 1)
|
|
|
|
####################
|
|
|
|
def test_drop_single_user(self):
|
|
cid = interpreters.channel_create()
|
|
interpreters.channel_send(cid, b'spam')
|
|
interpreters.channel_recv(cid)
|
|
interpreters.channel_drop_interpreter(cid, send=True, recv=True)
|
|
|
|
with self.assertRaises(interpreters.ChannelClosedError):
|
|
interpreters.channel_send(cid, b'eggs')
|
|
with self.assertRaises(interpreters.ChannelClosedError):
|
|
interpreters.channel_recv(cid)
|
|
|
|
def test_drop_multiple_users(self):
|
|
cid = interpreters.channel_create()
|
|
id1 = interpreters.create()
|
|
id2 = interpreters.create()
|
|
interpreters.run_string(id1, dedent(f"""
|
|
import _xxsubinterpreters as _interpreters
|
|
_interpreters.channel_send({int(cid)}, b'spam')
|
|
"""))
|
|
out = _run_output(id2, dedent(f"""
|
|
import _xxsubinterpreters as _interpreters
|
|
obj = _interpreters.channel_recv({int(cid)})
|
|
_interpreters.channel_drop_interpreter({int(cid)})
|
|
print(repr(obj))
|
|
"""))
|
|
interpreters.run_string(id1, dedent(f"""
|
|
_interpreters.channel_drop_interpreter({int(cid)})
|
|
"""))
|
|
|
|
self.assertEqual(out.strip(), "b'spam'")
|
|
|
|
def test_drop_no_kwargs(self):
|
|
cid = interpreters.channel_create()
|
|
interpreters.channel_send(cid, b'spam')
|
|
interpreters.channel_recv(cid)
|
|
interpreters.channel_drop_interpreter(cid)
|
|
|
|
with self.assertRaises(interpreters.ChannelClosedError):
|
|
interpreters.channel_send(cid, b'eggs')
|
|
with self.assertRaises(interpreters.ChannelClosedError):
|
|
interpreters.channel_recv(cid)
|
|
|
|
def test_drop_multiple_times(self):
|
|
cid = interpreters.channel_create()
|
|
interpreters.channel_send(cid, b'spam')
|
|
interpreters.channel_recv(cid)
|
|
interpreters.channel_drop_interpreter(cid, send=True, recv=True)
|
|
|
|
with self.assertRaises(interpreters.ChannelClosedError):
|
|
interpreters.channel_drop_interpreter(cid, send=True, recv=True)
|
|
|
|
def test_drop_with_unused_items(self):
|
|
cid = interpreters.channel_create()
|
|
interpreters.channel_send(cid, b'spam')
|
|
interpreters.channel_send(cid, b'ham')
|
|
interpreters.channel_drop_interpreter(cid, send=True, recv=True)
|
|
|
|
with self.assertRaises(interpreters.ChannelClosedError):
|
|
interpreters.channel_recv(cid)
|
|
|
|
def test_drop_never_used(self):
|
|
cid = interpreters.channel_create()
|
|
interpreters.channel_drop_interpreter(cid)
|
|
|
|
with self.assertRaises(interpreters.ChannelClosedError):
|
|
interpreters.channel_send(cid, b'spam')
|
|
with self.assertRaises(interpreters.ChannelClosedError):
|
|
interpreters.channel_recv(cid)
|
|
|
|
def test_drop_by_unassociated_interp(self):
|
|
cid = interpreters.channel_create()
|
|
interpreters.channel_send(cid, b'spam')
|
|
interp = interpreters.create()
|
|
interpreters.run_string(interp, dedent(f"""
|
|
import _xxsubinterpreters as _interpreters
|
|
_interpreters.channel_drop_interpreter({int(cid)})
|
|
"""))
|
|
obj = interpreters.channel_recv(cid)
|
|
interpreters.channel_drop_interpreter(cid)
|
|
|
|
with self.assertRaises(interpreters.ChannelClosedError):
|
|
interpreters.channel_send(cid, b'eggs')
|
|
self.assertEqual(obj, b'spam')
|
|
|
|
def test_drop_close_if_unassociated(self):
|
|
cid = interpreters.channel_create()
|
|
interp = interpreters.create()
|
|
interpreters.run_string(interp, dedent(f"""
|
|
import _xxsubinterpreters as _interpreters
|
|
obj = _interpreters.channel_send({int(cid)}, b'spam')
|
|
_interpreters.channel_drop_interpreter({int(cid)})
|
|
"""))
|
|
|
|
with self.assertRaises(interpreters.ChannelClosedError):
|
|
interpreters.channel_recv(cid)
|
|
|
|
def test_drop_partially(self):
|
|
# XXX Is partial close too wierd/confusing?
|
|
cid = interpreters.channel_create()
|
|
interpreters.channel_send(cid, None)
|
|
interpreters.channel_recv(cid)
|
|
interpreters.channel_send(cid, b'spam')
|
|
interpreters.channel_drop_interpreter(cid, send=True)
|
|
obj = interpreters.channel_recv(cid)
|
|
|
|
self.assertEqual(obj, b'spam')
|
|
|
|
def test_drop_used_multiple_times_by_single_user(self):
|
|
cid = interpreters.channel_create()
|
|
interpreters.channel_send(cid, b'spam')
|
|
interpreters.channel_send(cid, b'spam')
|
|
interpreters.channel_send(cid, b'spam')
|
|
interpreters.channel_recv(cid)
|
|
interpreters.channel_drop_interpreter(cid, send=True, recv=True)
|
|
|
|
with self.assertRaises(interpreters.ChannelClosedError):
|
|
interpreters.channel_send(cid, b'eggs')
|
|
with self.assertRaises(interpreters.ChannelClosedError):
|
|
interpreters.channel_recv(cid)
|
|
|
|
####################
|
|
|
|
def test_close_single_user(self):
|
|
cid = interpreters.channel_create()
|
|
interpreters.channel_send(cid, b'spam')
|
|
interpreters.channel_recv(cid)
|
|
interpreters.channel_close(cid)
|
|
|
|
with self.assertRaises(interpreters.ChannelClosedError):
|
|
interpreters.channel_send(cid, b'eggs')
|
|
with self.assertRaises(interpreters.ChannelClosedError):
|
|
interpreters.channel_recv(cid)
|
|
|
|
def test_close_multiple_users(self):
|
|
cid = interpreters.channel_create()
|
|
id1 = interpreters.create()
|
|
id2 = interpreters.create()
|
|
interpreters.run_string(id1, dedent(f"""
|
|
import _xxsubinterpreters as _interpreters
|
|
_interpreters.channel_send({int(cid)}, b'spam')
|
|
"""))
|
|
interpreters.run_string(id2, dedent(f"""
|
|
import _xxsubinterpreters as _interpreters
|
|
_interpreters.channel_recv({int(cid)})
|
|
"""))
|
|
interpreters.channel_close(cid)
|
|
with self.assertRaises(interpreters.RunFailedError) as cm:
|
|
interpreters.run_string(id1, dedent(f"""
|
|
_interpreters.channel_send({int(cid)}, b'spam')
|
|
"""))
|
|
self.assertIn('ChannelClosedError', str(cm.exception))
|
|
with self.assertRaises(interpreters.RunFailedError) as cm:
|
|
interpreters.run_string(id2, dedent(f"""
|
|
_interpreters.channel_send({int(cid)}, b'spam')
|
|
"""))
|
|
self.assertIn('ChannelClosedError', str(cm.exception))
|
|
|
|
def test_close_multiple_times(self):
|
|
cid = interpreters.channel_create()
|
|
interpreters.channel_send(cid, b'spam')
|
|
interpreters.channel_recv(cid)
|
|
interpreters.channel_close(cid)
|
|
|
|
with self.assertRaises(interpreters.ChannelClosedError):
|
|
interpreters.channel_close(cid)
|
|
|
|
def test_close_with_unused_items(self):
|
|
cid = interpreters.channel_create()
|
|
interpreters.channel_send(cid, b'spam')
|
|
interpreters.channel_send(cid, b'ham')
|
|
interpreters.channel_close(cid)
|
|
|
|
with self.assertRaises(interpreters.ChannelClosedError):
|
|
interpreters.channel_recv(cid)
|
|
|
|
def test_close_never_used(self):
|
|
cid = interpreters.channel_create()
|
|
interpreters.channel_close(cid)
|
|
|
|
with self.assertRaises(interpreters.ChannelClosedError):
|
|
interpreters.channel_send(cid, b'spam')
|
|
with self.assertRaises(interpreters.ChannelClosedError):
|
|
interpreters.channel_recv(cid)
|
|
|
|
def test_close_by_unassociated_interp(self):
|
|
cid = interpreters.channel_create()
|
|
interpreters.channel_send(cid, b'spam')
|
|
interp = interpreters.create()
|
|
interpreters.run_string(interp, dedent(f"""
|
|
import _xxsubinterpreters as _interpreters
|
|
_interpreters.channel_close({int(cid)})
|
|
"""))
|
|
with self.assertRaises(interpreters.ChannelClosedError):
|
|
interpreters.channel_recv(cid)
|
|
with self.assertRaises(interpreters.ChannelClosedError):
|
|
interpreters.channel_close(cid)
|
|
|
|
def test_close_used_multiple_times_by_single_user(self):
|
|
cid = interpreters.channel_create()
|
|
interpreters.channel_send(cid, b'spam')
|
|
interpreters.channel_send(cid, b'spam')
|
|
interpreters.channel_send(cid, b'spam')
|
|
interpreters.channel_recv(cid)
|
|
interpreters.channel_close(cid)
|
|
|
|
with self.assertRaises(interpreters.ChannelClosedError):
|
|
interpreters.channel_send(cid, b'eggs')
|
|
with self.assertRaises(interpreters.ChannelClosedError):
|
|
interpreters.channel_recv(cid)
|
|
|
|
####################
|
|
|
|
def test_send_recv_main(self):
|
|
cid = interpreters.channel_create()
|
|
orig = b'spam'
|
|
interpreters.channel_send(cid, orig)
|
|
obj = interpreters.channel_recv(cid)
|
|
|
|
self.assertEqual(obj, orig)
|
|
self.assertIsNot(obj, orig)
|
|
|
|
def test_send_recv_same_interpreter(self):
|
|
id1 = interpreters.create()
|
|
out = _run_output(id1, dedent("""
|
|
import _xxsubinterpreters as _interpreters
|
|
cid = _interpreters.channel_create()
|
|
orig = b'spam'
|
|
_interpreters.channel_send(cid, orig)
|
|
obj = _interpreters.channel_recv(cid)
|
|
assert obj is not orig
|
|
assert obj == orig
|
|
"""))
|
|
|
|
def test_send_recv_different_interpreters(self):
|
|
cid = interpreters.channel_create()
|
|
id1 = interpreters.create()
|
|
out = _run_output(id1, dedent(f"""
|
|
import _xxsubinterpreters as _interpreters
|
|
_interpreters.channel_send({int(cid)}, b'spam')
|
|
"""))
|
|
obj = interpreters.channel_recv(cid)
|
|
|
|
self.assertEqual(obj, b'spam')
|
|
|
|
def test_send_not_found(self):
|
|
with self.assertRaises(interpreters.ChannelNotFoundError):
|
|
interpreters.channel_send(10, b'spam')
|
|
|
|
def test_recv_not_found(self):
|
|
with self.assertRaises(interpreters.ChannelNotFoundError):
|
|
interpreters.channel_recv(10)
|
|
|
|
def test_recv_empty(self):
|
|
cid = interpreters.channel_create()
|
|
with self.assertRaises(interpreters.ChannelEmptyError):
|
|
interpreters.channel_recv(cid)
|
|
|
|
def test_run_string_arg(self):
|
|
cid = interpreters.channel_create()
|
|
interp = interpreters.create()
|
|
|
|
out = _run_output(interp, dedent("""
|
|
import _xxsubinterpreters as _interpreters
|
|
print(cid.end)
|
|
_interpreters.channel_send(cid, b'spam')
|
|
"""),
|
|
dict(cid=cid.send))
|
|
obj = interpreters.channel_recv(cid)
|
|
|
|
self.assertEqual(obj, b'spam')
|
|
self.assertEqual(out.strip(), 'send')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|