mirror of https://github.com/python/cpython
issue29167: fix race condition in (Int)Flag
This commit is contained in:
commit
0105606f55
15
Lib/enum.py
15
Lib/enum.py
|
@ -690,7 +690,9 @@ class Flag(Enum):
|
||||||
pseudo_member = object.__new__(cls)
|
pseudo_member = object.__new__(cls)
|
||||||
pseudo_member._name_ = None
|
pseudo_member._name_ = None
|
||||||
pseudo_member._value_ = value
|
pseudo_member._value_ = value
|
||||||
cls._value2member_map_[value] = pseudo_member
|
# use setdefault in case another thread already created a composite
|
||||||
|
# with this value
|
||||||
|
pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member)
|
||||||
return pseudo_member
|
return pseudo_member
|
||||||
|
|
||||||
def __contains__(self, other):
|
def __contains__(self, other):
|
||||||
|
@ -785,7 +787,9 @@ class IntFlag(int, Flag):
|
||||||
pseudo_member = int.__new__(cls, value)
|
pseudo_member = int.__new__(cls, value)
|
||||||
pseudo_member._name_ = None
|
pseudo_member._name_ = None
|
||||||
pseudo_member._value_ = value
|
pseudo_member._value_ = value
|
||||||
cls._value2member_map_[value] = pseudo_member
|
# use setdefault in case another thread already created a composite
|
||||||
|
# with this value
|
||||||
|
pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member)
|
||||||
return pseudo_member
|
return pseudo_member
|
||||||
|
|
||||||
def __or__(self, other):
|
def __or__(self, other):
|
||||||
|
@ -835,18 +839,21 @@ def _decompose(flag, value):
|
||||||
# _decompose is only called if the value is not named
|
# _decompose is only called if the value is not named
|
||||||
not_covered = value
|
not_covered = value
|
||||||
negative = value < 0
|
negative = value < 0
|
||||||
|
# issue29167: wrap accesses to _value2member_map_ in a list to avoid race
|
||||||
|
# conditions between iterating over it and having more psuedo-
|
||||||
|
# members added to it
|
||||||
if negative:
|
if negative:
|
||||||
# only check for named flags
|
# only check for named flags
|
||||||
flags_to_check = [
|
flags_to_check = [
|
||||||
(m, v)
|
(m, v)
|
||||||
for v, m in flag._value2member_map_.items()
|
for v, m in list(flag._value2member_map_.items())
|
||||||
if m.name is not None
|
if m.name is not None
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
# check for named flags and powers-of-two flags
|
# check for named flags and powers-of-two flags
|
||||||
flags_to_check = [
|
flags_to_check = [
|
||||||
(m, v)
|
(m, v)
|
||||||
for v, m in flag._value2member_map_.items()
|
for v, m in list(flag._value2member_map_.items())
|
||||||
if m.name is not None or _power_of_two(v)
|
if m.name is not None or _power_of_two(v)
|
||||||
]
|
]
|
||||||
members = []
|
members = []
|
||||||
|
|
|
@ -7,6 +7,11 @@ from enum import Enum, IntEnum, EnumMeta, Flag, IntFlag, unique, auto
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from pickle import dumps, loads, PicklingError, HIGHEST_PROTOCOL
|
from pickle import dumps, loads, PicklingError, HIGHEST_PROTOCOL
|
||||||
from test import support
|
from test import support
|
||||||
|
try:
|
||||||
|
import threading
|
||||||
|
except ImportError:
|
||||||
|
threading = None
|
||||||
|
|
||||||
|
|
||||||
# for pickle tests
|
# for pickle tests
|
||||||
try:
|
try:
|
||||||
|
@ -1983,6 +1988,45 @@ class TestFlag(unittest.TestCase):
|
||||||
d = 6
|
d = 6
|
||||||
self.assertEqual(repr(Bizarre(7)), '<Bizarre.d|c|b: 7>')
|
self.assertEqual(repr(Bizarre(7)), '<Bizarre.d|c|b: 7>')
|
||||||
|
|
||||||
|
@unittest.skipUnless(threading, 'Threading required for this test.')
|
||||||
|
@support.reap_threads
|
||||||
|
def test_unique_composite(self):
|
||||||
|
# override __eq__ to be identity only
|
||||||
|
class TestFlag(Flag):
|
||||||
|
one = auto()
|
||||||
|
two = auto()
|
||||||
|
three = auto()
|
||||||
|
four = auto()
|
||||||
|
five = auto()
|
||||||
|
six = auto()
|
||||||
|
seven = auto()
|
||||||
|
eight = auto()
|
||||||
|
def __eq__(self, other):
|
||||||
|
return self is other
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(self._value_)
|
||||||
|
# have multiple threads competing to complete the composite members
|
||||||
|
seen = set()
|
||||||
|
failed = False
|
||||||
|
def cycle_enum():
|
||||||
|
nonlocal failed
|
||||||
|
try:
|
||||||
|
for i in range(256):
|
||||||
|
seen.add(TestFlag(i))
|
||||||
|
except Exception:
|
||||||
|
failed = True
|
||||||
|
threads = [
|
||||||
|
threading.Thread(target=cycle_enum)
|
||||||
|
for _ in range(8)
|
||||||
|
]
|
||||||
|
with support.start_threads(threads):
|
||||||
|
pass
|
||||||
|
# check that only 248 members were created
|
||||||
|
self.assertFalse(
|
||||||
|
failed,
|
||||||
|
'at least one thread failed while creating composite members')
|
||||||
|
self.assertEqual(256, len(seen), 'too many composite members created')
|
||||||
|
|
||||||
|
|
||||||
class TestIntFlag(unittest.TestCase):
|
class TestIntFlag(unittest.TestCase):
|
||||||
"""Tests of the IntFlags."""
|
"""Tests of the IntFlags."""
|
||||||
|
@ -2275,6 +2319,46 @@ class TestIntFlag(unittest.TestCase):
|
||||||
for f in Open:
|
for f in Open:
|
||||||
self.assertEqual(bool(f.value), bool(f))
|
self.assertEqual(bool(f.value), bool(f))
|
||||||
|
|
||||||
|
@unittest.skipUnless(threading, 'Threading required for this test.')
|
||||||
|
@support.reap_threads
|
||||||
|
def test_unique_composite(self):
|
||||||
|
# override __eq__ to be identity only
|
||||||
|
class TestFlag(IntFlag):
|
||||||
|
one = auto()
|
||||||
|
two = auto()
|
||||||
|
three = auto()
|
||||||
|
four = auto()
|
||||||
|
five = auto()
|
||||||
|
six = auto()
|
||||||
|
seven = auto()
|
||||||
|
eight = auto()
|
||||||
|
def __eq__(self, other):
|
||||||
|
return self is other
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(self._value_)
|
||||||
|
# have multiple threads competing to complete the composite members
|
||||||
|
seen = set()
|
||||||
|
failed = False
|
||||||
|
def cycle_enum():
|
||||||
|
nonlocal failed
|
||||||
|
try:
|
||||||
|
for i in range(256):
|
||||||
|
seen.add(TestFlag(i))
|
||||||
|
except Exception:
|
||||||
|
failed = True
|
||||||
|
threads = [
|
||||||
|
threading.Thread(target=cycle_enum)
|
||||||
|
for _ in range(8)
|
||||||
|
]
|
||||||
|
with support.start_threads(threads):
|
||||||
|
pass
|
||||||
|
# check that only 248 members were created
|
||||||
|
self.assertFalse(
|
||||||
|
failed,
|
||||||
|
'at least one thread failed while creating composite members')
|
||||||
|
self.assertEqual(256, len(seen), 'too many composite members created')
|
||||||
|
|
||||||
|
|
||||||
class TestUnique(unittest.TestCase):
|
class TestUnique(unittest.TestCase):
|
||||||
|
|
||||||
def test_unique_clean(self):
|
def test_unique_clean(self):
|
||||||
|
@ -2488,7 +2572,8 @@ CONVERT_TEST_NAME_F = 5
|
||||||
class TestIntEnumConvert(unittest.TestCase):
|
class TestIntEnumConvert(unittest.TestCase):
|
||||||
def test_convert_value_lookup_priority(self):
|
def test_convert_value_lookup_priority(self):
|
||||||
test_type = enum.IntEnum._convert(
|
test_type = enum.IntEnum._convert(
|
||||||
'UnittestConvert', 'test.test_enum',
|
'UnittestConvert',
|
||||||
|
('test.test_enum', '__main__')[__name__=='__main__'],
|
||||||
filter=lambda x: x.startswith('CONVERT_TEST_'))
|
filter=lambda x: x.startswith('CONVERT_TEST_'))
|
||||||
# We don't want the reverse lookup value to vary when there are
|
# We don't want the reverse lookup value to vary when there are
|
||||||
# multiple possible names for a given value. It should always
|
# multiple possible names for a given value. It should always
|
||||||
|
@ -2497,7 +2582,8 @@ class TestIntEnumConvert(unittest.TestCase):
|
||||||
|
|
||||||
def test_convert(self):
|
def test_convert(self):
|
||||||
test_type = enum.IntEnum._convert(
|
test_type = enum.IntEnum._convert(
|
||||||
'UnittestConvert', 'test.test_enum',
|
'UnittestConvert',
|
||||||
|
('test.test_enum', '__main__')[__name__=='__main__'],
|
||||||
filter=lambda x: x.startswith('CONVERT_TEST_'))
|
filter=lambda x: x.startswith('CONVERT_TEST_'))
|
||||||
# Ensure that test_type has all of the desired names and values.
|
# Ensure that test_type has all of the desired names and values.
|
||||||
self.assertEqual(test_type.CONVERT_TEST_NAME_F,
|
self.assertEqual(test_type.CONVERT_TEST_NAME_F,
|
||||||
|
|
Loading…
Reference in New Issue