Close #18508 -- fix _value2member_map to always have the member's value

This commit is contained in:
Ethan Furman 2013-07-19 19:35:56 -07:00
parent e410f267f1
commit 2aa2732eaf
2 changed files with 31 additions and 11 deletions

View File

@ -1,5 +1,3 @@
"""Python Enumerations"""
import sys
from collections import OrderedDict
from types import MappingProxyType
@ -154,11 +152,13 @@ class EnumMeta(type):
args = (args, ) # wrap it one more time
if not use_args:
enum_member = __new__(enum_class)
enum_member._value = value
original_value = value
else:
enum_member = __new__(enum_class, *args)
if not hasattr(enum_member, '_value'):
enum_member._value = member_type(*args)
original_value = member_type(*args)
if not hasattr(enum_member, '_value'):
enum_member._value = original_value
value = enum_member._value
enum_member._member_type = member_type
enum_member._name = member_name
enum_member.__init__(*args)
@ -416,12 +416,14 @@ class Enum(metaclass=EnumMeta):
return value
# by-value search for a matching enum member
# see if it's in the reverse mapping (for hashable values)
if value in cls._value2member_map:
return cls._value2member_map[value]
# not there, now do long search -- O(n) behavior
for member in cls._member_map.values():
if member.value == value:
return member
try:
if value in cls._value2member_map:
return cls._value2member_map[value]
except TypeError:
# not there, now do long search -- O(n) behavior
for member in cls._member_map.values():
if member.value == value:
return member
raise ValueError("%s is not a valid %s" % (value, cls.__name__))
def __repr__(self):

View File

@ -694,6 +694,7 @@ class TestEnum(unittest.TestCase):
x = ('the-x', 1)
y = ('the-y', 2)
self.assertIs(NEI.__new__, Enum.__new__)
self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)")
globals()['NamedInt'] = NamedInt
@ -785,6 +786,7 @@ class TestEnum(unittest.TestCase):
[AutoNumber.first, AutoNumber.second, AutoNumber.third],
)
self.assertEqual(int(AutoNumber.second), 2)
self.assertEqual(AutoNumber.third.value, 3)
self.assertIs(AutoNumber(1), AutoNumber.first)
def test_inherited_new_from_enhanced_enum(self):
@ -916,6 +918,22 @@ class TestEnum(unittest.TestCase):
self.assertEqual(round(Planet.EARTH.surface_gravity, 2), 9.80)
self.assertEqual(Planet.EARTH.value, (5.976e+24, 6.37814e6))
def test_nonhash_value(self):
class AutoNumberInAList(Enum):
def __new__(cls):
value = [len(cls.__members__) + 1]
obj = object.__new__(cls)
obj._value = value
return obj
class ColorInAList(AutoNumberInAList):
red = ()
green = ()
blue = ()
self.assertEqual(list(ColorInAList), [ColorInAList.red, ColorInAList.green, ColorInAList.blue])
self.assertEqual(ColorInAList.red.value, [1])
self.assertEqual(ColorInAList([1]), ColorInAList.red)
class TestUnique(unittest.TestCase):