import gc import time import unittest import weakref from ast import Or from functools import partial from threading import Thread from unittest import TestCase try: import _testcapi except ImportError: _testcapi = None from test.support import threading_helper @threading_helper.requires_working_threading() class TestDict(TestCase): def test_racing_creation_shared_keys(self): """Verify that creating dictionaries is thread safe when we have a type with shared keys""" class C(int): pass self.racing_creation(C) def test_racing_creation_no_shared_keys(self): """Verify that creating dictionaries is thread safe when we have a type with an ordinary dict""" self.racing_creation(Or) def test_racing_creation_inline_values_invalid(self): """Verify that re-creating a dict after we have invalid inline values is thread safe""" class C: pass def make_obj(): a = C() # Make object, make inline values invalid, and then delete dict a.__dict__ = {} del a.__dict__ return a self.racing_creation(make_obj) def test_racing_creation_nonmanaged_dict(self): """Verify that explicit creation of an unmanaged dict is thread safe outside of the normal attribute setting code path""" def make_obj(): def f(): pass return f def set(func, name, val): # Force creation of the dict via PyObject_GenericGetDict func.__dict__[name] = val self.racing_creation(make_obj, set) def racing_creation(self, cls, set=setattr): objects = [] processed = [] OBJECT_COUNT = 100 THREAD_COUNT = 10 CUR = 0 for i in range(OBJECT_COUNT): objects.append(cls()) def writer_func(name): last = -1 while True: if CUR == last: continue elif CUR == OBJECT_COUNT: break obj = objects[CUR] set(obj, name, name) last = CUR processed.append(name) writers = [] for x in range(THREAD_COUNT): writer = Thread(target=partial(writer_func, f"a{x:02}")) writers.append(writer) writer.start() for i in range(OBJECT_COUNT): CUR = i while len(processed) != THREAD_COUNT: time.sleep(0.001) processed.clear() CUR = OBJECT_COUNT for writer in writers: writer.join() for obj_idx, obj in enumerate(objects): assert ( len(obj.__dict__) == THREAD_COUNT ), f"{len(obj.__dict__)} {obj.__dict__!r} {obj_idx}" for i in range(THREAD_COUNT): assert f"a{i:02}" in obj.__dict__, f"a{i:02} missing at {obj_idx}" def test_racing_set_dict(self): """Races assigning to __dict__ should be thread safe""" def f(): pass l = [] THREAD_COUNT = 10 class MyDict(dict): pass def writer_func(l): for i in range(1000): d = MyDict() l.append(weakref.ref(d)) f.__dict__ = d lists = [] writers = [] for x in range(THREAD_COUNT): thread_list = [] lists.append(thread_list) writer = Thread(target=partial(writer_func, thread_list)) writers.append(writer) for writer in writers: writer.start() for writer in writers: writer.join() f.__dict__ = {} gc.collect() for thread_list in lists: for ref in thread_list: self.assertIsNone(ref()) def test_racing_set_object_dict(self): """Races assigning to __dict__ should be thread safe""" class C: pass class MyDict(dict): pass for cyclic in (False, True): f = C() f.__dict__ = {"foo": 42} THREAD_COUNT = 10 def writer_func(l): for i in range(1000): if cyclic: other_d = {} d = MyDict({"foo": 100}) if cyclic: d["x"] = other_d other_d["bar"] = d l.append(weakref.ref(d)) f.__dict__ = d def reader_func(): for i in range(1000): f.foo lists = [] readers = [] writers = [] for x in range(THREAD_COUNT): thread_list = [] lists.append(thread_list) writer = Thread(target=partial(writer_func, thread_list)) writers.append(writer) for x in range(THREAD_COUNT): reader = Thread(target=partial(reader_func)) readers.append(reader) for writer in writers: writer.start() for reader in readers: reader.start() for writer in writers: writer.join() for reader in readers: reader.join() f.__dict__ = {} gc.collect() gc.collect() count = 0 ids = set() for thread_list in lists: for i, ref in enumerate(thread_list): if ref() is None: continue count += 1 ids.add(id(ref())) count += 1 self.assertEqual(count, 0) if __name__ == "__main__": unittest.main()