mirror of https://github.com/python/cpython
bpo-46477: [Enum] ensure Flag subclasses have correct bitwise methods (GH-30816)
This commit is contained in:
parent
976dec9b3b
commit
353e3b2820
86
Lib/enum.py
86
Lib/enum.py
|
@ -618,6 +618,18 @@ class EnumType(type):
|
|||
if name not in classdict:
|
||||
setattr(enum_class, name, getattr(first_enum, name))
|
||||
#
|
||||
# for Flag, add __or__, __and__, __xor__, and __invert__
|
||||
if Flag is not None and issubclass(enum_class, Flag):
|
||||
for name in (
|
||||
'__or__', '__and__', '__xor__',
|
||||
'__ror__', '__rand__', '__rxor__',
|
||||
'__invert__'
|
||||
):
|
||||
if name not in classdict:
|
||||
enum_method = getattr(Flag, name)
|
||||
setattr(enum_class, name, enum_method)
|
||||
classdict[name] = enum_method
|
||||
#
|
||||
# replace any other __new__ with our own (as long as Enum is not None,
|
||||
# anyway) -- again, this is to support pickle
|
||||
if Enum is not None:
|
||||
|
@ -1467,19 +1479,34 @@ class Flag(Enum, boundary=STRICT):
|
|||
return bool(self._value_)
|
||||
|
||||
def __or__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
if isinstance(other, self.__class__):
|
||||
other = other._value_
|
||||
elif self._member_type_ is not object and isinstance(other, self._member_type_):
|
||||
other = other
|
||||
else:
|
||||
return NotImplemented
|
||||
return self.__class__(self._value_ | other._value_)
|
||||
value = self._value_
|
||||
return self.__class__(value | other)
|
||||
|
||||
def __and__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
if isinstance(other, self.__class__):
|
||||
other = other._value_
|
||||
elif self._member_type_ is not object and isinstance(other, self._member_type_):
|
||||
other = other
|
||||
else:
|
||||
return NotImplemented
|
||||
return self.__class__(self._value_ & other._value_)
|
||||
value = self._value_
|
||||
return self.__class__(value & other)
|
||||
|
||||
def __xor__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
if isinstance(other, self.__class__):
|
||||
other = other._value_
|
||||
elif self._member_type_ is not object and isinstance(other, self._member_type_):
|
||||
other = other
|
||||
else:
|
||||
return NotImplemented
|
||||
return self.__class__(self._value_ ^ other._value_)
|
||||
value = self._value_
|
||||
return self.__class__(value ^ other)
|
||||
|
||||
def __invert__(self):
|
||||
if self._inverted_ is None:
|
||||
|
@ -1493,6 +1520,10 @@ class Flag(Enum, boundary=STRICT):
|
|||
self._inverted_._inverted_ = self
|
||||
return self._inverted_
|
||||
|
||||
__rand__ = __and__
|
||||
__ror__ = __or__
|
||||
__rxor__ = __xor__
|
||||
|
||||
|
||||
class IntFlag(int, ReprEnum, Flag, boundary=EJECT):
|
||||
"""
|
||||
|
@ -1500,42 +1531,6 @@ class IntFlag(int, ReprEnum, Flag, boundary=EJECT):
|
|||
"""
|
||||
|
||||
|
||||
def __or__(self, other):
|
||||
if isinstance(other, self.__class__):
|
||||
other = other._value_
|
||||
elif isinstance(other, int):
|
||||
other = other
|
||||
else:
|
||||
return NotImplemented
|
||||
value = self._value_
|
||||
return self.__class__(value | other)
|
||||
|
||||
def __and__(self, other):
|
||||
if isinstance(other, self.__class__):
|
||||
other = other._value_
|
||||
elif isinstance(other, int):
|
||||
other = other
|
||||
else:
|
||||
return NotImplemented
|
||||
value = self._value_
|
||||
return self.__class__(value & other)
|
||||
|
||||
def __xor__(self, other):
|
||||
if isinstance(other, self.__class__):
|
||||
other = other._value_
|
||||
elif isinstance(other, int):
|
||||
other = other
|
||||
else:
|
||||
return NotImplemented
|
||||
value = self._value_
|
||||
return self.__class__(value ^ other)
|
||||
|
||||
__ror__ = __or__
|
||||
__rand__ = __and__
|
||||
__rxor__ = __xor__
|
||||
__invert__ = Flag.__invert__
|
||||
|
||||
|
||||
def _high_bit(value):
|
||||
"""
|
||||
returns index of highest bit, or -1 if value is zero or negative
|
||||
|
@ -1662,6 +1657,13 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
|
|||
body['_flag_mask_'] = None
|
||||
body['_all_bits_'] = None
|
||||
body['_inverted_'] = None
|
||||
body['__or__'] = Flag.__or__
|
||||
body['__xor__'] = Flag.__xor__
|
||||
body['__and__'] = Flag.__and__
|
||||
body['__ror__'] = Flag.__ror__
|
||||
body['__rxor__'] = Flag.__rxor__
|
||||
body['__rand__'] = Flag.__rand__
|
||||
body['__invert__'] = Flag.__invert__
|
||||
for name, obj in cls.__dict__.items():
|
||||
if name in ('__dict__', '__weakref__'):
|
||||
continue
|
||||
|
|
|
@ -2496,6 +2496,13 @@ class TestSpecial(unittest.TestCase):
|
|||
self.assertEqual(Some.x.value, 1)
|
||||
self.assertEqual(Some.y.value, 2)
|
||||
|
||||
def test_custom_flag_bitwise(self):
|
||||
class MyIntFlag(int, Flag):
|
||||
ONE = 1
|
||||
TWO = 2
|
||||
FOUR = 4
|
||||
self.assertTrue(isinstance(MyIntFlag.ONE | MyIntFlag.TWO, MyIntFlag), MyIntFlag.ONE | MyIntFlag.TWO)
|
||||
self.assertTrue(isinstance(MyIntFlag.ONE | 2, MyIntFlag))
|
||||
|
||||
class TestOrder(unittest.TestCase):
|
||||
"test usage of the `_order_` attribute"
|
||||
|
|
Loading…
Reference in New Issue