issue23591: fix flag decomposition and repr

This commit is contained in:
Ethan Furman 2016-09-18 13:15:41 -07:00
parent 9a7bbb2e3f
commit 3515dcce80
3 changed files with 193 additions and 86 deletions

View File

@ -674,6 +674,8 @@ while combinations of flags won't::
... green = auto() ... green = auto()
... white = red | blue | green ... white = red | blue | green
... ...
>>> Color.white
<Color.white: 7>
Giving a name to the "no flags set" condition does not change its boolean Giving a name to the "no flags set" condition does not change its boolean
value:: value::
@ -1068,3 +1070,23 @@ but not of the class::
>>> dir(Planet.EARTH) >>> dir(Planet.EARTH)
['__class__', '__doc__', '__module__', 'name', 'surface_gravity', 'value'] ['__class__', '__doc__', '__module__', 'name', 'surface_gravity', 'value']
Combining members of ``Flag``
"""""""""""""""""""""""""""""
If a combination of Flag members is not named, the :func:`repr` will include
all named flags and all named combinations of flags that are in the value::
>>> class Color(Flag):
... red = auto()
... green = auto()
... blue = auto()
... magenta = red | blue
... yellow = red | green
... cyan = green | blue
...
>>> Color(3) # named combination
<Color.yellow: 3>
>>> Color(7) # not named combination
<Color.cyan|magenta|blue|yellow|green|red: 7>

View File

@ -1,7 +1,7 @@
import sys import sys
from types import MappingProxyType, DynamicClassAttribute from types import MappingProxyType, DynamicClassAttribute
from functools import reduce from functools import reduce
from operator import or_ as _or_ from operator import or_ as _or_, and_ as _and_, xor, neg
# try _collections first to reduce startup cost # try _collections first to reduce startup cost
try: try:
@ -47,11 +47,12 @@ def _make_class_unpicklable(cls):
cls.__reduce_ex__ = _break_on_call_reduce cls.__reduce_ex__ = _break_on_call_reduce
cls.__module__ = '<unknown>' cls.__module__ = '<unknown>'
_auto_null = object()
class auto: class auto:
""" """
Instances are replaced with an appropriate value in Enum class suites. Instances are replaced with an appropriate value in Enum class suites.
""" """
pass value = _auto_null
class _EnumDict(dict): class _EnumDict(dict):
@ -77,7 +78,7 @@ class _EnumDict(dict):
""" """
if _is_sunder(key): if _is_sunder(key):
if key not in ( if key not in (
'_order_', '_create_pseudo_member_', '_decompose_', '_order_', '_create_pseudo_member_',
'_generate_next_value_', '_missing_', '_generate_next_value_', '_missing_',
): ):
raise ValueError('_names_ are reserved for future Enum use') raise ValueError('_names_ are reserved for future Enum use')
@ -94,7 +95,9 @@ class _EnumDict(dict):
# enum overwriting a descriptor? # enum overwriting a descriptor?
raise TypeError('%r already defined as: %r' % (key, self[key])) raise TypeError('%r already defined as: %r' % (key, self[key]))
if isinstance(value, auto): if isinstance(value, auto):
value = self._generate_next_value(key, 1, len(self._member_names), self._last_values[:]) if value.value == _auto_null:
value.value = self._generate_next_value(key, 1, len(self._member_names), self._last_values[:])
value = value.value
self._member_names.append(key) self._member_names.append(key)
self._last_values.append(value) self._last_values.append(value)
super().__setitem__(key, value) super().__setitem__(key, value)
@ -658,7 +661,7 @@ class Flag(Enum):
try: try:
high_bit = _high_bit(last_value) high_bit = _high_bit(last_value)
break break
except TypeError: except Exception:
raise TypeError('Invalid Flag value: %r' % last_value) from None raise TypeError('Invalid Flag value: %r' % last_value) from None
return 2 ** (high_bit+1) return 2 ** (high_bit+1)
@ -668,61 +671,38 @@ class Flag(Enum):
if value < 0: if value < 0:
value = ~value value = ~value
possible_member = cls._create_pseudo_member_(value) possible_member = cls._create_pseudo_member_(value)
for member in possible_member._decompose_():
if member._name_ is None and member._value_ != 0:
raise ValueError('%r is not a valid %s' % (original_value, cls.__name__))
if original_value < 0: if original_value < 0:
possible_member = ~possible_member possible_member = ~possible_member
return possible_member return possible_member
@classmethod @classmethod
def _create_pseudo_member_(cls, value): def _create_pseudo_member_(cls, value):
"""
Create a composite member iff value contains only members.
"""
pseudo_member = cls._value2member_map_.get(value, None) pseudo_member = cls._value2member_map_.get(value, None)
if pseudo_member is None: if pseudo_member is None:
# construct a non-singleton enum pseudo-member # verify all bits are accounted for
_, extra_flags = _decompose(cls, value)
if extra_flags:
raise ValueError("%r is not a valid %s" % (value, cls.__name__))
# construct a singleton enum pseudo-member
pseudo_member = object.__new__(cls) pseudo_member = object.__new__(cls)
pseudo_member._name_ = None pseudo_member._name_ = None
pseudo_member._value_ = value pseudo_member._value_ = value
cls._value2member_map_[value] = pseudo_member cls._value2member_map_[value] = pseudo_member
return pseudo_member return pseudo_member
def _decompose_(self):
"""Extract all members from the value."""
value = self._value_
members = []
cls = self.__class__
for member in sorted(cls, key=lambda m: m._value_, reverse=True):
while _high_bit(value) > _high_bit(member._value_):
unknown = self._create_pseudo_member_(2 ** _high_bit(value))
members.append(unknown)
value &= ~unknown._value_
if (
(value & member._value_ == member._value_)
and (member._value_ or not members)
):
value &= ~member._value_
members.append(member)
if not members or value:
members.append(self._create_pseudo_member_(value))
members = list(members)
return members
def __contains__(self, other): def __contains__(self, other):
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):
return NotImplemented return NotImplemented
return other._value_ & self._value_ == other._value_ return other._value_ & self._value_ == other._value_
def __iter__(self):
if self.value == 0:
return iter([])
else:
return iter(self._decompose_())
def __repr__(self): def __repr__(self):
cls = self.__class__ cls = self.__class__
if self._name_ is not None: if self._name_ is not None:
return '<%s.%s: %r>' % (cls.__name__, self._name_, self._value_) return '<%s.%s: %r>' % (cls.__name__, self._name_, self._value_)
members = self._decompose_() members, uncovered = _decompose(cls, self._value_)
return '<%s.%s: %r>' % ( return '<%s.%s: %r>' % (
cls.__name__, cls.__name__,
'|'.join([str(m._name_ or m._value_) for m in members]), '|'.join([str(m._name_ or m._value_) for m in members]),
@ -733,7 +713,7 @@ class Flag(Enum):
cls = self.__class__ cls = self.__class__
if self._name_ is not None: if self._name_ is not None:
return '%s.%s' % (cls.__name__, self._name_) return '%s.%s' % (cls.__name__, self._name_)
members = self._decompose_() members, uncovered = _decompose(cls, self._value_)
if len(members) == 1 and members[0]._name_ is None: if len(members) == 1 and members[0]._name_ is None:
return '%s.%r' % (cls.__name__, members[0]._value_) return '%s.%r' % (cls.__name__, members[0]._value_)
else: else:
@ -761,8 +741,11 @@ class Flag(Enum):
return self.__class__(self._value_ ^ other._value_) return self.__class__(self._value_ ^ other._value_)
def __invert__(self): def __invert__(self):
members = self._decompose_() members, uncovered = _decompose(self.__class__, self._value_)
inverted_members = [m for m in self.__class__ if m not in members and not m._value_ & self._value_] inverted_members = [
m for m in self.__class__
if m not in members and not m._value_ & self._value_
]
inverted = reduce(_or_, inverted_members, self.__class__(0)) inverted = reduce(_or_, inverted_members, self.__class__(0))
return self.__class__(inverted) return self.__class__(inverted)
@ -770,26 +753,46 @@ class Flag(Enum):
class IntFlag(int, Flag): class IntFlag(int, Flag):
"""Support for integer-based Flags""" """Support for integer-based Flags"""
@classmethod
def _missing_(cls, value):
if not isinstance(value, int):
raise ValueError("%r is not a valid %s" % (value, cls.__name__))
new_member = cls._create_pseudo_member_(value)
return new_member
@classmethod @classmethod
def _create_pseudo_member_(cls, value): def _create_pseudo_member_(cls, value):
pseudo_member = cls._value2member_map_.get(value, None) pseudo_member = cls._value2member_map_.get(value, None)
if pseudo_member is None: if pseudo_member is None:
# construct a non-singleton enum pseudo-member need_to_create = [value]
pseudo_member = int.__new__(cls, value) # get unaccounted for bits
pseudo_member._name_ = None _, extra_flags = _decompose(cls, value)
pseudo_member._value_ = value # timer = 10
cls._value2member_map_[value] = pseudo_member while extra_flags:
# timer -= 1
bit = _high_bit(extra_flags)
flag_value = 2 ** bit
if (flag_value not in cls._value2member_map_ and
flag_value not in need_to_create
):
need_to_create.append(flag_value)
if extra_flags == -flag_value:
extra_flags = 0
else:
extra_flags ^= flag_value
for value in reversed(need_to_create):
# construct singleton pseudo-members
pseudo_member = int.__new__(cls, value)
pseudo_member._name_ = None
pseudo_member._value_ = value
cls._value2member_map_[value] = pseudo_member
return pseudo_member return pseudo_member
@classmethod
def _missing_(cls, value):
possible_member = cls._create_pseudo_member_(value)
return possible_member
def __or__(self, other): def __or__(self, other):
if not isinstance(other, (self.__class__, int)): if not isinstance(other, (self.__class__, int)):
return NotImplemented return NotImplemented
return self.__class__(self._value_ | self.__class__(other)._value_) result = self.__class__(self._value_ | self.__class__(other)._value_)
return result
def __and__(self, other): def __and__(self, other):
if not isinstance(other, (self.__class__, int)): if not isinstance(other, (self.__class__, int)):
@ -806,17 +809,13 @@ class IntFlag(int, Flag):
__rxor__ = __xor__ __rxor__ = __xor__
def __invert__(self): def __invert__(self):
# members = self._decompose_() result = self.__class__(~self._value_)
# inverted_members = [m for m in self.__class__ if m not in members and not m._value_ & self._value_] return result
# inverted = reduce(_or_, inverted_members, self.__class__(0))
return self.__class__(~self._value_)
def _high_bit(value): def _high_bit(value):
"""returns index of highest bit, or -1 if value is zero or negative""" """returns index of highest bit, or -1 if value is zero or negative"""
return value.bit_length() - 1 if value > 0 else -1 return value.bit_length() - 1
def unique(enumeration): def unique(enumeration):
"""Class decorator for enumerations ensuring unique member values.""" """Class decorator for enumerations ensuring unique member values."""
@ -830,3 +829,40 @@ def unique(enumeration):
raise ValueError('duplicate values found in %r: %s' % raise ValueError('duplicate values found in %r: %s' %
(enumeration, alias_details)) (enumeration, alias_details))
return enumeration return enumeration
def _decompose(flag, value):
"""Extract all members from the value."""
# _decompose is only called if the value is not named
not_covered = value
negative = value < 0
if negative:
# only check for named flags
flags_to_check = [
(m, v)
for v, m in flag._value2member_map_.items()
if m.name is not None
]
else:
# check for named flags and powers-of-two flags
flags_to_check = [
(m, v)
for v, m in flag._value2member_map_.items()
if m.name is not None or _power_of_two(v)
]
members = []
for member, member_value in flags_to_check:
if member_value and member_value & value == member_value:
members.append(member)
not_covered &= ~member_value
if not members and value in flag._value2member_map_:
members.append(flag._value2member_map_[value])
members.sort(key=lambda m: m._value_, reverse=True)
if len(members) > 1 and members[0].value == value:
# we have the breakdown, don't need the value member itself
members.pop(0)
return members, not_covered
def _power_of_two(value):
if value < 1:
return False
return value == 2 ** _high_bit(value)

View File

@ -1634,6 +1634,13 @@ class TestEnum(unittest.TestCase):
self.assertEqual(Color.blue.value, 2) self.assertEqual(Color.blue.value, 2)
self.assertEqual(Color.green.value, 3) self.assertEqual(Color.green.value, 3)
def test_duplicate_auto(self):
class Dupes(Enum):
first = primero = auto()
second = auto()
third = auto()
self.assertEqual([Dupes.first, Dupes.second, Dupes.third], list(Dupes))
class TestOrder(unittest.TestCase): class TestOrder(unittest.TestCase):
@ -1731,7 +1738,7 @@ class TestFlag(unittest.TestCase):
self.assertEqual(str(Open.AC), 'Open.AC') self.assertEqual(str(Open.AC), 'Open.AC')
self.assertEqual(str(Open.RO | Open.CE), 'Open.CE') self.assertEqual(str(Open.RO | Open.CE), 'Open.CE')
self.assertEqual(str(Open.WO | Open.CE), 'Open.CE|WO') self.assertEqual(str(Open.WO | Open.CE), 'Open.CE|WO')
self.assertEqual(str(~Open.RO), 'Open.CE|AC') self.assertEqual(str(~Open.RO), 'Open.CE|AC|RW|WO')
self.assertEqual(str(~Open.WO), 'Open.CE|RW') self.assertEqual(str(~Open.WO), 'Open.CE|RW')
self.assertEqual(str(~Open.AC), 'Open.CE') self.assertEqual(str(~Open.AC), 'Open.CE')
self.assertEqual(str(~(Open.RO | Open.CE)), 'Open.AC') self.assertEqual(str(~(Open.RO | Open.CE)), 'Open.AC')
@ -1758,7 +1765,7 @@ class TestFlag(unittest.TestCase):
self.assertEqual(repr(Open.AC), '<Open.AC: 3>') self.assertEqual(repr(Open.AC), '<Open.AC: 3>')
self.assertEqual(repr(Open.RO | Open.CE), '<Open.CE: 524288>') self.assertEqual(repr(Open.RO | Open.CE), '<Open.CE: 524288>')
self.assertEqual(repr(Open.WO | Open.CE), '<Open.CE|WO: 524289>') self.assertEqual(repr(Open.WO | Open.CE), '<Open.CE|WO: 524289>')
self.assertEqual(repr(~Open.RO), '<Open.CE|AC: 524291>') self.assertEqual(repr(~Open.RO), '<Open.CE|AC|RW|WO: 524291>')
self.assertEqual(repr(~Open.WO), '<Open.CE|RW: 524290>') self.assertEqual(repr(~Open.WO), '<Open.CE|RW: 524290>')
self.assertEqual(repr(~Open.AC), '<Open.CE: 524288>') self.assertEqual(repr(~Open.AC), '<Open.CE: 524288>')
self.assertEqual(repr(~(Open.RO | Open.CE)), '<Open.AC: 3>') self.assertEqual(repr(~(Open.RO | Open.CE)), '<Open.AC: 3>')
@ -1949,6 +1956,33 @@ class TestFlag(unittest.TestCase):
red = 'not an int' red = 'not an int'
blue = auto() blue = auto()
def test_cascading_failure(self):
class Bizarre(Flag):
c = 3
d = 4
f = 6
# Bizarre.c | Bizarre.d
self.assertRaisesRegex(ValueError, "5 is not a valid Bizarre", Bizarre, 5)
self.assertRaisesRegex(ValueError, "5 is not a valid Bizarre", Bizarre, 5)
self.assertRaisesRegex(ValueError, "2 is not a valid Bizarre", Bizarre, 2)
self.assertRaisesRegex(ValueError, "2 is not a valid Bizarre", Bizarre, 2)
self.assertRaisesRegex(ValueError, "1 is not a valid Bizarre", Bizarre, 1)
self.assertRaisesRegex(ValueError, "1 is not a valid Bizarre", Bizarre, 1)
def test_duplicate_auto(self):
class Dupes(Enum):
first = primero = auto()
second = auto()
third = auto()
self.assertEqual([Dupes.first, Dupes.second, Dupes.third], list(Dupes))
def test_bizarre(self):
class Bizarre(Flag):
b = 3
c = 4
d = 6
self.assertEqual(repr(Bizarre(7)), '<Bizarre.d|c|b: 7>')
class TestIntFlag(unittest.TestCase): class TestIntFlag(unittest.TestCase):
"""Tests of the IntFlags.""" """Tests of the IntFlags."""
@ -1965,6 +1999,21 @@ class TestIntFlag(unittest.TestCase):
AC = 3 AC = 3
CE = 1<<19 CE = 1<<19
def test_type(self):
Perm = self.Perm
Open = self.Open
for f in Perm:
self.assertTrue(isinstance(f, Perm))
self.assertEqual(f, f.value)
self.assertTrue(isinstance(Perm.W | Perm.X, Perm))
self.assertEqual(Perm.W | Perm.X, 3)
for f in Open:
self.assertTrue(isinstance(f, Open))
self.assertEqual(f, f.value)
self.assertTrue(isinstance(Open.WO | Open.RW, Open))
self.assertEqual(Open.WO | Open.RW, 3)
def test_str(self): def test_str(self):
Perm = self.Perm Perm = self.Perm
self.assertEqual(str(Perm.R), 'Perm.R') self.assertEqual(str(Perm.R), 'Perm.R')
@ -1975,14 +2024,14 @@ class TestIntFlag(unittest.TestCase):
self.assertEqual(str(Perm.R | 8), 'Perm.8|R') self.assertEqual(str(Perm.R | 8), 'Perm.8|R')
self.assertEqual(str(Perm(0)), 'Perm.0') self.assertEqual(str(Perm(0)), 'Perm.0')
self.assertEqual(str(Perm(8)), 'Perm.8') self.assertEqual(str(Perm(8)), 'Perm.8')
self.assertEqual(str(~Perm.R), 'Perm.W|X|-8') self.assertEqual(str(~Perm.R), 'Perm.W|X')
self.assertEqual(str(~Perm.W), 'Perm.R|X|-8') self.assertEqual(str(~Perm.W), 'Perm.R|X')
self.assertEqual(str(~Perm.X), 'Perm.R|W|-8') self.assertEqual(str(~Perm.X), 'Perm.R|W')
self.assertEqual(str(~(Perm.R | Perm.W)), 'Perm.X|-8') self.assertEqual(str(~(Perm.R | Perm.W)), 'Perm.X')
self.assertEqual(str(~(Perm.R | Perm.W | Perm.X)), 'Perm.-8') self.assertEqual(str(~(Perm.R | Perm.W | Perm.X)), 'Perm.-8')
self.assertEqual(str(~(Perm.R | 8)), 'Perm.W|X|-16') self.assertEqual(str(~(Perm.R | 8)), 'Perm.W|X')
self.assertEqual(str(Perm(~0)), 'Perm.R|W|X|-8') self.assertEqual(str(Perm(~0)), 'Perm.R|W|X')
self.assertEqual(str(Perm(~8)), 'Perm.R|W|X|-16') self.assertEqual(str(Perm(~8)), 'Perm.R|W|X')
Open = self.Open Open = self.Open
self.assertEqual(str(Open.RO), 'Open.RO') self.assertEqual(str(Open.RO), 'Open.RO')
@ -1991,12 +2040,12 @@ class TestIntFlag(unittest.TestCase):
self.assertEqual(str(Open.RO | Open.CE), 'Open.CE') self.assertEqual(str(Open.RO | Open.CE), 'Open.CE')
self.assertEqual(str(Open.WO | Open.CE), 'Open.CE|WO') self.assertEqual(str(Open.WO | Open.CE), 'Open.CE|WO')
self.assertEqual(str(Open(4)), 'Open.4') self.assertEqual(str(Open(4)), 'Open.4')
self.assertEqual(str(~Open.RO), 'Open.CE|AC|-524292') self.assertEqual(str(~Open.RO), 'Open.CE|AC|RW|WO')
self.assertEqual(str(~Open.WO), 'Open.CE|RW|-524292') self.assertEqual(str(~Open.WO), 'Open.CE|RW')
self.assertEqual(str(~Open.AC), 'Open.CE|-524292') self.assertEqual(str(~Open.AC), 'Open.CE')
self.assertEqual(str(~(Open.RO | Open.CE)), 'Open.AC|-524292') self.assertEqual(str(~(Open.RO | Open.CE)), 'Open.AC|RW|WO')
self.assertEqual(str(~(Open.WO | Open.CE)), 'Open.RW|-524292') self.assertEqual(str(~(Open.WO | Open.CE)), 'Open.RW')
self.assertEqual(str(Open(~4)), 'Open.CE|AC|-524296') self.assertEqual(str(Open(~4)), 'Open.CE|AC|RW|WO')
def test_repr(self): def test_repr(self):
Perm = self.Perm Perm = self.Perm
@ -2008,14 +2057,14 @@ class TestIntFlag(unittest.TestCase):
self.assertEqual(repr(Perm.R | 8), '<Perm.8|R: 12>') self.assertEqual(repr(Perm.R | 8), '<Perm.8|R: 12>')
self.assertEqual(repr(Perm(0)), '<Perm.0: 0>') self.assertEqual(repr(Perm(0)), '<Perm.0: 0>')
self.assertEqual(repr(Perm(8)), '<Perm.8: 8>') self.assertEqual(repr(Perm(8)), '<Perm.8: 8>')
self.assertEqual(repr(~Perm.R), '<Perm.W|X|-8: -5>') self.assertEqual(repr(~Perm.R), '<Perm.W|X: -5>')
self.assertEqual(repr(~Perm.W), '<Perm.R|X|-8: -3>') self.assertEqual(repr(~Perm.W), '<Perm.R|X: -3>')
self.assertEqual(repr(~Perm.X), '<Perm.R|W|-8: -2>') self.assertEqual(repr(~Perm.X), '<Perm.R|W: -2>')
self.assertEqual(repr(~(Perm.R | Perm.W)), '<Perm.X|-8: -7>') self.assertEqual(repr(~(Perm.R | Perm.W)), '<Perm.X: -7>')
self.assertEqual(repr(~(Perm.R | Perm.W | Perm.X)), '<Perm.-8: -8>') self.assertEqual(repr(~(Perm.R | Perm.W | Perm.X)), '<Perm.-8: -8>')
self.assertEqual(repr(~(Perm.R | 8)), '<Perm.W|X|-16: -13>') self.assertEqual(repr(~(Perm.R | 8)), '<Perm.W|X: -13>')
self.assertEqual(repr(Perm(~0)), '<Perm.R|W|X|-8: -1>') self.assertEqual(repr(Perm(~0)), '<Perm.R|W|X: -1>')
self.assertEqual(repr(Perm(~8)), '<Perm.R|W|X|-16: -9>') self.assertEqual(repr(Perm(~8)), '<Perm.R|W|X: -9>')
Open = self.Open Open = self.Open
self.assertEqual(repr(Open.RO), '<Open.RO: 0>') self.assertEqual(repr(Open.RO), '<Open.RO: 0>')
@ -2024,12 +2073,12 @@ class TestIntFlag(unittest.TestCase):
self.assertEqual(repr(Open.RO | Open.CE), '<Open.CE: 524288>') self.assertEqual(repr(Open.RO | Open.CE), '<Open.CE: 524288>')
self.assertEqual(repr(Open.WO | Open.CE), '<Open.CE|WO: 524289>') self.assertEqual(repr(Open.WO | Open.CE), '<Open.CE|WO: 524289>')
self.assertEqual(repr(Open(4)), '<Open.4: 4>') self.assertEqual(repr(Open(4)), '<Open.4: 4>')
self.assertEqual(repr(~Open.RO), '<Open.CE|AC|-524292: -1>') self.assertEqual(repr(~Open.RO), '<Open.CE|AC|RW|WO: -1>')
self.assertEqual(repr(~Open.WO), '<Open.CE|RW|-524292: -2>') self.assertEqual(repr(~Open.WO), '<Open.CE|RW: -2>')
self.assertEqual(repr(~Open.AC), '<Open.CE|-524292: -4>') self.assertEqual(repr(~Open.AC), '<Open.CE: -4>')
self.assertEqual(repr(~(Open.RO | Open.CE)), '<Open.AC|-524292: -524289>') self.assertEqual(repr(~(Open.RO | Open.CE)), '<Open.AC|RW|WO: -524289>')
self.assertEqual(repr(~(Open.WO | Open.CE)), '<Open.RW|-524292: -524290>') self.assertEqual(repr(~(Open.WO | Open.CE)), '<Open.RW: -524290>')
self.assertEqual(repr(Open(~4)), '<Open.CE|AC|-524296: -5>') self.assertEqual(repr(Open(~4)), '<Open.CE|AC|RW|WO: -5>')
def test_or(self): def test_or(self):
Perm = self.Perm Perm = self.Perm