diff --git a/Lib/collections/__init__.py b/Lib/collections/__init__.py --- a/Lib/collections/__init__.py +++ b/Lib/collections/__init__.py @@ -861,6 +861,89 @@ self.maps[0].clear() +_sentinel = object() + + +class transformdict(MutableMapping): + + __slots__ = ('_transform', '_data') + + def __init__(self, transform, init_dict=None, **kwargs): + if not callable(transform): + raise TypeError("expected a callable, got %r" % transform.__class__) + self._transform = transform + self._data = {} + if init_dict: + self.update(init_dict) + if kwargs: + self.update(kwargs) + + # Minimum set of methods required for MutableMapping + + def __len__(self): + return len(self._data) + + def __iter__(self): + return (v[0] for v in self._data.values()) + + def __getitem__(self, key): + return self._data[self._transform(key)][1] + + def __setitem__(self, key, value): + self._data[self._transform(key)] = key, value + + def __delitem__(self, key): + del self._data[self._transform(key)] + + # Methods overriden to mitigate the performance overhead. + + def clear(self): + self._data.clear() + + def __contains__(self, key): + return self._transform(key) in self._data + + def get(self, key, default=_sentinel): + tup = self._data.get(self._transform(key)) + if tup is not None: + return tup[1] + elif default is not _sentinel: + return default + else: + return None + + def pop(self, key, default=_sentinel): + if default is not _sentinel: + tup = self._data.pop(self._transform(key), default) + else: + tup = self._data.pop(self._transform(key)) + if tup is not default: + return tup[1] + else: + return default + + # Other methods + + def copy(self): + other = self.__class__(self._transform) + other._data.update(self._data) + return other + + __copy__ = copy + + def __reduce__(self): + 'Return state information for pickle and deepcopy.' + if hasattr(self, '__dict__'): + inst_dict = self.__dict__.copy() + else: + inst_dict = None + return self.__class__, (self._transform,), inst_dict, None, iter(self.items()) + + def __repr__(self): + return '%s(%r, %s)' % (self.__class__.__name__, + self._transform, repr(dict(self))) + + ################################################################################ ### UserDict ################################################################################ diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -8,11 +8,13 @@ from test import mapping_tests import pickle, copy from random import randrange, shuffle +from functools import partial import keyword import re import sys from collections import UserDict from collections import ChainMap +from collections import transformdict from collections.abc import Hashable, Iterable, Iterator from collections.abc import Sized, Container, Callable from collections.abc import Set, MutableSet @@ -1353,6 +1355,253 @@ self.assertRaises(KeyError, d.popitem) +def str_lower(s): + return s.lower() + +################################################################################ +### transformdict +################################################################################ + +class TestTransformDict(unittest.TestCase): + + def check_underlying_dict(self, d, expected): + """ + Check for implementation details. + """ + self.assertEqual(set(d._data), set(expected)) + self.assertEqual({k: v[1] for k, v in d._data.items()}, expected) + #dict_iter = dict.__iter__ + #dict_getitem = dict.__getitem__ + #self.assertEqual(set(dict_iter(d)), set(expected)) + #self.assertEqual({k: dict_getitem(d, k)[1] for k in dict_iter(d)}, expected) + #self.assertEqual(dict.__len__(d), len(expected)) + + def test_init(self): + with self.assertRaises(TypeError): + transformdict() + with self.assertRaises(TypeError): + # Too many positional args + transformdict(str.lower, {}, {}) + d = transformdict(str.lower) + self.check_underlying_dict(d, {}) + pairs = [('Bar', 1), ('Foo', 2)] + d = transformdict(str.lower, pairs) + self.assertEqual(sorted(d.items()), pairs) + self.check_underlying_dict(d, {'bar': 1, 'foo': 2}) + d = transformdict(str.lower, dict(pairs)) + self.assertEqual(sorted(d.items()), pairs) + self.check_underlying_dict(d, {'bar': 1, 'foo': 2}) + d = transformdict(str.lower, **dict(pairs)) + self.assertEqual(sorted(d.items()), pairs) + self.check_underlying_dict(d, {'bar': 1, 'foo': 2}) + d = transformdict(str.lower, {'Bar': 1}, Foo=2) + self.assertEqual(sorted(d.items()), pairs) + self.check_underlying_dict(d, {'bar': 1, 'foo': 2}) + + def test_various_transforms(self): + d = transformdict(lambda s: s.encode('utf-8')) + d['Foo'] = 5 + self.assertEqual(d['Foo'], 5) + self.check_underlying_dict(d, {b'Foo': 5}) + with self.assertRaises(AttributeError): + # 'bytes' object has no attribute 'encode' + d[b'Foo'] + # Another example + d = transformdict(str.swapcase) + d['Foo'] = 5 + self.assertEqual(d['Foo'], 5) + self.check_underlying_dict(d, {'fOO': 5}) + with self.assertRaises(KeyError): + d['fOO'] + + # NOTE: we only test the operations which are not inherited from + # MutableMapping. + + def test_setitem_getitem(self): + d = transformdict(str.lower) + with self.assertRaises(KeyError): + d['foo'] + d['Foo'] = 5 + self.assertEqual(d['foo'], 5) + self.assertEqual(d['Foo'], 5) + self.assertEqual(d['FOo'], 5) + with self.assertRaises(KeyError): + d['bar'] + self.check_underlying_dict(d, {'foo': 5}) + d['BAR'] = 6 + self.assertEqual(d['Bar'], 6) + self.check_underlying_dict(d, {'foo': 5, 'bar': 6}) + # Overwriting + d['foO'] = 7 + self.assertEqual(d['foo'], 7) + self.assertEqual(d['Foo'], 7) + self.assertEqual(d['FOo'], 7) + self.check_underlying_dict(d, {'foo': 7, 'bar': 6}) + + def test_delitem(self): + d = transformdict(str.lower, Foo=5) + d['baR'] = 3 + del d['fOO'] + with self.assertRaises(KeyError): + del d['Foo'] + with self.assertRaises(KeyError): + del d['foo'] + self.check_underlying_dict(d, {'bar': 3}) + + def test_get(self): + d = transformdict(str.lower) + default = object() + self.assertIs(d.get('foo'), None) + self.assertIs(d.get('foo', default), default) + d['Foo'] = 5 + self.assertEqual(d.get('foo'), 5) + self.assertEqual(d.get('FOO'), 5) + self.assertIs(d.get('bar'), None) + self.check_underlying_dict(d, {'foo': 5}) + + def test_pop(self): + d = transformdict(str.lower) + default = object() + with self.assertRaises(KeyError): + d.pop('foo') + self.assertIs(d.pop('foo', default), default) + d['Foo'] = 5 + self.assertIn('foo', d) + self.assertEqual(d.pop('foo'), 5) + self.assertNotIn('foo', d) + self.check_underlying_dict(d, {}) + d['Foo'] = 5 + self.assertIn('Foo', d) + self.assertEqual(d.pop('FOO'), 5) + self.assertNotIn('foo', d) + self.check_underlying_dict(d, {}) + with self.assertRaises(KeyError): + d.pop('foo') + + def test_clear(self): + d = transformdict(str.lower) + d.clear() + self.check_underlying_dict(d, {}) + d['Foo'] = 5 + d['baR'] = 3 + self.check_underlying_dict(d, {'foo': 5, 'bar': 3}) + d.clear() + self.check_underlying_dict(d, {}) + + def test_contains(self): + d = transformdict(str.lower) + self.assertIs(False, 'foo' in d) + d['Foo'] = 5 + self.assertIs(True, 'Foo' in d) + self.assertIs(True, 'foo' in d) + self.assertIs(True, 'FOO' in d) + self.assertIs(False, 'bar' in d) + + def test_len(self): + d = transformdict(str.lower) + self.assertEqual(len(d), 0) + d['Foo'] = 5 + self.assertEqual(len(d), 1) + d['BAR'] = 6 + self.assertEqual(len(d), 2) + d['foo'] = 7 + self.assertEqual(len(d), 2) + d['baR'] = 3 + self.assertEqual(len(d), 2) + del d['Bar'] + self.assertEqual(len(d), 1) + + def test_iter(self): + d = transformdict(str.lower) + it = iter(d) + with self.assertRaises(StopIteration): + next(it) + d['Foo'] = 5 + d['BAR'] = 6 + yielded = [] + for x in d: + yielded.append(x) + self.assertEqual(set(yielded), {'Foo', 'BAR'}) + + def test_repr(self): + d = transformdict(str.lower) + self.assertEqual(repr(d), + "transformdict(, {})") + d['Foo'] = 5 + self.assertEqual(repr(d), + "transformdict(, {'Foo': 5})") + d['Bar'] = 6 + if next(iter(d)) == 'Foo': + self.assertEqual(repr(d), + "transformdict(, " + "{'Foo': 5, 'Bar': 6})") + else: + self.assertEqual(repr(d), + "transformdict(, " + "{'Bar': 6, 'Foo': 5})") + + def check_shallow_copy(self, copy_func): + d = transformdict(str_lower, {'Foo': []}) + e = copy_func(d) + self.assertIsInstance(e, transformdict) + self.assertIs(e._transform, str_lower) + self.check_underlying_dict(e, {'foo': []}) + e['Bar'] = 6 + self.assertEqual(e['bar'], 6) + with self.assertRaises(KeyError): + d['bar'] + e['foo'].append(5) + self.assertEqual(d['foo'], [5]) + self.assertEqual(set(e), {'Foo', 'Bar'}) + + def check_deep_copy(self, copy_func): + d = transformdict(str_lower, {'Foo': []}) + e = copy_func(d) + self.assertIsInstance(e, transformdict) + self.assertIs(e._transform, str_lower) + self.check_underlying_dict(e, {'foo': []}) + e['Bar'] = 6 + self.assertEqual(e['bar'], 6) + with self.assertRaises(KeyError): + d['bar'] + e['foo'].append(5) + self.assertEqual(d['foo'], []) + self.check_underlying_dict(e, {'foo': [5], 'bar': 6}) + self.assertEqual(set(e), {'Foo', 'Bar'}) + + def test_copy(self): + self.check_shallow_copy(lambda d: d.copy()) + + def test_copy_copy(self): + self.check_shallow_copy(copy.copy) + + def test_cast_as_dict(self): + d = transformdict(str.lower, {'Foo': 5}) + e = dict(d) + self.assertEqual(e, {'Foo': 5}) + + def test_copy_deepcopy(self): + self.check_deep_copy(copy.deepcopy) + + def test_pickling(self): + def pickle_unpickle(obj, proto): + data = pickle.dumps(obj, proto) + return pickle.loads(data) + for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(pickle_protocol=proto): + self.check_deep_copy(partial(pickle_unpickle, proto=proto)) + + +class TransformDictMappingTests(mapping_tests.BasicTestMappingProtocol): + type2test = partial(transformdict, str.lower) + +class MyTransformDict(transformdict): + pass + +class TransformDictSubclassMappingTests(mapping_tests.BasicTestMappingProtocol): + type2test = partial(MyTransformDict, str.lower) + + ################################################################################ ### Run tests ################################################################################ @@ -1363,7 +1612,9 @@ NamedTupleDocs = doctest.DocTestSuite(module=collections) test_classes = [TestNamedTuple, NamedTupleDocs, TestOneTrickPonyABCs, TestCollectionABCs, TestCounter, TestChainMap, - TestOrderedDict, GeneralMappingTests, SubclassMappingTests] + TestOrderedDict, GeneralMappingTests, SubclassMappingTests, + TestTransformDict, TransformDictMappingTests, + TransformDictSubclassMappingTests] support.run_unittest(*test_classes) support.run_doctest(collections, verbose)