Skip to content

bpo-38250: [Enum] single-bit flags are canonical #24215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Jan 25, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
5f522d3
[Enum] fix Flag iteration, repr(), and str()
ethanfurman Jan 13, 2021
f7f9e72
add boundary KEEP for Flags (default for _convert_)
ethanfurman Jan 14, 2021
d289ba3
update re.RegexFlag for new Flag implementation
ethanfurman Jan 14, 2021
72dbdd7
remove extra white space
ethanfurman Jan 14, 2021
210fae7
test that zero-valued members are empty
ethanfurman Jan 14, 2021
48bcd07
update tests to confirm CONFORM with negate
ethanfurman Jan 14, 2021
1fd7471
update aenum.rst; add doctest to test_enum
ethanfurman Jan 14, 2021
aa425e6
fix doc test
ethanfurman Jan 14, 2021
4786942
optimizations
ethanfurman Jan 14, 2021
45565b2
formatting
ethanfurman Jan 14, 2021
806c8c6
add news entry
ethanfurman Jan 14, 2021
668c9a9
fix formatting of news entry
ethanfurman Jan 15, 2021
9bd9e97
update iteration method and order
ethanfurman Jan 20, 2021
18bcbac
add John Belmonte
ethanfurman Jan 20, 2021
f1c4584
more bit-fiddling improvements
ethanfurman Jan 20, 2021
e3713aa
use pop() instead of "del"
ethanfurman Jan 20, 2021
c4ec211
update DynamicClassAttribute __doc__
ethanfurman Jan 20, 2021
00b2bfe
remove formatting changes
ethanfurman Jan 20, 2021
9f432c3
remove extra parens
ethanfurman Jan 20, 2021
15c060a
remove formatting changes
ethanfurman Jan 20, 2021
86d7669
remove formatting
ethanfurman Jan 20, 2021
55915df
simplify determination of member iteration
ethanfurman Jan 22, 2021
95bf9c8
add note about next auto() value for Enum and Flag
ethanfurman Jan 25, 2021
41ac1ce
local name optimizations
ethanfurman Jan 25, 2021
3ea814e
remove commented-out code
ethanfurman Jan 25, 2021
4983558
add test for next auto() and _order_
ethanfurman Jan 25, 2021
651da18
raise TypeError if _value_ not added in custom new
ethanfurman Jan 25, 2021
6e99d48
enable doc tests, update formatting
ethanfurman Jan 25, 2021
b52c5a2
fix note
ethanfurman Jan 25, 2021
8d7b272
Update 2021-01-14-15-07-16.bpo-38250.1fvhOk.rst
ethanfurman Jan 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add boundary KEEP for Flags (default for _convert_)
some flag sets, such as ``ssl.Options`` are incomplete/inconsistent;
using KEEP allows those flags to exist, and have useful repr()s, etc.

also, add ``_inverted_`` attribute to Flag members to significantly
speed up that operation.
  • Loading branch information
ethanfurman committed Jan 14, 2021
commit f7f9e728e123dcaecccda0a96d4cdff6cce5bd01
107 changes: 77 additions & 30 deletions Lib/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
'Enum', 'IntEnum', 'StrEnum', 'Flag', 'IntFlag',
'auto', 'unique',
'property',
'FlagBoundary', 'STRICT', 'CONFORM', 'EJECT',
'FlagBoundary', 'STRICT', 'CONFORM', 'EJECT', 'KEEP',
]


Expand Down Expand Up @@ -65,19 +65,22 @@ def _is_private(cls_name, name):
return False

def _bits(num):
# return str(num) if num<=1 else bin(num>>1) + str(num&1)
if num in (0, 1):
return str(num)
if num == 0:
return '0b0'
negative = False
if num < 0:
negative = True
num = ~num
result = _bits(num>>1) + str(num&1)
digits = []
while num:
digits.insert(0, num&1)
num >>= 1
if negative:
result = '1' + ''.join(['10'[d=='1'] for d in result])
result = '0b1' + (''.join(['10'[d] for d in digits]).lstrip('0'))
else:
result = '0b0' + ''.join(str(d) for d in digits)
return result


def _bit_count(num):
"""
return number of set bits
Expand Down Expand Up @@ -130,6 +133,8 @@ def _is_single_bit(num):
"""
True if only one bit set in num (should be an int)
"""
if num == 0:
return False
num &= num - 1
return num == 0

Expand Down Expand Up @@ -442,6 +447,7 @@ def __new__(metacls, cls, bases, classdict, boundary=None, **kwds):
classdict['_boundary_'] = boundary or getattr(first_enum, '_boundary_', None)
classdict['_flag_mask_'] = flag_mask
classdict['_all_bits_'] = 2 ** ((flag_mask).bit_length()) - 1
classdict['_inverted_'] = None
#
# If a custom type is mixed into the Enum, and it does not know how
# to pickle itself, pickle.dumps will succeed but pickle.loads will
Expand Down Expand Up @@ -507,6 +513,7 @@ def __new__(metacls, cls, bases, classdict, boundary=None, **kwds):
delattr(enum_class, '_boundary_')
delattr(enum_class, '_flag_mask_')
delattr(enum_class, '_all_bits_')
delattr(enum_class, '_inverted_')
elif Flag is not None and issubclass(enum_class, Flag):
# ensure _all_bits_ is correct and there are no missing flags
single_bit_total = 0
Expand All @@ -524,7 +531,7 @@ def __new__(metacls, cls, bases, classdict, boundary=None, **kwds):
all_bit_total |= i
if i & multi_bit_total and not i & single_bit_total:
missed.append(i)
if missed:
if missed and enum_class._boundary_ is not KEEP:
raise TypeError('invalid Flag %r -- missing values: %s' % (cls, ', '.join((str(i) for i in missed))))
enum_class._flag_mask_ = single_bit_total
#
Expand All @@ -536,7 +543,12 @@ def __bool__(self):
"""
return True

def __call__(cls, value, names=None, *, module=None, qualname=None, type=None, start=1):
def __call__(
cls, value, names=None,
*,
module=None, qualname=None, type=None,
start=1, boundary=None,
):
"""
Either returns an existing member, or creates a new enum class.

Expand Down Expand Up @@ -571,6 +583,7 @@ def __call__(cls, value, names=None, *, module=None, qualname=None, type=None, s
qualname=qualname,
type=type,
start=start,
boundary=boundary,
)

def __contains__(cls, member):
Expand Down Expand Up @@ -653,7 +666,12 @@ def __setattr__(cls, name, value):
raise AttributeError('Cannot reassign members.')
super().__setattr__(name, value)

def _create_(cls, class_name, names, *, module=None, qualname=None, type=None, start=1):
def _create_(
cls, class_name, names,
*,
module=None, qualname=None, type=None,
start=1, boundary=None,
):
"""
Convenience method to create a new Enum class.

Expand Down Expand Up @@ -703,9 +721,12 @@ def _create_(cls, class_name, names, *, module=None, qualname=None, type=None, s
if qualname is not None:
classdict['__qualname__'] = qualname

return metacls.__new__(metacls, class_name, bases, classdict)
return metacls.__new__(
metacls, class_name, bases, classdict,
boundary=boundary,
)

def _convert_(cls, name, module, filter, source=None):
def _convert_(cls, name, module, filter, source=None, boundary=None):
"""
Create a new Enum subclass that replaces a collection of global constants
"""
Expand All @@ -732,7 +753,7 @@ def _convert_(cls, name, module, filter, source=None):
except TypeError:
# unless some values aren't comparable, in which case sort by name
members.sort(key=lambda t: t[0])
cls = cls(name, members, module=module)
cls = cls(name, members, module=module, boundary=boundary or KEEP)
cls.__reduce_ex__ = _reduce_ex_by_name
module_globals.update(cls.__members__)
module_globals[name] = cls
Expand Down Expand Up @@ -847,6 +868,7 @@ class Enum(metaclass=EnumMeta):

Derive from this class to define new enumerations.
"""

def __new__(cls, value):
# all enum instances are actually created during class construction
# without calling this method; this method is called by the metaclass'
Expand Down Expand Up @@ -1026,11 +1048,13 @@ class FlagBoundary(StrEnum):
"strict" -> error is raised [default for Flag]
"conform" -> extra bits are discarded
"eject" -> lose flag status [default for IntFlag]
"keep" -> keep flag status and all bits
"""
STRICT = auto()
CONFORM = auto()
EJECT = auto()
STRICT, CONFORM, EJECT = FlagBoundary
KEEP = auto()
STRICT, CONFORM, EJECT, KEEP = FlagBoundary


class Flag(Enum, boundary=STRICT):
Expand Down Expand Up @@ -1088,15 +1112,21 @@ def _missing_(cls, value):
value = value & cls._flag_mask_
elif cls._boundary_ is EJECT:
return value
elif cls._boundary_ is KEEP:
if value < 0:
value = max(cls._all_bits_+1, 2**(value.bit_length())) + value
else:
raise ValueError('unknown flag boundary: %r' % (cls._boundary_, ))
elif value < 0:
if value < 0:
neg_value = value
value = cls._all_bits_ + 1 + value
# get members
members, _ = _decompose(cls, value)
if _:
raise ValueError('%s: _decompose(%r) --> %r, %r' % (cls.__name__, value, members, _))
members, unknown = _decompose(cls, value)
if unknown and cls._boundary_ is not KEEP:
raise ValueError(
'%s: _decompose(%r) --> %r, %r'
% (cls.__name__, value, members, unknown)
)
# normal Flag?
__new__ = getattr(cls, '__new_member__', None)
if cls._member_type_ is object and not __new__:
Expand All @@ -1106,12 +1136,19 @@ def _missing_(cls, value):
pseudo_member = (__new__ or cls._member_type_.__new__)(cls, value)
if not hasattr(pseudo_member, 'value'):
pseudo_member._value_ = value
pseudo_member._name_ = '|'.join([m._name_ for m in members]) or None
if members:
pseudo_member._name_ = '|'.join([m._name_ for m in members])
if unknown:
pseudo_member._name_ += '|0x%x' % unknown
else:
pseudo_member._name_ = None
# use setdefault in case another thread already created a composite
# with this value
pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member)
if neg_value is not None:
cls._value2member_map_[neg_value] = pseudo_member
# with this value, but only if all members are known
# note: zero is a special case -- add it
if not unknown:
pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member)
if neg_value is not None:
cls._value2member_map_[neg_value] = pseudo_member
return pseudo_member

def __contains__(self, other):
Expand Down Expand Up @@ -1146,7 +1183,7 @@ def __str__(self):
if self._name_ is not None:
return '%s.%s' % (cls.__name__, self._name_)
else:
return '%s.%s' % (cls.__name__, self._value_)
return '%s(%s)' % (cls.__name__, self._value_)

def __bool__(self):
return bool(self._value_)
Expand All @@ -1167,12 +1204,19 @@ def __xor__(self, other):
return self.__class__(self._value_ ^ other._value_)

def __invert__(self):
current = set(list(self))
return self.__class__(reduce(
_or_,
[m._value_ for m in self.__class__ if m not in current],
0,
))
if self._inverted_ is None:
if self._boundary_ is KEEP:
# use all bits
self._inverted_ = self.__class__(~self._value_)
else:
# get flags not in this member
self._inverted_ = self.__class__(reduce(
_or_,
[m._value_ for m in self.__class__ if m not in self],
0
))
self._inverted_._inverted_ = self
return self._inverted_


class IntFlag(int, Flag, boundary=EJECT):
Expand Down Expand Up @@ -1255,6 +1299,9 @@ def _decompose(flag, value):
]
possibles.sort(key=lambda m: m._value_, reverse=True)
for multi in possibles:
if multi._value_ == 0:
# do not add the zero flag
continue
if multi._value_ & value == multi._value_:
members.append(multi)
value ^= multi._value_
Expand Down
21 changes: 16 additions & 5 deletions Lib/test/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -2279,12 +2279,12 @@ def test_str(self):
self.assertEqual(str(Perm.X), 'Perm.X')
self.assertEqual(str(Perm.R | Perm.W), 'Perm.R|W')
self.assertEqual(str(Perm.R | Perm.W | Perm.X), 'Perm.R|W|X')
self.assertEqual(str(Perm(0)), 'Perm.0')
self.assertEqual(str(Perm(0)), 'Perm(0)')
self.assertEqual(str(~Perm.R), 'Perm.W|X')
self.assertEqual(str(~Perm.W), 'Perm.R|X')
self.assertEqual(str(~Perm.X), 'Perm.R|W')
self.assertEqual(str(~(Perm.R | Perm.W)), 'Perm.X')
self.assertEqual(str(~(Perm.R | Perm.W | Perm.X)), 'Perm.0')
self.assertEqual(str(~(Perm.R | Perm.W | Perm.X)), 'Perm(0)')
self.assertEqual(str(Perm(~0)), 'Perm.R|W|X')

Open = self.Open
Expand Down Expand Up @@ -2421,6 +2421,11 @@ class Space(Flag, boundary=EJECT):
self.assertEqual(Space(11), 11)
self.assertTrue(type(Space(11)) is int)

def test_iter(self):
Color = self.Color
Open = self.Open
self.assertEqual(list(Color), [Color.RED, Color.GREEN, Color.BLUE])
self.assertEqual(list(Open), [Open.WO, Open.RW, Open.CE])

def test_programatic_function_string(self):
Perm = Flag('Perm', 'R W X')
Expand Down Expand Up @@ -2762,13 +2767,13 @@ def test_str(self):
self.assertEqual(str(Perm.R | Perm.W), 'Perm.R|W')
self.assertEqual(str(Perm.R | Perm.W | Perm.X), 'Perm.R|W|X')
self.assertEqual(str(Perm.R | 8), '12')
self.assertEqual(str(Perm(0)), 'Perm.0')
self.assertEqual(str(Perm(0)), 'Perm(0)')
self.assertEqual(str(Perm(8)), '8')
self.assertEqual(str(~Perm.R), 'Perm.W|X')
self.assertEqual(str(~Perm.W), 'Perm.R|X')
self.assertEqual(str(~Perm.X), 'Perm.R|W')
self.assertEqual(str(~(Perm.R | Perm.W)), 'Perm.X')
self.assertEqual(str(~(Perm.R | Perm.W | Perm.X)), 'Perm.0')
self.assertEqual(str(~(Perm.R | Perm.W | Perm.X)), 'Perm(0)')
self.assertEqual(str(~(Perm.R | 8)), '-13')
self.assertEqual(str(Perm(~0)), 'Perm.R|W|X')
self.assertEqual(str(Perm(~8)), '-9')
Expand Down Expand Up @@ -2938,6 +2943,12 @@ class Space(Flag, boundary=EJECT):
self.assertEqual(Space(11), 11)
self.assertTrue(type(Space(11)) is int)

def test_iter(self):
Color = self.Color
Open = self.Open
self.assertEqual(list(Color), [Color.RED, Color.GREEN, Color.BLUE])
self.assertEqual(list(Open), [Open.WO, Open.RW, Open.CE])

def test_programatic_function_string(self):
Perm = IntFlag('Perm', 'R W X')
lst = list(Perm)
Expand Down Expand Up @@ -3256,7 +3267,7 @@ class Sillier(IntEnum):
Help on class Color in module %s:

class Color(enum.Enum)
| Color(value, names=None, *, module=None, qualname=None, type=None, start=1)
| Color(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)
|\x20\x20
| An enumeration.
|\x20\x20
Expand Down