bpo-42567: [Enum] call __init_subclass__ after members are added (GH-23714)

When creating an Enum, type.__new__ calls __init_subclass__, but at that point the members have not been added.

This patch suppresses the initial call, then manually calls the ancestor __init_subclass__ before returning the new Enum class.
This commit is contained in:
Ethan Furman 2020-12-09 16:41:22 -08:00 committed by GitHub
parent 2a35137328
commit 6bd94de168
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 104 additions and 2 deletions

View File

@ -9,6 +9,14 @@ __all__ = [
] ]
class _NoInitSubclass:
"""
temporary base class to suppress calling __init_subclass__
"""
@classmethod
def __init_subclass__(cls, **kwds):
pass
def _is_descriptor(obj): def _is_descriptor(obj):
""" """
Returns True if obj is a descriptor, False otherwise. Returns True if obj is a descriptor, False otherwise.
@ -157,7 +165,7 @@ class EnumMeta(type):
) )
return enum_dict return enum_dict
def __new__(metacls, cls, bases, classdict): def __new__(metacls, cls, bases, classdict, **kwds):
# an Enum class is final once enumeration items have been defined; it # an Enum class is final once enumeration items have been defined; it
# cannot be mixed with other types (int, float, etc.) if it has an # cannot be mixed with other types (int, float, etc.) if it has an
# inherited __new__ unless a new __new__ is defined (or the resulting # inherited __new__ unless a new __new__ is defined (or the resulting
@ -192,8 +200,22 @@ class EnumMeta(type):
if '__doc__' not in classdict: if '__doc__' not in classdict:
classdict['__doc__'] = 'An enumeration.' classdict['__doc__'] = 'An enumeration.'
# postpone calling __init_subclass__
if '__init_subclass__' in classdict and classdict['__init_subclass__'] is None:
raise TypeError('%s.__init_subclass__ cannot be None')
# remove current __init_subclass__ so previous one can be found with getattr
new_init_subclass = classdict.pop('__init_subclass__', None)
# create our new Enum type # create our new Enum type
enum_class = super().__new__(metacls, cls, bases, classdict) if bases:
bases = (_NoInitSubclass, ) + bases
enum_class = type.__new__(metacls, cls, bases, classdict)
enum_class.__bases__ = enum_class.__bases__[1:] #or (object, )
else:
enum_class = type.__new__(metacls, cls, bases, classdict)
old_init_subclass = getattr(enum_class, '__init_subclass__', None)
# and restore the new one (if there was one)
if new_init_subclass is not None:
enum_class.__init_subclass__ = classmethod(new_init_subclass)
enum_class._member_names_ = [] # names in definition order enum_class._member_names_ = [] # names in definition order
enum_class._member_map_ = {} # name->value map enum_class._member_map_ = {} # name->value map
enum_class._member_type_ = member_type enum_class._member_type_ = member_type
@ -305,6 +327,9 @@ class EnumMeta(type):
if _order_ != enum_class._member_names_: if _order_ != enum_class._member_names_:
raise TypeError('member order does not match _order_') raise TypeError('member order does not match _order_')
# finally, call parents' __init_subclass__
if Enum is not None and old_init_subclass is not None:
old_init_subclass(**kwds)
return enum_class return enum_class
def __bool__(self): def __bool__(self):
@ -682,6 +707,9 @@ class Enum(metaclass=EnumMeta):
else: else:
return start return start
def __init_subclass__(cls, **kwds):
super().__init_subclass__(**kwds)
@classmethod @classmethod
def _missing_(cls, value): def _missing_(cls, value):
return None return None

View File

@ -2117,6 +2117,43 @@ class TestEnum(unittest.TestCase):
class ThirdFailedStrEnum(StrEnum): class ThirdFailedStrEnum(StrEnum):
one = '1' one = '1'
two = b'2', 'ascii', 9 two = b'2', 'ascii', 9
def test_init_subclass(self):
class MyEnum(Enum):
def __init_subclass__(cls, **kwds):
super(MyEnum, cls).__init_subclass__(**kwds)
self.assertFalse(cls.__dict__.get('_test', False))
cls._test1 = 'MyEnum'
#
class TheirEnum(MyEnum):
def __init_subclass__(cls, **kwds):
super().__init_subclass__(**kwds)
cls._test2 = 'TheirEnum'
class WhoseEnum(TheirEnum):
def __init_subclass__(cls, **kwds):
pass
class NoEnum(WhoseEnum):
ONE = 1
self.assertEqual(TheirEnum.__dict__['_test1'], 'MyEnum')
self.assertEqual(WhoseEnum.__dict__['_test1'], 'MyEnum')
self.assertEqual(WhoseEnum.__dict__['_test2'], 'TheirEnum')
self.assertFalse(NoEnum.__dict__.get('_test1', False))
self.assertFalse(NoEnum.__dict__.get('_test2', False))
#
class OurEnum(MyEnum):
def __init_subclass__(cls, **kwds):
cls._test2 = 'OurEnum'
class WhereEnum(OurEnum):
def __init_subclass__(cls, **kwds):
pass
class NeverEnum(WhereEnum):
ONE = 'one'
self.assertEqual(OurEnum.__dict__['_test1'], 'MyEnum')
self.assertFalse(WhereEnum.__dict__.get('_test1', False))
self.assertEqual(WhereEnum.__dict__['_test2'], 'OurEnum')
self.assertFalse(NeverEnum.__dict__.get('_test1', False))
self.assertFalse(NeverEnum.__dict__.get('_test2', False))
class TestOrder(unittest.TestCase): class TestOrder(unittest.TestCase):
@ -2573,6 +2610,42 @@ class TestFlag(unittest.TestCase):
'at least one thread failed while creating composite members') 'at least one thread failed while creating composite members')
self.assertEqual(256, len(seen), 'too many composite members created') self.assertEqual(256, len(seen), 'too many composite members created')
def test_init_subclass(self):
class MyEnum(Flag):
def __init_subclass__(cls, **kwds):
super().__init_subclass__(**kwds)
self.assertFalse(cls.__dict__.get('_test', False))
cls._test1 = 'MyEnum'
#
class TheirEnum(MyEnum):
def __init_subclass__(cls, **kwds):
super(TheirEnum, cls).__init_subclass__(**kwds)
cls._test2 = 'TheirEnum'
class WhoseEnum(TheirEnum):
def __init_subclass__(cls, **kwds):
pass
class NoEnum(WhoseEnum):
ONE = 1
self.assertEqual(TheirEnum.__dict__['_test1'], 'MyEnum')
self.assertEqual(WhoseEnum.__dict__['_test1'], 'MyEnum')
self.assertEqual(WhoseEnum.__dict__['_test2'], 'TheirEnum')
self.assertFalse(NoEnum.__dict__.get('_test1', False))
self.assertFalse(NoEnum.__dict__.get('_test2', False))
#
class OurEnum(MyEnum):
def __init_subclass__(cls, **kwds):
cls._test2 = 'OurEnum'
class WhereEnum(OurEnum):
def __init_subclass__(cls, **kwds):
pass
class NeverEnum(WhereEnum):
ONE = 1
self.assertEqual(OurEnum.__dict__['_test1'], 'MyEnum')
self.assertFalse(WhereEnum.__dict__.get('_test1', False))
self.assertEqual(WhereEnum.__dict__['_test2'], 'OurEnum')
self.assertFalse(NeverEnum.__dict__.get('_test1', False))
self.assertFalse(NeverEnum.__dict__.get('_test2', False))
class TestIntFlag(unittest.TestCase): class TestIntFlag(unittest.TestCase):
"""Tests of the IntFlags.""" """Tests of the IntFlags."""

View File

@ -0,0 +1 @@
`Enum`: call `__init_subclass__` after members have been added