diff --git a/Lib/collections/__init__.py b/Lib/collections/__init__.py --- a/Lib/collections/__init__.py +++ b/Lib/collections/__init__.py @@ -861,6 +861,123 @@ class ChainMap(MutableMapping): self.maps[0].clear() +######################################################################## +### transformdict +######################################################################## + +_sentinel = object() + +class transformdict(MutableMapping): + '''Dictionary that calls a transformation function when looking + up keys, but preserves the original keys. + + >>> d = transformdict(str.lower) + >>> d['Foo'] = 5 + >>> d['foo'] == d['FOO'] == d['Foo'] == 5 + True + >>> set(d.keys()) + {'Foo'} + ''' + + __slots__ = ('_transform', '_original', '_data') + + def __init__(self, transform, init_dict=None, **kwargs): + '''Create a new transformdict with the given *transform* function. + *init_dict* and *kwargs* are optional initializers, as in the + dict constructor. + ''' + if not callable(transform): + raise TypeError("expected a callable, got %r" % transform.__class__) + self._transform = transform + # transformed => original + self._original = {} + self._data = {} + if init_dict: + self.update(init_dict) + if kwargs: + self.update(kwargs) + + # Minimum set of methods required for MutableMapping + + def __len__(self): + return len(self._data) + + def __iter__(self): + return iter(self._original.values()) + + def __getitem__(self, key): + return self._data[self._transform(key)] + + def __setitem__(self, key, value): + transformed = self._transform(key) + self._data[transformed] = value + self._original.setdefault(transformed, key) + + def __delitem__(self, key): + transformed = self._transform(key) + del self._data[transformed] + del self._original[transformed] + + # Methods overriden to mitigate the performance overhead. + + def clear(self): + 'D.clear() -> None. Remove all items from D.' + self._data.clear() + self._original.clear() + + def __contains__(self, key): + return self._transform(key) in self._data + + def get(self, key, default=None): + 'D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None.' + return self._data.get(self._transform(key), default) + + def pop(self, key, default=_sentinel): + '''D.pop(k[,d]) -> v, remove specified key and return the corresponding value. + If key is not found, d is returned if given, otherwise KeyError is raised. + ''' + transformed = self._transform(key) + if default is _sentinel: + del self._original[transformed] + return self._data.pop(transformed) + else: + self._original.pop(transformed, None) + return self._data.pop(transformed, default) + + def popitem(self): + '''D.popitem() -> (k, v), remove and return some (key, value) pair + as a 2-tuple; but raise KeyError if D is empty. + ''' + transformed, value = self._data.popitem() + return self._original.pop(transformed), value + + # Other methods + + def copy(self): + 'D.copy() -> a shallow copy of D' + other = self.__class__(self._transform) + other._original = self._original.copy() + other._data = self._data.copy() + return other + + __copy__ = copy + + def __getstate__(self): + return (self._transform, self._data, self._original) + + def __setstate__(self, state): + self._transform, self._data, self._original = state + + def __repr__(self): + try: + equiv = dict(self) + except TypeError: + # Some keys are unhashable, fall back on .items() + equiv = list(self.items()) + return '%s(%r, %s)' % (self.__class__.__name__, + self._transform, repr(equiv)) + + ################################################################################ ### UserDict ################################################################################ diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -8,11 +8,13 @@ from collections import namedtuple, Coun from test import mapping_tests import pickle, copy from random import randrange, shuffle +from functools import partial import keyword import re import sys from collections import UserDict from collections import ChainMap +from collections import transformdict from collections.abc import Hashable, Iterable, Iterator from collections.abc import Sized, Container, Callable from collections.abc import Set, MutableSet @@ -1353,6 +1355,289 @@ class SubclassMappingTests(mapping_tests self.assertRaises(KeyError, d.popitem) +def str_lower(s): + return s.lower() + +################################################################################ +### transformdict +################################################################################ + +class TransformDictTestBase(unittest.TestCase): + + def check_underlying_dict(self, d, expected): + """ + Check for implementation details. + """ + self.assertEqual(d._data, expected) + self.assertEqual(set(d._original), set(expected)) + self.assertEqual([d._transform(v) for v in d._original.values()], + list(d._original.keys())) + + +class TestTransformDict(TransformDictTestBase): + + def test_init(self): + with self.assertRaises(TypeError): + transformdict() + with self.assertRaises(TypeError): + # Too many positional args + transformdict(str.lower, {}, {}) + with self.assertRaises(TypeError): + # Not a callable + transformdict(object()) + d = transformdict(str.lower) + self.check_underlying_dict(d, {}) + pairs = [('Bar', 1), ('Foo', 2)] + d = transformdict(str.lower, pairs) + self.assertEqual(sorted(d.items()), pairs) + self.check_underlying_dict(d, {'bar': 1, 'foo': 2}) + d = transformdict(str.lower, dict(pairs)) + self.assertEqual(sorted(d.items()), pairs) + self.check_underlying_dict(d, {'bar': 1, 'foo': 2}) + d = transformdict(str.lower, **dict(pairs)) + self.assertEqual(sorted(d.items()), pairs) + self.check_underlying_dict(d, {'bar': 1, 'foo': 2}) + d = transformdict(str.lower, {'Bar': 1}, Foo=2) + self.assertEqual(sorted(d.items()), pairs) + self.check_underlying_dict(d, {'bar': 1, 'foo': 2}) + + def test_various_transforms(self): + d = transformdict(lambda s: s.encode('utf-8')) + d['Foo'] = 5 + self.assertEqual(d['Foo'], 5) + self.check_underlying_dict(d, {b'Foo': 5}) + with self.assertRaises(AttributeError): + # 'bytes' object has no attribute 'encode' + d[b'Foo'] + # Another example + d = transformdict(str.swapcase) + d['Foo'] = 5 + self.assertEqual(d['Foo'], 5) + self.check_underlying_dict(d, {'fOO': 5}) + with self.assertRaises(KeyError): + d['fOO'] + + # NOTE: we mostly test the operations which are not inherited from + # MutableMapping. + + def test_setitem_getitem(self): + d = transformdict(str.lower) + with self.assertRaises(KeyError): + d['foo'] + d['Foo'] = 5 + self.assertEqual(d['foo'], 5) + self.assertEqual(d['Foo'], 5) + self.assertEqual(d['FOo'], 5) + with self.assertRaises(KeyError): + d['bar'] + self.check_underlying_dict(d, {'foo': 5}) + d['BAR'] = 6 + self.assertEqual(d['Bar'], 6) + self.check_underlying_dict(d, {'foo': 5, 'bar': 6}) + # Overwriting + d['foO'] = 7 + self.assertEqual(d['foo'], 7) + self.assertEqual(d['Foo'], 7) + self.assertEqual(d['FOo'], 7) + self.check_underlying_dict(d, {'foo': 7, 'bar': 6}) + + def test_delitem(self): + d = transformdict(str.lower, Foo=5) + d['baR'] = 3 + del d['fOO'] + with self.assertRaises(KeyError): + del d['Foo'] + with self.assertRaises(KeyError): + del d['foo'] + self.check_underlying_dict(d, {'bar': 3}) + + def test_get(self): + d = transformdict(str.lower) + default = object() + self.assertIs(d.get('foo'), None) + self.assertIs(d.get('foo', default), default) + d['Foo'] = 5 + self.assertEqual(d.get('foo'), 5) + self.assertEqual(d.get('FOO'), 5) + self.assertIs(d.get('bar'), None) + self.check_underlying_dict(d, {'foo': 5}) + + def test_pop(self): + d = transformdict(str.lower) + default = object() + with self.assertRaises(KeyError): + d.pop('foo') + self.assertIs(d.pop('foo', default), default) + d['Foo'] = 5 + self.assertIn('foo', d) + self.assertEqual(d.pop('foo'), 5) + self.assertNotIn('foo', d) + self.check_underlying_dict(d, {}) + d['Foo'] = 5 + self.assertIn('Foo', d) + self.assertEqual(d.pop('FOO'), 5) + self.assertNotIn('foo', d) + self.check_underlying_dict(d, {}) + with self.assertRaises(KeyError): + d.pop('foo') + + def test_clear(self): + d = transformdict(str.lower) + d.clear() + self.check_underlying_dict(d, {}) + d['Foo'] = 5 + d['baR'] = 3 + self.check_underlying_dict(d, {'foo': 5, 'bar': 3}) + d.clear() + self.check_underlying_dict(d, {}) + + def test_contains(self): + d = transformdict(str.lower) + self.assertIs(False, 'foo' in d) + d['Foo'] = 5 + self.assertIs(True, 'Foo' in d) + self.assertIs(True, 'foo' in d) + self.assertIs(True, 'FOO' in d) + self.assertIs(False, 'bar' in d) + + def test_len(self): + d = transformdict(str.lower) + self.assertEqual(len(d), 0) + d['Foo'] = 5 + self.assertEqual(len(d), 1) + d['BAR'] = 6 + self.assertEqual(len(d), 2) + d['foo'] = 7 + self.assertEqual(len(d), 2) + d['baR'] = 3 + self.assertEqual(len(d), 2) + del d['Bar'] + self.assertEqual(len(d), 1) + + def test_iter(self): + d = transformdict(str.lower) + it = iter(d) + with self.assertRaises(StopIteration): + next(it) + d['Foo'] = 5 + d['BAR'] = 6 + self.assertEqual(set(x for x in d), {'Foo', 'BAR'}) + + def test_first_key_retained(self): + d = transformdict(str.lower, {'Foo': 5, 'BAR': 6}) + self.assertEqual(set(d), {'Foo', 'BAR'}) + d['foo'] = 7 + d['baR'] = 8 + d['quux'] = 9 + self.assertEqual(set(d), {'Foo', 'BAR', 'quux'}) + del d['foo'] + d['FOO'] = 9 + del d['bar'] + d.setdefault('Bar', 15) + d.setdefault('BAR', 15) + self.assertEqual(set(d), {'FOO', 'Bar', 'quux'}) + + def test_repr(self): + d = transformdict(str.lower) + self.assertEqual(repr(d), + "transformdict(, {})") + d['Foo'] = 5 + self.assertEqual(repr(d), + "transformdict(, {'Foo': 5})") + d['Bar'] = 6 + if next(iter(d)) == 'Foo': + self.assertEqual(repr(d), + "transformdict(, " + "{'Foo': 5, 'Bar': 6})") + else: + self.assertEqual(repr(d), + "transformdict(, " + "{'Bar': 6, 'Foo': 5})") + + def test_repr(self): + d = transformdict(str.lower) + self.assertEqual(repr(d), + "transformdict(, {})") + d['Foo'] = 5 + self.assertEqual(repr(d), + "transformdict(, {'Foo': 5})") + + def test_repr_non_hashable_keys(self): + d = transformdict(id) + self.assertEqual(repr(d), + "transformdict(, {})") + d[[1]] = 2 + self.assertEqual(repr(d), + "transformdict(, [([1], 2)])") + + +class TransformDictMappingTests(TransformDictTestBase, + mapping_tests.BasicTestMappingProtocol): + + transformdict = transformdict + type2test = partial(transformdict, str.lower) + + def check_shallow_copy(self, copy_func): + d = self.transformdict(str_lower, {'Foo': []}) + e = copy_func(d) + self.assertIs(e.__class__, self.transformdict) + self.assertIs(e._transform, str_lower) + self.check_underlying_dict(e, {'foo': []}) + e['Bar'] = 6 + self.assertEqual(e['bar'], 6) + with self.assertRaises(KeyError): + d['bar'] + e['foo'].append(5) + self.assertEqual(d['foo'], [5]) + self.assertEqual(set(e), {'Foo', 'Bar'}) + + def check_deep_copy(self, copy_func): + d = self.transformdict(str_lower, {'Foo': []}) + e = copy_func(d) + self.assertIs(e.__class__, self.transformdict) + self.assertIs(e._transform, str_lower) + self.check_underlying_dict(e, {'foo': []}) + e['Bar'] = 6 + self.assertEqual(e['bar'], 6) + with self.assertRaises(KeyError): + d['bar'] + e['foo'].append(5) + self.assertEqual(d['foo'], []) + self.check_underlying_dict(e, {'foo': [5], 'bar': 6}) + self.assertEqual(set(e), {'Foo', 'Bar'}) + + def test_copy(self): + self.check_shallow_copy(lambda d: d.copy()) + + def test_copy_copy(self): + self.check_shallow_copy(copy.copy) + + def test_cast_as_dict(self): + d = self.transformdict(str.lower, {'Foo': 5}) + e = dict(d) + self.assertEqual(e, {'Foo': 5}) + + def test_copy_deepcopy(self): + self.check_deep_copy(copy.deepcopy) + + def test_pickling(self): + def pickle_unpickle(obj, proto): + data = pickle.dumps(obj, proto) + return pickle.loads(data) + for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(pickle_protocol=proto): + self.check_deep_copy(partial(pickle_unpickle, proto=proto)) + + +class MyTransformDict(transformdict): + pass + +class TransformDictSubclassMappingTests(TransformDictMappingTests): + + transformdict = MyTransformDict + type2test = partial(MyTransformDict, str.lower) + + ################################################################################ ### Run tests ################################################################################ @@ -1363,7 +1648,9 @@ def test_main(verbose=None): NamedTupleDocs = doctest.DocTestSuite(module=collections) test_classes = [TestNamedTuple, NamedTupleDocs, TestOneTrickPonyABCs, TestCollectionABCs, TestCounter, TestChainMap, - TestOrderedDict, GeneralMappingTests, SubclassMappingTests] + TestOrderedDict, GeneralMappingTests, SubclassMappingTests, + TestTransformDict, TransformDictMappingTests, + TransformDictSubclassMappingTests] support.run_unittest(*test_classes) support.run_doctest(collections, verbose)