[3.9] bpo-42567: [Enum] call __init_subclass__ after members are added (GH-23714) (GH-23772)

When creating an Enum, `type.__new__` calls `__init_subclass__`, but at that point the members have not been added.

This patch suppresses the initial call, then manually calls the ancestor `__init_subclass__` before returning the new Enum class.
(cherry picked from commit 6bd94de168)
This commit is contained in:
Ethan Furman 2020-12-14 18:41:34 -08:00 committed by GitHub
parent aba12b67c1
commit 9d1fff1fcd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 103 additions and 3 deletions

View File

@ -9,6 +9,14 @@ __all__ = [
] ]
class _NoInitSubclass:
"""
temporary base class to suppress calling __init_subclass__
"""
@classmethod
def __init_subclass__(cls, **kwds):
pass
def _is_descriptor(obj): def _is_descriptor(obj):
""" """
Returns True if obj is a descriptor, False otherwise. Returns True if obj is a descriptor, False otherwise.
@ -176,7 +184,7 @@ class EnumMeta(type):
) )
return enum_dict return enum_dict
def __new__(metacls, cls, bases, classdict): def __new__(metacls, cls, bases, classdict, **kwds):
# an Enum class is final once enumeration items have been defined; it # an Enum class is final once enumeration items have been defined; it
# cannot be mixed with other types (int, float, etc.) if it has an # cannot be mixed with other types (int, float, etc.) if it has an
# inherited __new__ unless a new __new__ is defined (or the resulting # inherited __new__ unless a new __new__ is defined (or the resulting
@ -211,8 +219,22 @@ class EnumMeta(type):
if '__doc__' not in classdict: if '__doc__' not in classdict:
classdict['__doc__'] = 'An enumeration.' classdict['__doc__'] = 'An enumeration.'
# postpone calling __init_subclass__
if '__init_subclass__' in classdict and classdict['__init_subclass__'] is None:
raise TypeError('%s.__init_subclass__ cannot be None')
# remove current __init_subclass__ so previous one can be found with getattr
new_init_subclass = classdict.pop('__init_subclass__', None)
# create our new Enum type # create our new Enum type
enum_class = super().__new__(metacls, cls, bases, classdict) if bases:
bases = (_NoInitSubclass, ) + bases
enum_class = type.__new__(metacls, cls, bases, classdict)
enum_class.__bases__ = enum_class.__bases__[1:] #or (object, )
else:
enum_class = type.__new__(metacls, cls, bases, classdict)
old_init_subclass = getattr(enum_class, '__init_subclass__', None)
# and restore the new one (if there was one)
if new_init_subclass is not None:
enum_class.__init_subclass__ = classmethod(new_init_subclass)
enum_class._member_names_ = [] # names in definition order enum_class._member_names_ = [] # names in definition order
enum_class._member_map_ = {} # name->value map enum_class._member_map_ = {} # name->value map
enum_class._member_type_ = member_type enum_class._member_type_ = member_type
@ -324,6 +346,9 @@ class EnumMeta(type):
if _order_ != enum_class._member_names_: if _order_ != enum_class._member_names_:
raise TypeError('member order does not match _order_') raise TypeError('member order does not match _order_')
# finally, call parents' __init_subclass__
if Enum is not None and old_init_subclass is not None:
old_init_subclass(**kwds)
return enum_class return enum_class
def __bool__(self): def __bool__(self):
@ -701,6 +726,9 @@ class Enum(metaclass=EnumMeta):
else: else:
return start return start
def __init_subclass__(cls, **kwds):
super().__init_subclass__(**kwds)
@classmethod @classmethod
def _missing_(cls, value): def _missing_(cls, value):
return None return None

View File

@ -2049,7 +2049,6 @@ class TestEnum(unittest.TestCase):
local_ls = {} local_ls = {}
exec(code, global_ns, local_ls) exec(code, global_ns, local_ls)
@unittest.skipUnless( @unittest.skipUnless(
sys.version_info[:2] == (3, 9), sys.version_info[:2] == (3, 9),
'private variables are now normal attributes', 'private variables are now normal attributes',
@ -2066,6 +2065,42 @@ class TestEnum(unittest.TestCase):
except ValueError: except ValueError:
pass pass
def test_init_subclass(self):
class MyEnum(Enum):
def __init_subclass__(cls, **kwds):
super(MyEnum, cls).__init_subclass__(**kwds)
self.assertFalse(cls.__dict__.get('_test', False))
cls._test1 = 'MyEnum'
#
class TheirEnum(MyEnum):
def __init_subclass__(cls, **kwds):
super().__init_subclass__(**kwds)
cls._test2 = 'TheirEnum'
class WhoseEnum(TheirEnum):
def __init_subclass__(cls, **kwds):
pass
class NoEnum(WhoseEnum):
ONE = 1
self.assertEqual(TheirEnum.__dict__['_test1'], 'MyEnum')
self.assertEqual(WhoseEnum.__dict__['_test1'], 'MyEnum')
self.assertEqual(WhoseEnum.__dict__['_test2'], 'TheirEnum')
self.assertFalse(NoEnum.__dict__.get('_test1', False))
self.assertFalse(NoEnum.__dict__.get('_test2', False))
#
class OurEnum(MyEnum):
def __init_subclass__(cls, **kwds):
cls._test2 = 'OurEnum'
class WhereEnum(OurEnum):
def __init_subclass__(cls, **kwds):
pass
class NeverEnum(WhereEnum):
ONE = 'one'
self.assertEqual(OurEnum.__dict__['_test1'], 'MyEnum')
self.assertFalse(WhereEnum.__dict__.get('_test1', False))
self.assertEqual(WhereEnum.__dict__['_test2'], 'OurEnum')
self.assertFalse(NeverEnum.__dict__.get('_test1', False))
self.assertFalse(NeverEnum.__dict__.get('_test2', False))
class TestOrder(unittest.TestCase): class TestOrder(unittest.TestCase):
@ -2516,6 +2551,42 @@ class TestFlag(unittest.TestCase):
'at least one thread failed while creating composite members') 'at least one thread failed while creating composite members')
self.assertEqual(256, len(seen), 'too many composite members created') self.assertEqual(256, len(seen), 'too many composite members created')
def test_init_subclass(self):
class MyEnum(Flag):
def __init_subclass__(cls, **kwds):
super().__init_subclass__(**kwds)
self.assertFalse(cls.__dict__.get('_test', False))
cls._test1 = 'MyEnum'
#
class TheirEnum(MyEnum):
def __init_subclass__(cls, **kwds):
super(TheirEnum, cls).__init_subclass__(**kwds)
cls._test2 = 'TheirEnum'
class WhoseEnum(TheirEnum):
def __init_subclass__(cls, **kwds):
pass
class NoEnum(WhoseEnum):
ONE = 1
self.assertEqual(TheirEnum.__dict__['_test1'], 'MyEnum')
self.assertEqual(WhoseEnum.__dict__['_test1'], 'MyEnum')
self.assertEqual(WhoseEnum.__dict__['_test2'], 'TheirEnum')
self.assertFalse(NoEnum.__dict__.get('_test1', False))
self.assertFalse(NoEnum.__dict__.get('_test2', False))
#
class OurEnum(MyEnum):
def __init_subclass__(cls, **kwds):
cls._test2 = 'OurEnum'
class WhereEnum(OurEnum):
def __init_subclass__(cls, **kwds):
pass
class NeverEnum(WhereEnum):
ONE = 1
self.assertEqual(OurEnum.__dict__['_test1'], 'MyEnum')
self.assertFalse(WhereEnum.__dict__.get('_test1', False))
self.assertEqual(WhereEnum.__dict__['_test2'], 'OurEnum')
self.assertFalse(NeverEnum.__dict__.get('_test1', False))
self.assertFalse(NeverEnum.__dict__.get('_test2', False))
class TestIntFlag(unittest.TestCase): class TestIntFlag(unittest.TestCase):
"""Tests of the IntFlags.""" """Tests of the IntFlags."""

View File

@ -0,0 +1 @@
`Enum`: call `__init_subclass__` after members have been added