issue29167: fix race condition in (Int)Flag

This commit is contained in:
Ethan Furman 2017-01-24 12:13:34 -08:00
commit 0105606f55
2 changed files with 99 additions and 6 deletions

View File

@ -690,7 +690,9 @@ class Flag(Enum):
pseudo_member = object.__new__(cls) pseudo_member = object.__new__(cls)
pseudo_member._name_ = None pseudo_member._name_ = None
pseudo_member._value_ = value pseudo_member._value_ = value
cls._value2member_map_[value] = pseudo_member # use setdefault in case another thread already created a composite
# with this value
pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member)
return pseudo_member return pseudo_member
def __contains__(self, other): def __contains__(self, other):
@ -785,7 +787,9 @@ class IntFlag(int, Flag):
pseudo_member = int.__new__(cls, value) pseudo_member = int.__new__(cls, value)
pseudo_member._name_ = None pseudo_member._name_ = None
pseudo_member._value_ = value pseudo_member._value_ = value
cls._value2member_map_[value] = pseudo_member # use setdefault in case another thread already created a composite
# with this value
pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member)
return pseudo_member return pseudo_member
def __or__(self, other): def __or__(self, other):
@ -835,18 +839,21 @@ def _decompose(flag, value):
# _decompose is only called if the value is not named # _decompose is only called if the value is not named
not_covered = value not_covered = value
negative = value < 0 negative = value < 0
# issue29167: wrap accesses to _value2member_map_ in a list to avoid race
# conditions between iterating over it and having more psuedo-
# members added to it
if negative: if negative:
# only check for named flags # only check for named flags
flags_to_check = [ flags_to_check = [
(m, v) (m, v)
for v, m in flag._value2member_map_.items() for v, m in list(flag._value2member_map_.items())
if m.name is not None if m.name is not None
] ]
else: else:
# check for named flags and powers-of-two flags # check for named flags and powers-of-two flags
flags_to_check = [ flags_to_check = [
(m, v) (m, v)
for v, m in flag._value2member_map_.items() for v, m in list(flag._value2member_map_.items())
if m.name is not None or _power_of_two(v) if m.name is not None or _power_of_two(v)
] ]
members = [] members = []

View File

@ -7,6 +7,11 @@ from enum import Enum, IntEnum, EnumMeta, Flag, IntFlag, unique, auto
from io import StringIO from io import StringIO
from pickle import dumps, loads, PicklingError, HIGHEST_PROTOCOL from pickle import dumps, loads, PicklingError, HIGHEST_PROTOCOL
from test import support from test import support
try:
import threading
except ImportError:
threading = None
# for pickle tests # for pickle tests
try: try:
@ -1983,6 +1988,45 @@ class TestFlag(unittest.TestCase):
d = 6 d = 6
self.assertEqual(repr(Bizarre(7)), '<Bizarre.d|c|b: 7>') self.assertEqual(repr(Bizarre(7)), '<Bizarre.d|c|b: 7>')
@unittest.skipUnless(threading, 'Threading required for this test.')
@support.reap_threads
def test_unique_composite(self):
# override __eq__ to be identity only
class TestFlag(Flag):
one = auto()
two = auto()
three = auto()
four = auto()
five = auto()
six = auto()
seven = auto()
eight = auto()
def __eq__(self, other):
return self is other
def __hash__(self):
return hash(self._value_)
# have multiple threads competing to complete the composite members
seen = set()
failed = False
def cycle_enum():
nonlocal failed
try:
for i in range(256):
seen.add(TestFlag(i))
except Exception:
failed = True
threads = [
threading.Thread(target=cycle_enum)
for _ in range(8)
]
with support.start_threads(threads):
pass
# check that only 248 members were created
self.assertFalse(
failed,
'at least one thread failed while creating composite members')
self.assertEqual(256, len(seen), 'too many composite members created')
class TestIntFlag(unittest.TestCase): class TestIntFlag(unittest.TestCase):
"""Tests of the IntFlags.""" """Tests of the IntFlags."""
@ -2275,6 +2319,46 @@ class TestIntFlag(unittest.TestCase):
for f in Open: for f in Open:
self.assertEqual(bool(f.value), bool(f)) self.assertEqual(bool(f.value), bool(f))
@unittest.skipUnless(threading, 'Threading required for this test.')
@support.reap_threads
def test_unique_composite(self):
# override __eq__ to be identity only
class TestFlag(IntFlag):
one = auto()
two = auto()
three = auto()
four = auto()
five = auto()
six = auto()
seven = auto()
eight = auto()
def __eq__(self, other):
return self is other
def __hash__(self):
return hash(self._value_)
# have multiple threads competing to complete the composite members
seen = set()
failed = False
def cycle_enum():
nonlocal failed
try:
for i in range(256):
seen.add(TestFlag(i))
except Exception:
failed = True
threads = [
threading.Thread(target=cycle_enum)
for _ in range(8)
]
with support.start_threads(threads):
pass
# check that only 248 members were created
self.assertFalse(
failed,
'at least one thread failed while creating composite members')
self.assertEqual(256, len(seen), 'too many composite members created')
class TestUnique(unittest.TestCase): class TestUnique(unittest.TestCase):
def test_unique_clean(self): def test_unique_clean(self):
@ -2488,7 +2572,8 @@ CONVERT_TEST_NAME_F = 5
class TestIntEnumConvert(unittest.TestCase): class TestIntEnumConvert(unittest.TestCase):
def test_convert_value_lookup_priority(self): def test_convert_value_lookup_priority(self):
test_type = enum.IntEnum._convert( test_type = enum.IntEnum._convert(
'UnittestConvert', 'test.test_enum', 'UnittestConvert',
('test.test_enum', '__main__')[__name__=='__main__'],
filter=lambda x: x.startswith('CONVERT_TEST_')) filter=lambda x: x.startswith('CONVERT_TEST_'))
# We don't want the reverse lookup value to vary when there are # We don't want the reverse lookup value to vary when there are
# multiple possible names for a given value. It should always # multiple possible names for a given value. It should always
@ -2497,7 +2582,8 @@ class TestIntEnumConvert(unittest.TestCase):
def test_convert(self): def test_convert(self):
test_type = enum.IntEnum._convert( test_type = enum.IntEnum._convert(
'UnittestConvert', 'test.test_enum', 'UnittestConvert',
('test.test_enum', '__main__')[__name__=='__main__'],
filter=lambda x: x.startswith('CONVERT_TEST_')) filter=lambda x: x.startswith('CONVERT_TEST_'))
# Ensure that test_type has all of the desired names and values. # Ensure that test_type has all of the desired names and values.
self.assertEqual(test_type.CONVERT_TEST_NAME_F, self.assertEqual(test_type.CONVERT_TEST_NAME_F,