From 353e3b2820bed38da16140276786eef9ba33d3bd Mon Sep 17 00:00:00 2001 From: Ethan Furman Date: Sat, 22 Jan 2022 18:27:52 -0800 Subject: [PATCH] bpo-46477: [Enum] ensure Flag subclasses have correct bitwise methods (GH-30816) --- Lib/enum.py | 86 ++++++++++++++++++++++--------------------- Lib/test/test_enum.py | 7 ++++ 2 files changed, 51 insertions(+), 42 deletions(-) diff --git a/Lib/enum.py b/Lib/enum.py index b5104677312..85245c95f9a 100644 --- a/Lib/enum.py +++ b/Lib/enum.py @@ -618,6 +618,18 @@ class EnumType(type): if name not in classdict: setattr(enum_class, name, getattr(first_enum, name)) # + # for Flag, add __or__, __and__, __xor__, and __invert__ + if Flag is not None and issubclass(enum_class, Flag): + for name in ( + '__or__', '__and__', '__xor__', + '__ror__', '__rand__', '__rxor__', + '__invert__' + ): + if name not in classdict: + enum_method = getattr(Flag, name) + setattr(enum_class, name, enum_method) + classdict[name] = enum_method + # # replace any other __new__ with our own (as long as Enum is not None, # anyway) -- again, this is to support pickle if Enum is not None: @@ -1467,19 +1479,34 @@ class Flag(Enum, boundary=STRICT): return bool(self._value_) def __or__(self, other): - if not isinstance(other, self.__class__): + if isinstance(other, self.__class__): + other = other._value_ + elif self._member_type_ is not object and isinstance(other, self._member_type_): + other = other + else: return NotImplemented - return self.__class__(self._value_ | other._value_) + value = self._value_ + return self.__class__(value | other) def __and__(self, other): - if not isinstance(other, self.__class__): + if isinstance(other, self.__class__): + other = other._value_ + elif self._member_type_ is not object and isinstance(other, self._member_type_): + other = other + else: return NotImplemented - return self.__class__(self._value_ & other._value_) + value = self._value_ + return self.__class__(value & other) def __xor__(self, other): - if not isinstance(other, self.__class__): + if isinstance(other, self.__class__): + other = other._value_ + elif self._member_type_ is not object and isinstance(other, self._member_type_): + other = other + else: return NotImplemented - return self.__class__(self._value_ ^ other._value_) + value = self._value_ + return self.__class__(value ^ other) def __invert__(self): if self._inverted_ is None: @@ -1493,6 +1520,10 @@ class Flag(Enum, boundary=STRICT): self._inverted_._inverted_ = self return self._inverted_ + __rand__ = __and__ + __ror__ = __or__ + __rxor__ = __xor__ + class IntFlag(int, ReprEnum, Flag, boundary=EJECT): """ @@ -1500,42 +1531,6 @@ class IntFlag(int, ReprEnum, Flag, boundary=EJECT): """ - def __or__(self, other): - if isinstance(other, self.__class__): - other = other._value_ - elif isinstance(other, int): - other = other - else: - return NotImplemented - value = self._value_ - return self.__class__(value | other) - - def __and__(self, other): - if isinstance(other, self.__class__): - other = other._value_ - elif isinstance(other, int): - other = other - else: - return NotImplemented - value = self._value_ - return self.__class__(value & other) - - def __xor__(self, other): - if isinstance(other, self.__class__): - other = other._value_ - elif isinstance(other, int): - other = other - else: - return NotImplemented - value = self._value_ - return self.__class__(value ^ other) - - __ror__ = __or__ - __rand__ = __and__ - __rxor__ = __xor__ - __invert__ = Flag.__invert__ - - def _high_bit(value): """ returns index of highest bit, or -1 if value is zero or negative @@ -1662,6 +1657,13 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None): body['_flag_mask_'] = None body['_all_bits_'] = None body['_inverted_'] = None + body['__or__'] = Flag.__or__ + body['__xor__'] = Flag.__xor__ + body['__and__'] = Flag.__and__ + body['__ror__'] = Flag.__ror__ + body['__rxor__'] = Flag.__rxor__ + body['__rand__'] = Flag.__rand__ + body['__invert__'] = Flag.__invert__ for name, obj in cls.__dict__.items(): if name in ('__dict__', '__weakref__'): continue diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py index d7ce8add787..b8a7914355c 100644 --- a/Lib/test/test_enum.py +++ b/Lib/test/test_enum.py @@ -2496,6 +2496,13 @@ class TestSpecial(unittest.TestCase): self.assertEqual(Some.x.value, 1) self.assertEqual(Some.y.value, 2) + def test_custom_flag_bitwise(self): + class MyIntFlag(int, Flag): + ONE = 1 + TWO = 2 + FOUR = 4 + self.assertTrue(isinstance(MyIntFlag.ONE | MyIntFlag.TWO, MyIntFlag), MyIntFlag.ONE | MyIntFlag.TWO) + self.assertTrue(isinstance(MyIntFlag.ONE | 2, MyIntFlag)) class TestOrder(unittest.TestCase): "test usage of the `_order_` attribute"