diff --git a/Lib/enum.py b/Lib/enum.py new file mode 100644 --- /dev/null +++ b/Lib/enum.py @@ -0,0 +1,454 @@ +"""Python Enumerations""" + +import sys +import weakref +from collections import OrderedDict +from types import MappingProxyType + +__all__ = ['Enum', 'IntEnum'] + + +class _RouteClassAttributeToGetattr: + """Route attribute access on a class to __getattr__. + + This is a descriptor, used to define attributes that act differently when + accessed through an instance and through a class. Instance access remains + normal, but access to an attribute through a class will be routed to the + class's __getattr__ method; this is done by raising AttributeError. + + """ + def __init__(self, fget=None): + self.fget = fget + + def __get__(self, instance, ownerclass=None): + if instance is None: + raise AttributeError() + return self.fget(instance) + + def __set__(self, instance, value): + raise AttributeError("can't set attribute") + + def __delete__(self, instance): + raise AttributeError("can't delete attribute") + + +def _is_dunder(name): + """Returns True if a dunder name, False otherwise.""" + return name[:2] == name[-2:] == '__' + + +def _make_class_unpicklable(cls): + """Make the given class un-picklable.""" + def _break_on_call_reduce(self): + raise TypeError('%r cannot be pickled' % self) + cls.__reduce__ = _break_on_call_reduce + cls.__module__ = '' + + +class _EnumDict(dict): + """Keeps track of definition order of the enum items. + + EnumMeta will use the names found in self._member_names as the + enumeration member names. + + """ + def __init__(self): + super().__init__() + self._member_names = [] + + def __setitem__(self, key, value): + """Changes anything not dundered or that doesn't have __get__. + + If a descriptor is added with the same name as an enum member, the name + is removed from _member_names (this may leave a hole in the numerical + sequence of values). + + If an enum member name is used twice, an error is raised; duplicate + values are not checked for. + + """ + if _is_dunder(key) or hasattr(value, '__get__'): + if key in self._member_names: + # overwriting an enum with a method? then remove the name from + # _member_names or it will become an enum anyway when the class + # is created + self._member_names.remove(key) + else: + if key in self._member_names: + raise TypeError('Attempted to reuse key: %r' % key) + self._member_names.append(key) + super().__setitem__(key, value) + + +# Dummy value for Enum as EnumMeta explicity checks for it, but of course until +# EnumMeta finishes running the first time the Enum class doesn't exist. This +# is also why there are checks in EnumMeta like `if Enum is not None` +Enum = None + + +class EnumMeta(type): + """Metaclass for Enum""" + @classmethod + def __prepare__(metacls, cls, bases): + return _EnumDict() + + def __new__(metacls, cls, bases, classdict): + # an Enum class is final once enumeration items have been defined; it + # cannot be mixed with other types (int, float, etc.) if it has an + # inherited __new__ unless a new __new__ is defined (or the resulting + # class will fail). + member_type, first_enum = metacls._get_mixins(bases) + __new__, save_new, use_args = metacls._find_new(classdict, member_type, + first_enum) + + # save enum items into separate mapping so they don't get baked into + # the new class + members = {k: classdict[k] for k in classdict._member_names} + for name in classdict._member_names: + del classdict[name] + + # check for illegal enum names (any others?) + invalid_names = set(members) & {'mro', '_create', + '_get_mixins', '_find_new'} + if invalid_names: + raise ValueError('Invalid enum member name: {0}'.format( + ','.join(invalid_names))) + + # create our new Enum type + enum_class = super().__new__(metacls, cls, bases, classdict) + enum_class._member_names = [] # names in definition order + enum_class._member_map = OrderedDict() # name->value map + + # Reverse value->name map for hashable values. + enum_class._value2member_map = {} + + # check for a __getnewargs__, and if not present sabotage + # pickling, since it won't work anyway + if (member_type is not object and + member_type.__dict__.get('__getnewargs__') is None + ): + _make_class_unpicklable(enum_class) + + # instantiate them, checking for duplicates as we go + # we instantiate first instead of checking for duplicates first in case + # a custom __new__ is doing something funky with the values -- such as + # auto-numbering ;) + for member_name in classdict._member_names: + value = members[member_name] + if not isinstance(value, tuple): + args = (value, ) + else: + args = value + if member_type is tuple: # special case for tuple enums + args = (args, ) # wrap it one more time + if not use_args: + enum_member = __new__(enum_class) + enum_member._value = value + else: + enum_member = __new__(enum_class, *args) + if not hasattr(enum_member, '_value'): + enum_member._value = member_type(*args) + enum_member._member_type = member_type + enum_member._name = member_name + enum_member.__init__(*args) + # If another member with the same value was already defined, the + # new member becomes an alias to the existing one. + for name, canonical_member in enum_class._member_map.items(): + if canonical_member.value == enum_member._value: + enum_member = canonical_member + break + else: + # Aliases don't appear in member names (only in __members__). + enum_class._member_names.append(member_name) + enum_class._member_map[member_name] = enum_member + try: + # This may fail if value is not hashable. We can't add the value + # to the map, and by-value lookups for this value will be + # linear. + enum_class._value2member_map[value] = enum_member + except TypeError: + pass + + # double check that repr and friends are not the mixin's or various + # things break (such as pickle) + for name in ('__repr__', '__str__', '__getnewargs__'): + class_method = getattr(enum_class, name) + obj_method = getattr(member_type, name, None) + enum_method = getattr(first_enum, name, None) + if obj_method is not None and obj_method is class_method: + setattr(enum_class, name, enum_method) + + # replace any other __new__ with our own (as long as Enum is not None, + # anyway) -- again, this is to support pickle + if Enum is not None: + # if the user defined their own __new__, save it before it gets + # clobbered in case they subclass later + if save_new: + enum_class.__new_member__ = __new__ + enum_class.__new__ = Enum.__new__ + return enum_class + + def __call__(cls, value, names=None, *, module=None, type=None): + """Either returns an existing member, or creates a new enum class. + + This method is used both when an enum class is given a value to match + to an enumeration member (i.e. Color(3)) and for the functional API + (i.e. Color = Enum('Color', names='red green blue')). + + When used for the functional API: `module`, if set, will be stored in + the new class' __module__ attribute; `type`, if set, will be mixed in + as the first base class. + + Note: if `module` is not set this routine will attempt to discover the + calling module by walking the frame stack; if this is unsuccessful + the resulting class will not be pickleable. + + """ + if names is None: # simple value lookup + return cls.__new__(cls, value) + # otherwise, functional API: we're creating a new Enum type + return cls._create(value, names, module=module, type=type) + + def __contains__(cls, member): + return isinstance(member, cls) and member.name in cls._member_map + + def __dir__(self): + return ['__class__', '__doc__', '__members__'] + self._member_names + + @property + def __members__(cls): + """Returns a mapping of member name->value. + + This mapping lists all enum members, including aliases. Note that this + is a read-only view of the internal mapping. + + """ + return MappingProxyType(cls._member_map) + + def __getattr__(cls, name): + """Return the enum member matching `name` + + We use __getattr__ instead of descriptors or inserting into the enum + class' __dict__ in order to support `name` and `value` being both + properties for enum members (which live in the class' __dict__) and + enum members themselves. + + """ + if _is_dunder(name): + raise AttributeError(name) + try: + return cls._member_map[name] + except KeyError: + raise AttributeError(name) from None + + def __getitem__(cls, name): + return cls._member_map[name] + + def __iter__(cls): + return (cls._member_map[name] for name in cls._member_names) + + def __len__(cls): + return len(cls._member_names) + + def __repr__(cls): + return "" % cls.__name__ + + def _create(cls, class_name, names=None, *, module=None, type=None): + """Convenience method to create a new Enum class. + + `names` can be: + + * A string containing member names, separated either with spaces or + commas. Values are auto-numbered from 1. + * An iterable of member names. Values are auto-numbered from 1. + * An iterable of (member name, value) pairs. + * A mapping of member name -> value. + + """ + metacls = cls.__class__ + bases = (cls, ) if type is None else (type, cls) + classdict = metacls.__prepare__(class_name, bases) + + # special processing needed for names? + if isinstance(names, str): + names = names.replace(',', ' ').split() + if isinstance(names, (tuple, list)) and isinstance(names[0], str): + names = [(e, i) for (i, e) in enumerate(names, 1)] + + # Here, names is either an iterable of (name, value) or a mapping. + for item in names: + if isinstance(item, str): + member_name, member_value = item, names[item] + else: + member_name, member_value = item + classdict[member_name] = member_value + enum_class = metacls.__new__(metacls, class_name, bases, classdict) + + # TODO: replace the frame hack if a blessed way to know the calling + # module is ever developed + if module is None: + try: + module = sys._getframe(2).f_globals['__name__'] + except (AttributeError, ValueError) as exc: + pass + if module is None: + _make_class_unpicklable(enum_class) + else: + enum_class.__module__ = module + + return enum_class + + @staticmethod + def _get_mixins(bases): + """Returns the type for creating enum members, and the first inherited + enum class. + + bases: the tuple of bases that was given to __new__ + + """ + if not bases: + return object, Enum + + # double check that we are not subclassing a class with existing + # enumeration members; while we're at it, see if any other data + # type has been mixed in so we can use the correct __new__ + member_type = first_enum = None + for base in bases: + if (base is not Enum and + issubclass(base, Enum) and + base._member_names): + raise TypeError("Cannot extend enumerations") + # base is now the last base in bases + if not issubclass(base, Enum): + raise TypeError("new enumerations must be created as " + "`ClassName([mixin_type,] enum_type)`") + + # get correct mix-in type (either mix-in type of Enum subclass, or + # first base if last base is Enum) + if not issubclass(bases[0], Enum): + member_type = bases[0] # first data type + first_enum = bases[-1] # enum type + else: + for base in bases[0].__mro__: + # most common: (IntEnum, int, Enum, object) + # possible: (, , + # , , + # ) + if issubclass(base, Enum): + if first_enum is None: + first_enum = base + else: + if member_type is None: + member_type = base + + return member_type, first_enum + + @staticmethod + def _find_new(classdict, member_type, first_enum): + """Returns the __new__ to be used for creating the enum members. + + classdict: the class dictionary given to __new__ + member_type: the data type whose __new__ will be used by default + first_enum: enumeration to check for an overriding __new__ + + """ + # now find the correct __new__, checking to see of one was defined + # by the user; also check earlier enum classes in case a __new__ was + # saved as __new_member__ + __new__ = classdict.get('__new__', None) + + # should __new__ be saved as __new_member__ later? + save_new = __new__ is not None + + if __new__ is None: + # check all possibles for __new_member__ before falling back to + # __new__ + for method in ('__new_member__', '__new__'): + for possible in (member_type, first_enum): + target = getattr(possible, method, None) + if target not in { + None, + None.__new__, + object.__new__, + Enum.__new__, + }: + __new__ = target + break + if __new__ is not None: + break + else: + __new__ = object.__new__ + + # if a non-object.__new__ is used then whatever value/tuple was + # assigned to the enum member name will be passed to __new__ and to the + # new enum member's __init__ + if __new__ is object.__new__: + use_args = False + else: + use_args = True + + return __new__, save_new, use_args + + +class Enum(metaclass=EnumMeta): + """Generic enumeration. + + Derive from this class to define new enumerations. + + """ + def __new__(cls, value): + # all enum instances are actually created during class construction + # without calling this method; this method is called by the metaclass' + # __call__ (i.e. Color(3) ), and by pickle + if type(value) is cls: + # For lookups like Color(Color.red) + return value + # by-value search for a matching enum member + # see if it's in the reverse mapping (for hashable values) + if value in cls._value2member_map: + return cls._value2member_map[value] + # not there, now do long search -- O(n) behavior + for member in cls._member_map.values(): + if member.value == value: + return member + raise ValueError("%s is not a valid %s" % (value, cls.__name__)) + + def __repr__(self): + return "<%s.%s: %r>" % ( + self.__class__.__name__, self._name, self._value) + + def __str__(self): + return "%s.%s" % (self.__class__.__name__, self._name) + + def __dir__(self): + return (['__class__', '__doc__', 'name', 'value']) + + def __eq__(self, other): + if type(other) is self.__class__: + return self is other + return NotImplemented + + def __getnewargs__(self): + return (self._value, ) + + def __hash__(self): + return hash(self._name) + + # _RouteClassAttributeToGetattr is used to provide access to the `name` + # and `value` properties of enum members while keeping some measure of + # protection from modification, while still allowing for an enumeration + # to have members named `name` and `value`. This works because enumeration + # members are not set directely on the enum class -- __getattr__ is + # used to look them up. + + @_RouteClassAttributeToGetattr + def name(self): + return self._name + + @_RouteClassAttributeToGetattr + def value(self): + return self._value + + +class IntEnum(int, Enum): + """Enum where members are also (and must be) ints""" diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py new file mode 100644 --- /dev/null +++ b/Lib/test/test_enum.py @@ -0,0 +1,778 @@ +import enum +import sys +import unittest +from collections import OrderedDict +from pickle import dumps, loads, PicklingError +from enum import Enum, IntEnum + + +# for pickle tests +try: + class Stooges(Enum): + LARRY = 1 + CURLY = 2 + MOE = 3 +except Exception as exc: + Stooges = exc + +try: + class IntStooges(int, Enum): + LARRY = 1 + CURLY = 2 + MOE = 3 +except Exception as exc: + IntStooges = exc + +try: + class FloatStooges(float, Enum): + LARRY = 1.39 + CURLY = 2.72 + MOE = 3.142596 +except Exception as exc: + FloatStooges = exc + +# for pickle test and subclass tests +try: + class StrEnum(str, Enum): + 'accepts only string values' + class Name(StrEnum): + BDFL = 'Guido van Rossum' + FLUFL = 'Barry Warsaw' +except Exception as exc: + Name = exc + +try: + Question = Enum('Question', 'who what when where why', module=__name__) +except Exception as exc: + Question = exc + +try: + Answer = Enum('Answer', 'him this then there because') +except Exception as exc: + Answer = exc + +class TestEnum(unittest.TestCase): + def setUp(self): + class Season(Enum): + SPRING = 1 + SUMMER = 2 + AUTUMN = 3 + WINTER = 4 + self.Season = Season + + def test_enum_in_enum_out(self): + Season = self.Season + self.assertIs(Season(Season.WINTER), Season.WINTER) + + def test_enum_value(self): + Season = self.Season + self.assertEqual(Season.SPRING.value, 1) + + def test_intenum_value(self): + self.assertEqual(IntStooges.CURLY.value, 2) + + def test_dir_on_class(self): + Season = self.Season + self.assertEqual( + set(dir(Season)), + set(['__class__', '__doc__', '__members__', + 'SPRING', 'SUMMER', 'AUTUMN', 'WINTER']), + ) + + def test_dir_on_item(self): + Season = self.Season + self.assertEqual( + set(dir(Season.WINTER)), + set(['__class__', '__doc__', 'name', 'value']), + ) + + def test_enum(self): + Season = self.Season + lst = list(Season) + self.assertEqual(len(lst), len(Season)) + self.assertEqual(len(Season), 4, Season) + self.assertEqual( + [Season.SPRING, Season.SUMMER, Season.AUTUMN, Season.WINTER], lst) + + for i, season in enumerate('SPRING SUMMER AUTUMN WINTER'.split(), 1): + e = Season(i) + self.assertEqual(e, getattr(Season, season)) + self.assertEqual(e.value, i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, season) + self.assertIn(e, Season) + self.assertIs(type(e), Season) + self.assertIsInstance(e, Season) + self.assertEqual(str(e), 'Season.' + season) + self.assertEqual( + repr(e), + ''.format(season, i), + ) + + def test_value_name(self): + Season = self.Season + self.assertEqual(Season.SPRING.name, 'SPRING') + self.assertEqual(Season.SPRING.value, 1) + with self.assertRaises(AttributeError): + Season.SPRING.name = 'invierno' + with self.assertRaises(AttributeError): + Season.SPRING.value = 2 + + def test_invalid_names(self): + with self.assertRaises(ValueError): + class Wrong(Enum): + mro = 9 + with self.assertRaises(ValueError): + class Wrong(Enum): + _create= 11 + with self.assertRaises(ValueError): + class Wrong(Enum): + _get_mixins = 9 + with self.assertRaises(ValueError): + class Wrong(Enum): + _find_new = 1 + + + def test_contains(self): + Season = self.Season + self.assertIn(Season.AUTUMN, Season) + self.assertNotIn(3, Season) + + val = Season(3) + self.assertIn(val, Season) + + class OtherEnum(Enum): + one = 1; two = 2 + self.assertNotIn(OtherEnum.two, Season) + + def test_comparisons(self): + Season = self.Season + with self.assertRaises(TypeError): + Season.SPRING < Season.WINTER + with self.assertRaises(TypeError): + Season.SPRING > 4 + + self.assertNotEqual(Season.SPRING, 1) + + class Part(Enum): + SPRING = 1 + CLIP = 2 + BARREL = 3 + + self.assertNotEqual(Season.SPRING, Part.SPRING) + with self.assertRaises(TypeError): + Season.SPRING < Part.CLIP + + def test_enum_duplicates(self): + class Season(Enum): + SPRING = 1 + SUMMER = 2 + AUTUMN = FALL = 3 + WINTER = 4 + ANOTHER_SPRING = 1 + lst = list(Season) + self.assertEqual( + lst, + [Season.SPRING, Season.SUMMER, + Season.AUTUMN, Season.WINTER, + ]) + self.assertIs(Season.FALL, Season.AUTUMN) + self.assertEqual(Season.FALL.value, 3) + self.assertEqual(Season.AUTUMN.value, 3) + self.assertIs(Season(3), Season.AUTUMN) + self.assertIs(Season(1), Season.SPRING) + self.assertEqual(Season.FALL.name, 'AUTUMN') + self.assertEqual( + [k for k,v in Season.__members__.items() if v.name != k], + ['FALL', 'ANOTHER_SPRING'], + ) + + def test_enum_with_value_name(self): + class Huh(Enum): + name = 1 + value = 2 + self.assertEqual( + list(Huh), + [Huh.name, Huh.value], + ) + self.assertIs(type(Huh.name), Huh) + self.assertEqual(Huh.name.name, 'name') + self.assertEqual(Huh.name.value, 1) + + def test_intenum(self): + class WeekDay(IntEnum): + SUNDAY = 1 + MONDAY = 2 + TUESDAY = 3 + WEDNESDAY = 4 + THURSDAY = 5 + FRIDAY = 6 + SATURDAY = 7 + + self.assertEqual(['a', 'b', 'c'][WeekDay.MONDAY], 'c') + self.assertEqual([i for i in range(WeekDay.TUESDAY)], [0, 1, 2]) + + lst = list(WeekDay) + self.assertEqual(len(lst), len(WeekDay)) + self.assertEqual(len(WeekDay), 7) + target = 'SUNDAY MONDAY TUESDAY WEDNESDAY THURSDAY FRIDAY SATURDAY' + target = target.split() + for i, weekday in enumerate(target, 1): + e = WeekDay(i) + self.assertEqual(e, i) + self.assertEqual(int(e), i) + self.assertEqual(e.name, weekday) + self.assertIn(e, WeekDay) + self.assertEqual(lst.index(e)+1, i) + self.assertTrue(0 < e < 8) + self.assertIs(type(e), WeekDay) + self.assertIsInstance(e, int) + self.assertIsInstance(e, Enum) + + def test_intenum_duplicates(self): + class WeekDay(IntEnum): + SUNDAY = 1 + MONDAY = 2 + TUESDAY = TEUSDAY = 3 + WEDNESDAY = 4 + THURSDAY = 5 + FRIDAY = 6 + SATURDAY = 7 + self.assertIs(WeekDay.TEUSDAY, WeekDay.TUESDAY) + self.assertEqual(WeekDay(3).name, 'TUESDAY') + self.assertEqual([k for k,v in WeekDay.__members__.items() + if v.name != k], ['TEUSDAY', ]) + + def test_pickle_enum(self): + if isinstance(Stooges, Exception): + raise Stooges + self.assertIs(Stooges.CURLY, loads(dumps(Stooges.CURLY))) + self.assertIs(Stooges, loads(dumps(Stooges))) + + def test_pickle_int(self): + if isinstance(IntStooges, Exception): + raise IntStooges + self.assertIs(IntStooges.CURLY, loads(dumps(IntStooges.CURLY))) + self.assertIs(IntStooges, loads(dumps(IntStooges))) + + def test_pickle_float(self): + if isinstance(FloatStooges, Exception): + raise FloatStooges + self.assertIs(FloatStooges.CURLY, loads(dumps(FloatStooges.CURLY))) + self.assertIs(FloatStooges, loads(dumps(FloatStooges))) + + def test_pickle_enum_function(self): + if isinstance(Answer, Exception): + raise Answer + self.assertIs(Answer.him, loads(dumps(Answer.him))) + self.assertIs(Answer, loads(dumps(Answer))) + + def test_pickle_enum_function_with_module(self): + if isinstance(Question, Exception): + raise Question + self.assertIs(Question.who, loads(dumps(Question.who))) + self.assertIs(Question, loads(dumps(Question))) + + def test_exploding_pickle(self): + BadPickle = Enum('BadPickle', 'dill sweet bread-n-butter') + enum._make_class_unpicklable(BadPickle) + globals()['BadPickle'] = BadPickle + with self.assertRaises(TypeError): + dumps(BadPickle.dill) + with self.assertRaises(PicklingError): + dumps(BadPickle) + + def test_string_enum(self): + class SkillLevel(str, Enum): + master = 'what is the sound of one hand clapping?' + journeyman = 'why did the chicken cross the road?' + apprentice = 'knock, knock!' + self.assertEqual(SkillLevel.apprentice, 'knock, knock!') + + def test_getattr_getitem(self): + class Period(Enum): + morning = 1 + noon = 2 + evening = 3 + night = 4 + self.assertIs(Period(2), Period.noon) + self.assertIs(getattr(Period, 'night'), Period.night) + self.assertIs(Period['morning'], Period.morning) + + def test_getattr_dunder(self): + Season = self.Season + self.assertTrue(getattr(Season, '__eq__')) + + def test_iteration_order(self): + class Season(Enum): + SUMMER = 2 + WINTER = 4 + AUTUMN = 3 + SPRING = 1 + self.assertEqual( + list(Season), + [Season.SUMMER, Season.WINTER, Season.AUTUMN, Season.SPRING], + ) + + def test_programatic_function_string(self): + SummerMonth = Enum('SummerMonth', 'june july august') + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate('june july august'.split(), 1): + e = SummerMonth(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertIn(e, SummerMonth) + self.assertIs(type(e), SummerMonth) + + def test_programatic_function_string_list(self): + SummerMonth = Enum('SummerMonth', ['june', 'july', 'august']) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate('june july august'.split(), 1): + e = SummerMonth(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertIn(e, SummerMonth) + self.assertIs(type(e), SummerMonth) + + def test_programatic_function_iterable(self): + SummerMonth = Enum( + 'SummerMonth', + (('june', 1), ('july', 2), ('august', 3)) + ) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate('june july august'.split(), 1): + e = SummerMonth(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertIn(e, SummerMonth) + self.assertIs(type(e), SummerMonth) + + def test_programatic_function_from_dict(self): + SummerMonth = Enum( + 'SummerMonth', + OrderedDict((('june', 1), ('july', 2), ('august', 3))) + ) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate('june july august'.split(), 1): + e = SummerMonth(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertIn(e, SummerMonth) + self.assertIs(type(e), SummerMonth) + + def test_programatic_function_type(self): + SummerMonth = Enum('SummerMonth', 'june july august', type=int) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate('june july august'.split(), 1): + e = SummerMonth(i) + self.assertEqual(e, i) + self.assertEqual(e.name, month) + self.assertIn(e, SummerMonth) + self.assertIs(type(e), SummerMonth) + + def test_programatic_function_type_from_subclass(self): + SummerMonth = IntEnum('SummerMonth', 'june july august') + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate('june july august'.split(), 1): + e = SummerMonth(i) + self.assertEqual(e, i) + self.assertEqual(e.name, month) + self.assertIn(e, SummerMonth) + self.assertIs(type(e), SummerMonth) + + def test_subclassing(self): + if isinstance(Name, Exception): + raise Name + self.assertEqual(Name.BDFL, 'Guido van Rossum') + self.assertTrue(Name.BDFL, Name('Guido van Rossum')) + self.assertIs(Name.BDFL, getattr(Name, 'BDFL')) + self.assertIs(Name.BDFL, loads(dumps(Name.BDFL))) + + def test_extending(self): + class Color(Enum): + red = 1 + green = 2 + blue = 3 + with self.assertRaises(TypeError): + class MoreColor(Color): + cyan = 4 + magenta = 5 + yellow = 6 + + def test_exclude_methods(self): + class whatever(Enum): + this = 'that' + these = 'those' + def really(self): + return 'no, not %s' % self.value + self.assertIsNot(type(whatever.really), whatever) + self.assertEqual(whatever.this.really(), 'no, not that') + + def test_overwrite_enums(self): + class Why(Enum): + question = 1 + answer = 2 + propisition = 3 + def question(self): + print(42) + self.assertIsNot(type(Why.question), Why) + self.assertNotIn(Why.question, Why._member_names) + self.assertNotIn(Why.question, Why) + + def test_wrong_inheritance_order(self): + with self.assertRaises(TypeError): + class Wrong(Enum, str): + NotHere = 'error before this point' + + def test_wrong_enum_in_call(self): + class Monochrome(Enum): + black = 0 + white = 1 + class Gender(Enum): + male = 0 + female = 1 + self.assertRaises(ValueError, Monochrome, Gender.male) + + def test_wrong_enum_in_mixed_call(self): + class Monochrome(IntEnum): + black = 0 + white = 1 + class Gender(Enum): + male = 0 + female = 1 + self.assertRaises(ValueError, Monochrome, Gender.male) + + def test_mixed_enum_in_call_1(self): + class Monochrome(IntEnum): + black = 0 + white = 1 + class Gender(IntEnum): + male = 0 + female = 1 + self.assertIs(Monochrome(Gender.female), Monochrome.white) + + def test_mixed_enum_in_call_2(self): + class Monochrome(Enum): + black = 0 + white = 1 + class Gender(IntEnum): + male = 0 + female = 1 + self.assertIs(Monochrome(Gender.male), Monochrome.black) + + def test_flufl_enum(self): + class Fluflnum(Enum): + def __int__(self): + return int(self.value) + class MailManOptions(Fluflnum): + option1 = 1 + option2 = 2 + option3 = 3 + self.assertEqual(int(MailManOptions.option1), 1) + + def test_no_such_enum_member(self): + class Color(Enum): + red = 1 + green = 2 + blue = 3 + with self.assertRaises(ValueError): + Color(4) + with self.assertRaises(KeyError): + Color['chartreuse'] + + def test_new_repr(self): + class Color(Enum): + red = 1 + green = 2 + blue = 3 + def __repr__(self): + return "don't you just love shades of %s?" % self.name + self.assertEqual( + repr(Color.blue), + "don't you just love shades of blue?", + ) + + def test_inherited_repr(self): + class MyEnum(Enum): + def __repr__(self): + return "My name is %s." % self.name + class MyIntEnum(int, MyEnum): + this = 1 + that = 2 + theother = 3 + self.assertEqual(repr(MyIntEnum.that), "My name is that.") + + def test_multiple_mixin_mro(self): + class auto_enum(type(Enum)): + def __new__(metacls, cls, bases, classdict): + temp = type(classdict)() + names = set(classdict._member_names) + i = 0 + for k in classdict._member_names: + v = classdict[k] + if v is Ellipsis: + v = i + else: + i = v + i += 1 + temp[k] = v + for k, v in classdict.items(): + if k not in names: + temp[k] = v + return super(auto_enum, metacls).__new__( + metacls, cls, bases, temp) + + class AutoNumberedEnum(Enum, metaclass=auto_enum): + pass + + class AutoIntEnum(IntEnum, metaclass=auto_enum): + pass + + class TestAutoNumber(AutoNumberedEnum): + a = ... + b = 3 + c = ... + + class TestAutoInt(AutoIntEnum): + a = ... + b = 3 + c = ... + + def test_subclasses_with_getnewargs(self): + class NamedInt(int): + def __new__(cls, *args): + _args = args + name, *args = args + if len(args) == 0: + raise TypeError("name and value must be specified") + self = int.__new__(cls, *args) + self._intname = name + self._args = _args + return self + def __getnewargs__(self): + return self._args + @property + def __name__(self): + return self._intname + def __repr__(self): + # repr() is updated to include the name and type info + return "{}({!r}, {})".format(type(self).__name__, + self.__name__, + int.__repr__(self)) + def __str__(self): + # str() is unchanged, even if it relies on the repr() fallback + base = int + base_str = base.__str__ + if base_str.__objclass__ is object: + return base.__repr__(self) + return base_str(self) + # for simplicity, we only define one operator that + # propagates expressions + def __add__(self, other): + temp = int(self) + int( other) + if isinstance(self, NamedInt) and isinstance(other, NamedInt): + return NamedInt( + '({0} + {1})'.format(self.__name__, other.__name__), + temp ) + else: + return temp + + class NEI(NamedInt, Enum): + x = ('the-x', 1) + y = ('the-y', 2) + + self.assertIs(NEI.__new__, Enum.__new__) + self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)") + globals()['NamedInt'] = NamedInt + globals()['NEI'] = NEI + NI5 = NamedInt('test', 5) + self.assertEqual(NI5, 5) + self.assertEqual(loads(dumps(NI5)), 5) + self.assertEqual(NEI.y.value, 2) + self.assertIs(loads(dumps(NEI.y)), NEI.y) + + def test_subclasses_without_getnewargs(self): + class NamedInt(int): + def __new__(cls, *args): + _args = args + name, *args = args + if len(args) == 0: + raise TypeError("name and value must be specified") + self = int.__new__(cls, *args) + self._intname = name + self._args = _args + return self + @property + def __name__(self): + return self._intname + def __repr__(self): + # repr() is updated to include the name and type info + return "{}({!r}, {})".format(type(self).__name__, + self.__name__, + int.__repr__(self)) + def __str__(self): + # str() is unchanged, even if it relies on the repr() fallback + base = int + base_str = base.__str__ + if base_str.__objclass__ is object: + return base.__repr__(self) + return base_str(self) + # for simplicity, we only define one operator that + # propagates expressions + def __add__(self, other): + temp = int(self) + int( other) + if isinstance(self, NamedInt) and isinstance(other, NamedInt): + return NamedInt( + '({0} + {1})'.format(self.__name__, other.__name__), + temp ) + else: + return temp + + class NEI(NamedInt, Enum): + x = ('the-x', 1) + y = ('the-y', 2) + + self.assertIs(NEI.__new__, Enum.__new__) + self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)") + globals()['NamedInt'] = NamedInt + globals()['NEI'] = NEI + NI5 = NamedInt('test', 5) + self.assertEqual(NI5, 5) + self.assertEqual(NEI.y.value, 2) + with self.assertRaises(TypeError): + dumps(NEI.x) + with self.assertRaises(PicklingError): + dumps(NEI) + + def test_tuple_subclass(self): + class SomeTuple(tuple, Enum): + first = (1, 'for the money') + second = (2, 'for the show') + third = (3, 'for the music') + self.assertIs(type(SomeTuple.first), SomeTuple) + self.assertIsInstance(SomeTuple.second, tuple) + self.assertEqual(SomeTuple.third, (3, 'for the music')) + globals()['SomeTuple'] = SomeTuple + self.assertIs(loads(dumps(SomeTuple.first)), SomeTuple.first) + + def test_duplicate_values_give_unique_enum_items(self): + class AutoNumber(Enum): + first = () + second = () + third = () + def __new__(cls): + value = len(cls.__members__) + 1 + obj = object.__new__(cls) + obj._value = value + return obj + def __int__(self): + return int(self._value) + self.assertEqual( + list(AutoNumber), + [AutoNumber.first, AutoNumber.second, AutoNumber.third], + ) + self.assertEqual(int(AutoNumber.second), 2) + self.assertIs(AutoNumber(1), AutoNumber.first) + + def test_inherited_new_from_enhanced_enum(self): + class AutoNumber(Enum): + def __new__(cls): + value = len(cls.__members__) + 1 + obj = object.__new__(cls) + obj._value = value + return obj + def __int__(self): + return int(self._value) + class Color(AutoNumber): + red = () + green = () + blue = () + self.assertEqual(list(Color), [Color.red, Color.green, Color.blue]) + self.assertEqual(list(map(int, Color)), [1, 2, 3]) + + def test_inherited_new_from_mixed_enum(self): + class AutoNumber(IntEnum): + def __new__(cls): + value = len(cls.__members__) + 1 + obj = int.__new__(cls, value) + obj._value = value + return obj + class Color(AutoNumber): + red = () + green = () + blue = () + self.assertEqual(list(Color), [Color.red, Color.green, Color.blue]) + self.assertEqual(list(map(int, Color)), [1, 2, 3]) + + def test_ordered_mixin(self): + class OrderedEnum(Enum): + def __ge__(self, other): + if self.__class__ is other.__class__: + return self._value >= other._value + return NotImplemented + def __gt__(self, other): + if self.__class__ is other.__class__: + return self._value > other._value + return NotImplemented + def __le__(self, other): + if self.__class__ is other.__class__: + return self._value <= other._value + return NotImplemented + def __lt__(self, other): + if self.__class__ is other.__class__: + return self._value < other._value + return NotImplemented + class Grade(OrderedEnum): + A = 5 + B = 4 + C = 3 + D = 2 + F = 1 + self.assertGreater(Grade.A, Grade.B) + self.assertLessEqual(Grade.F, Grade.C) + self.assertLess(Grade.D, Grade.A) + self.assertGreaterEqual(Grade.B, Grade.B) + + +if __name__ == '__main__': + unittest.main()