cpython/Lib/test/test_interpreters.py

536 lines
15 KiB
Python

import contextlib
import os
import threading
from textwrap import dedent
import unittest
import time
import _xxsubinterpreters as _interpreters
from test.support import interpreters
def _captured_script(script):
r, w = os.pipe()
indented = script.replace('\n', '\n ')
wrapped = dedent(f"""
import contextlib
with open({w}, 'w') as spipe:
with contextlib.redirect_stdout(spipe):
{indented}
""")
return wrapped, open(r)
def clean_up_interpreters():
for interp in interpreters.list_all():
if interp.id == 0: # main
continue
try:
interp.close()
except RuntimeError:
pass # already destroyed
def _run_output(interp, request, shared=None):
script, rpipe = _captured_script(request)
with rpipe:
interp.run(script)
return rpipe.read()
@contextlib.contextmanager
def _running(interp):
r, w = os.pipe()
def run():
interp.run(dedent(f"""
# wait for "signal"
with open({r}) as rpipe:
rpipe.read()
"""))
t = threading.Thread(target=run)
t.start()
yield
with open(w, 'w') as spipe:
spipe.write('done')
t.join()
class TestBase(unittest.TestCase):
def tearDown(self):
clean_up_interpreters()
class CreateTests(TestBase):
def test_in_main(self):
interp = interpreters.create()
lst = interpreters.list_all()
self.assertEqual(interp.id, lst[1].id)
def test_in_thread(self):
lock = threading.Lock()
id = None
interp = interpreters.create()
lst = interpreters.list_all()
def f():
nonlocal id
id = interp.id
lock.acquire()
lock.release()
t = threading.Thread(target=f)
with lock:
t.start()
t.join()
self.assertEqual(interp.id, lst[1].id)
def test_in_subinterpreter(self):
main, = interpreters.list_all()
interp = interpreters.create()
out = _run_output(interp, dedent("""
from test.support import interpreters
interp = interpreters.create()
print(interp)
"""))
interp2 = out.strip()
self.assertEqual(len(set(interpreters.list_all())), len({main, interp, interp2}))
def test_after_destroy_all(self):
before = set(interpreters.list_all())
# Create 3 subinterpreters.
interp_lst = []
for _ in range(3):
interps = interpreters.create()
interp_lst.append(interps)
# Now destroy them.
for interp in interp_lst:
interp.close()
# Finally, create another.
interp = interpreters.create()
self.assertEqual(len(set(interpreters.list_all())), len(before | {interp}))
def test_after_destroy_some(self):
before = set(interpreters.list_all())
# Create 3 subinterpreters.
interp1 = interpreters.create()
interp2 = interpreters.create()
interp3 = interpreters.create()
# Now destroy 2 of them.
interp1.close()
interp2.close()
# Finally, create another.
interp = interpreters.create()
self.assertEqual(len(set(interpreters.list_all())), len(before | {interp3, interp}))
class GetCurrentTests(TestBase):
def test_main(self):
main_interp_id = _interpreters.get_main()
cur_interp_id = interpreters.get_current().id
self.assertEqual(cur_interp_id, main_interp_id)
def test_subinterpreter(self):
main = _interpreters.get_main()
interp = interpreters.create()
out = _run_output(interp, dedent("""
from test.support import interpreters
cur = interpreters.get_current()
print(cur)
"""))
cur = out.strip()
self.assertNotEqual(cur, main)
class ListAllTests(TestBase):
def test_initial(self):
interps = interpreters.list_all()
self.assertEqual(1, len(interps))
def test_after_creating(self):
main = interpreters.get_current()
first = interpreters.create()
second = interpreters.create()
ids = []
for interp in interpreters.list_all():
ids.append(interp.id)
self.assertEqual(ids, [main.id, first.id, second.id])
def test_after_destroying(self):
main = interpreters.get_current()
first = interpreters.create()
second = interpreters.create()
first.close()
ids = []
for interp in interpreters.list_all():
ids.append(interp.id)
self.assertEqual(ids, [main.id, second.id])
class TestInterpreterId(TestBase):
def test_in_main(self):
main = interpreters.get_current()
self.assertEqual(0, main.id)
def test_with_custom_num(self):
interp = interpreters.Interpreter(1)
self.assertEqual(1, interp.id)
def test_for_readonly_property(self):
interp = interpreters.Interpreter(1)
with self.assertRaises(AttributeError):
interp.id = 2
class TestInterpreterIsRunning(TestBase):
def test_main(self):
main = interpreters.get_current()
self.assertTrue(main.is_running())
def test_subinterpreter(self):
interp = interpreters.create()
self.assertFalse(interp.is_running())
with _running(interp):
self.assertTrue(interp.is_running())
self.assertFalse(interp.is_running())
def test_from_subinterpreter(self):
interp = interpreters.create()
out = _run_output(interp, dedent(f"""
import _xxsubinterpreters as _interpreters
if _interpreters.is_running({interp.id}):
print(True)
else:
print(False)
"""))
self.assertEqual(out.strip(), 'True')
def test_already_destroyed(self):
interp = interpreters.create()
interp.close()
with self.assertRaises(RuntimeError):
interp.is_running()
class TestInterpreterDestroy(TestBase):
def test_basic(self):
interp1 = interpreters.create()
interp2 = interpreters.create()
interp3 = interpreters.create()
self.assertEqual(4, len(interpreters.list_all()))
interp2.close()
self.assertEqual(3, len(interpreters.list_all()))
def test_all(self):
before = set(interpreters.list_all())
interps = set()
for _ in range(3):
interp = interpreters.create()
interps.add(interp)
self.assertEqual(len(set(interpreters.list_all())), len(before | interps))
for interp in interps:
interp.close()
self.assertEqual(len(set(interpreters.list_all())), len(before))
def test_main(self):
main, = interpreters.list_all()
with self.assertRaises(RuntimeError):
main.close()
def f():
with self.assertRaises(RuntimeError):
main.close()
t = threading.Thread(target=f)
t.start()
t.join()
def test_already_destroyed(self):
interp = interpreters.create()
interp.close()
with self.assertRaises(RuntimeError):
interp.close()
def test_from_current(self):
main, = interpreters.list_all()
interp = interpreters.create()
script = dedent(f"""
from test.support import interpreters
try:
main = interpreters.get_current()
main.close()
except RuntimeError:
pass
""")
interp.run(script)
self.assertEqual(len(set(interpreters.list_all())), len({main, interp}))
def test_from_sibling(self):
main, = interpreters.list_all()
interp1 = interpreters.create()
script = dedent(f"""
from test.support import interpreters
interp2 = interpreters.create()
interp2.close()
""")
interp1.run(script)
self.assertEqual(len(set(interpreters.list_all())), len({main, interp1}))
def test_from_other_thread(self):
interp = interpreters.create()
def f():
interp.close()
t = threading.Thread(target=f)
t.start()
t.join()
def test_still_running(self):
main, = interpreters.list_all()
interp = interpreters.create()
with _running(interp):
with self.assertRaises(RuntimeError):
interp.close()
self.assertTrue(interp.is_running())
class TestInterpreterRun(TestBase):
SCRIPT = dedent("""
with open('{}', 'w') as out:
out.write('{}')
""")
FILENAME = 'spam'
def setUp(self):
super().setUp()
self.interp = interpreters.create()
self._fs = None
def tearDown(self):
if self._fs is not None:
self._fs.close()
super().tearDown()
@property
def fs(self):
if self._fs is None:
self._fs = FSFixture(self)
return self._fs
def test_success(self):
script, file = _captured_script('print("it worked!", end="")')
with file:
self.interp.run(script)
out = file.read()
self.assertEqual(out, 'it worked!')
def test_in_thread(self):
script, file = _captured_script('print("it worked!", end="")')
with file:
def f():
self.interp.run(script)
t = threading.Thread(target=f)
t.start()
t.join()
out = file.read()
self.assertEqual(out, 'it worked!')
@unittest.skipUnless(hasattr(os, 'fork'), "test needs os.fork()")
def test_fork(self):
import tempfile
with tempfile.NamedTemporaryFile('w+') as file:
file.write('')
file.flush()
expected = 'spam spam spam spam spam'
script = dedent(f"""
import os
try:
os.fork()
except RuntimeError:
with open('{file.name}', 'w') as out:
out.write('{expected}')
""")
self.interp.run(script)
file.seek(0)
content = file.read()
self.assertEqual(content, expected)
def test_already_running(self):
with _running(self.interp):
with self.assertRaises(RuntimeError):
self.interp.run('print("spam")')
def test_bad_script(self):
with self.assertRaises(TypeError):
self.interp.run(10)
def test_bytes_for_script(self):
with self.assertRaises(TypeError):
self.interp.run(b'print("spam")')
class TestIsShareable(TestBase):
def test_default_shareables(self):
shareables = [
# singletons
None,
# builtin objects
b'spam',
'spam',
10,
-10,
]
for obj in shareables:
with self.subTest(obj):
self.assertTrue(
interpreters.is_shareable(obj))
def test_not_shareable(self):
class Cheese:
def __init__(self, name):
self.name = name
def __str__(self):
return self.name
class SubBytes(bytes):
"""A subclass of a shareable type."""
not_shareables = [
# singletons
True,
False,
NotImplemented,
...,
# builtin types and objects
type,
object,
object(),
Exception(),
100.0,
# user-defined types and objects
Cheese,
Cheese('Wensleydale'),
SubBytes(b'spam'),
]
for obj in not_shareables:
with self.subTest(repr(obj)):
self.assertFalse(
interpreters.is_shareable(obj))
class TestChannel(TestBase):
def test_create_cid(self):
r, s = interpreters.create_channel()
self.assertIsInstance(r, interpreters.RecvChannel)
self.assertIsInstance(s, interpreters.SendChannel)
def test_sequential_ids(self):
before = interpreters.list_all_channels()
channels1 = interpreters.create_channel()
channels2 = interpreters.create_channel()
channels3 = interpreters.create_channel()
after = interpreters.list_all_channels()
self.assertEqual(len(set(after) - set(before)),
len({channels1, channels2, channels3}))
class TestSendRecv(TestBase):
def test_send_recv_main(self):
r, s = interpreters.create_channel()
orig = b'spam'
s.send(orig)
obj = r.recv()
self.assertEqual(obj, orig)
self.assertIsNot(obj, orig)
def test_send_recv_same_interpreter(self):
interp = interpreters.create()
out = _run_output(interp, dedent("""
from test.support import interpreters
r, s = interpreters.create_channel()
orig = b'spam'
s.send(orig)
obj = r.recv()
assert obj is not orig
assert obj == orig
"""))
def test_send_recv_different_threads(self):
r, s = interpreters.create_channel()
def f():
while True:
try:
obj = r.recv()
break
except interpreters.ChannelEmptyError:
time.sleep(0.1)
s.send(obj)
t = threading.Thread(target=f)
t.start()
s.send(b'spam')
t.join()
obj = r.recv()
self.assertEqual(obj, b'spam')
def test_send_recv_nowait_main(self):
r, s = interpreters.create_channel()
orig = b'spam'
s.send(orig)
obj = r.recv_nowait()
self.assertEqual(obj, orig)
self.assertIsNot(obj, orig)
def test_send_recv_nowait_same_interpreter(self):
interp = interpreters.create()
out = _run_output(interp, dedent("""
from test.support import interpreters
r, s = interpreters.create_channel()
orig = b'spam'
s.send(orig)
obj = r.recv_nowait()
assert obj is not orig
assert obj == orig
"""))
r, s = interpreters.create_channel()
def f():
while True:
try:
obj = r.recv_nowait()
break
except _interpreters.ChannelEmptyError:
time.sleep(0.1)
s.send(obj)