bpo-46477: [Enum] ensure Flag subclasses have correct bitwise methods (GH-30816)

This commit is contained in:
Ethan Furman 2022-01-22 18:27:52 -08:00 committed by GitHub
parent 976dec9b3b
commit 353e3b2820
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 51 additions and 42 deletions

View File

@ -618,6 +618,18 @@ class EnumType(type):
if name not in classdict:
setattr(enum_class, name, getattr(first_enum, name))
#
# for Flag, add __or__, __and__, __xor__, and __invert__
if Flag is not None and issubclass(enum_class, Flag):
for name in (
'__or__', '__and__', '__xor__',
'__ror__', '__rand__', '__rxor__',
'__invert__'
):
if name not in classdict:
enum_method = getattr(Flag, name)
setattr(enum_class, name, enum_method)
classdict[name] = enum_method
#
# replace any other __new__ with our own (as long as Enum is not None,
# anyway) -- again, this is to support pickle
if Enum is not None:
@ -1467,19 +1479,34 @@ class Flag(Enum, boundary=STRICT):
return bool(self._value_)
def __or__(self, other):
if not isinstance(other, self.__class__):
if isinstance(other, self.__class__):
other = other._value_
elif self._member_type_ is not object and isinstance(other, self._member_type_):
other = other
else:
return NotImplemented
return self.__class__(self._value_ | other._value_)
value = self._value_
return self.__class__(value | other)
def __and__(self, other):
if not isinstance(other, self.__class__):
if isinstance(other, self.__class__):
other = other._value_
elif self._member_type_ is not object and isinstance(other, self._member_type_):
other = other
else:
return NotImplemented
return self.__class__(self._value_ & other._value_)
value = self._value_
return self.__class__(value & other)
def __xor__(self, other):
if not isinstance(other, self.__class__):
if isinstance(other, self.__class__):
other = other._value_
elif self._member_type_ is not object and isinstance(other, self._member_type_):
other = other
else:
return NotImplemented
return self.__class__(self._value_ ^ other._value_)
value = self._value_
return self.__class__(value ^ other)
def __invert__(self):
if self._inverted_ is None:
@ -1493,6 +1520,10 @@ class Flag(Enum, boundary=STRICT):
self._inverted_._inverted_ = self
return self._inverted_
__rand__ = __and__
__ror__ = __or__
__rxor__ = __xor__
class IntFlag(int, ReprEnum, Flag, boundary=EJECT):
"""
@ -1500,42 +1531,6 @@ class IntFlag(int, ReprEnum, Flag, boundary=EJECT):
"""
def __or__(self, other):
if isinstance(other, self.__class__):
other = other._value_
elif isinstance(other, int):
other = other
else:
return NotImplemented
value = self._value_
return self.__class__(value | other)
def __and__(self, other):
if isinstance(other, self.__class__):
other = other._value_
elif isinstance(other, int):
other = other
else:
return NotImplemented
value = self._value_
return self.__class__(value & other)
def __xor__(self, other):
if isinstance(other, self.__class__):
other = other._value_
elif isinstance(other, int):
other = other
else:
return NotImplemented
value = self._value_
return self.__class__(value ^ other)
__ror__ = __or__
__rand__ = __and__
__rxor__ = __xor__
__invert__ = Flag.__invert__
def _high_bit(value):
"""
returns index of highest bit, or -1 if value is zero or negative
@ -1662,6 +1657,13 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
body['_flag_mask_'] = None
body['_all_bits_'] = None
body['_inverted_'] = None
body['__or__'] = Flag.__or__
body['__xor__'] = Flag.__xor__
body['__and__'] = Flag.__and__
body['__ror__'] = Flag.__ror__
body['__rxor__'] = Flag.__rxor__
body['__rand__'] = Flag.__rand__
body['__invert__'] = Flag.__invert__
for name, obj in cls.__dict__.items():
if name in ('__dict__', '__weakref__'):
continue

View File

@ -2496,6 +2496,13 @@ class TestSpecial(unittest.TestCase):
self.assertEqual(Some.x.value, 1)
self.assertEqual(Some.y.value, 2)
def test_custom_flag_bitwise(self):
class MyIntFlag(int, Flag):
ONE = 1
TWO = 2
FOUR = 4
self.assertTrue(isinstance(MyIntFlag.ONE | MyIntFlag.TWO, MyIntFlag), MyIntFlag.ONE | MyIntFlag.TWO)
self.assertTrue(isinstance(MyIntFlag.ONE | 2, MyIntFlag))
class TestOrder(unittest.TestCase):
"test usage of the `_order_` attribute"