|
6 | 6 | __all__ = [
|
7 | 7 | 'EnumType', 'EnumMeta',
|
8 | 8 | 'Enum', 'IntEnum', 'StrEnum', 'Flag', 'IntFlag',
|
9 |
| - 'auto', 'unique', |
10 |
| - 'property', |
| 9 | + 'auto', 'unique', 'property', 'verify', |
11 | 10 | 'FlagBoundary', 'STRICT', 'CONFORM', 'EJECT', 'KEEP',
|
12 | 11 | 'global_flag_repr', 'global_enum_repr', 'global_enum',
|
| 12 | + 'EnumCheck', 'CONTINUOUS', 'NAMED_FLAGS', 'UNIQUE', |
13 | 13 | ]
|
14 | 14 |
|
15 | 15 |
|
@@ -89,6 +89,9 @@ def _break_on_call_reduce(self, proto):
|
89 | 89 | setattr(obj, '__module__', '<unknown>')
|
90 | 90 |
|
91 | 91 | def _iter_bits_lsb(num):
|
| 92 | + # num must be an integer |
| 93 | + if isinstance(num, Enum): |
| 94 | + num = num.value |
92 | 95 | while num:
|
93 | 96 | b = num & (~num + 1)
|
94 | 97 | yield b
|
@@ -538,13 +541,6 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
|
538 | 541 | else:
|
539 | 542 | # multi-bit flags are considered aliases
|
540 | 543 | multi_bit_total |= flag_value
|
541 |
| - if enum_class._boundary_ is not KEEP: |
542 |
| - missed = list(_iter_bits_lsb(multi_bit_total & ~single_bit_total)) |
543 |
| - if missed: |
544 |
| - raise TypeError( |
545 |
| - 'invalid Flag %r -- missing values: %s' |
546 |
| - % (cls, ', '.join((str(i) for i in missed))) |
547 |
| - ) |
548 | 544 | enum_class._flag_mask_ = single_bit_total
|
549 | 545 | #
|
550 | 546 | # set correct __iter__
|
@@ -688,7 +684,10 @@ def __members__(cls):
|
688 | 684 | return MappingProxyType(cls._member_map_)
|
689 | 685 |
|
690 | 686 | def __repr__(cls):
|
691 |
| - return "<enum %r>" % cls.__name__ |
| 687 | + if Flag is not None and issubclass(cls, Flag): |
| 688 | + return "<flag %r>" % cls.__name__ |
| 689 | + else: |
| 690 | + return "<enum %r>" % cls.__name__ |
692 | 691 |
|
693 | 692 | def __reversed__(cls):
|
694 | 693 | """
|
@@ -1303,7 +1302,8 @@ def __invert__(self):
|
1303 | 1302 | else:
|
1304 | 1303 | # calculate flags not in this member
|
1305 | 1304 | self._inverted_ = self.__class__(self._flag_mask_ ^ self._value_)
|
1306 |
| - self._inverted_._inverted_ = self |
| 1305 | + if isinstance(self._inverted_, self.__class__): |
| 1306 | + self._inverted_._inverted_ = self |
1307 | 1307 | return self._inverted_
|
1308 | 1308 |
|
1309 | 1309 |
|
@@ -1561,6 +1561,91 @@ def convert_class(cls):
|
1561 | 1561 | return enum_class
|
1562 | 1562 | return convert_class
|
1563 | 1563 |
|
| 1564 | +@_simple_enum(StrEnum) |
| 1565 | +class EnumCheck: |
| 1566 | + """ |
| 1567 | + various conditions to check an enumeration for |
| 1568 | + """ |
| 1569 | + CONTINUOUS = "no skipped integer values" |
| 1570 | + NAMED_FLAGS = "multi-flag aliases may not contain unnamed flags" |
| 1571 | + UNIQUE = "one name per value" |
| 1572 | +CONTINUOUS, NAMED_FLAGS, UNIQUE = EnumCheck |
| 1573 | + |
| 1574 | + |
| 1575 | +class verify: |
| 1576 | + """ |
| 1577 | + Check an enumeration for various constraints. (see EnumCheck) |
| 1578 | + """ |
| 1579 | + def __init__(self, *checks): |
| 1580 | + self.checks = checks |
| 1581 | + def __call__(self, enumeration): |
| 1582 | + checks = self.checks |
| 1583 | + cls_name = enumeration.__name__ |
| 1584 | + if Flag is not None and issubclass(enumeration, Flag): |
| 1585 | + enum_type = 'flag' |
| 1586 | + elif issubclass(enumeration, Enum): |
| 1587 | + enum_type = 'enum' |
| 1588 | + else: |
| 1589 | + raise TypeError("the 'verify' decorator only works with Enum and Flag") |
| 1590 | + for check in checks: |
| 1591 | + if check is UNIQUE: |
| 1592 | + # check for duplicate names |
| 1593 | + duplicates = [] |
| 1594 | + for name, member in enumeration.__members__.items(): |
| 1595 | + if name != member.name: |
| 1596 | + duplicates.append((name, member.name)) |
| 1597 | + if duplicates: |
| 1598 | + alias_details = ', '.join( |
| 1599 | + ["%s -> %s" % (alias, name) for (alias, name) in duplicates]) |
| 1600 | + raise ValueError('aliases found in %r: %s' % |
| 1601 | + (enumeration, alias_details)) |
| 1602 | + elif check is CONTINUOUS: |
| 1603 | + values = set(e.value for e in enumeration) |
| 1604 | + if len(values) < 2: |
| 1605 | + continue |
| 1606 | + low, high = min(values), max(values) |
| 1607 | + missing = [] |
| 1608 | + if enum_type == 'flag': |
| 1609 | + # check for powers of two |
| 1610 | + for i in range(_high_bit(low)+1, _high_bit(high)): |
| 1611 | + if 2**i not in values: |
| 1612 | + missing.append(2**i) |
| 1613 | + elif enum_type == 'enum': |
| 1614 | + # check for powers of one |
| 1615 | + for i in range(low+1, high): |
| 1616 | + if i not in values: |
| 1617 | + missing.append(i) |
| 1618 | + else: |
| 1619 | + raise Exception('verify: unknown type %r' % enum_type) |
| 1620 | + if missing: |
| 1621 | + raise ValueError('invalid %s %r: missing values %s' % ( |
| 1622 | + enum_type, cls_name, ', '.join((str(m) for m in missing))) |
| 1623 | + ) |
| 1624 | + elif check is NAMED_FLAGS: |
| 1625 | + # examine each alias and check for unnamed flags |
| 1626 | + member_names = enumeration._member_names_ |
| 1627 | + member_values = [m.value for m in enumeration] |
| 1628 | + missing = [] |
| 1629 | + for name, alias in enumeration._member_map_.items(): |
| 1630 | + if name in member_names: |
| 1631 | + # not an alias |
| 1632 | + continue |
| 1633 | + values = list(_iter_bits_lsb(alias.value)) |
| 1634 | + missed = [v for v in values if v not in member_values] |
| 1635 | + if missed: |
| 1636 | + plural = ('', 's')[len(missed) > 1] |
| 1637 | + a = ('a ', '')[len(missed) > 1] |
| 1638 | + missing.append('%r is missing %snamed flag%s for value%s %s' % ( |
| 1639 | + name, a, plural, plural, |
| 1640 | + ', '.join(str(v) for v in missed) |
| 1641 | + )) |
| 1642 | + if missing: |
| 1643 | + raise ValueError( |
| 1644 | + 'invalid Flag %r: %s' |
| 1645 | + % (cls_name, '; '.join(missing)) |
| 1646 | + ) |
| 1647 | + return enumeration |
| 1648 | + |
1564 | 1649 | def _test_simple_enum(checked_enum, simple_enum):
|
1565 | 1650 | """
|
1566 | 1651 | A function that can be used to test an enum created with :func:`_simple_enum`
|
|
0 commit comments