diff --git a/Doc/library/collections.rst b/Doc/library/collections.rst --- a/Doc/library/collections.rst +++ b/Doc/library/collections.rst @@ -293,7 +293,7 @@ Counter({'b': 4}) .. versionadded:: 3.3 - Added support for unary plus and unary minus. + Added support for unary plus, unary minus, and in-place multiset operations. .. note:: diff --git a/Lib/collections/__init__.py b/Lib/collections/__init__.py --- a/Lib/collections/__init__.py +++ b/Lib/collections/__init__.py @@ -683,6 +683,69 @@ ''' return Counter() - self + def _keep_positive(self): + '''Internal method to strip elements with a negative or zero count''' + nonpositive = [elem for elem, count in self.items() if not count > 0] + for elem in nonpositive: + del self[elem] + return self + + def __iadd__(self, other): + '''Inplace add from another counter, keeping only positive counts. + + >>> c = Counter('abbb') + >>> c += Counter('bcc') + >>> c + Counter({'b': 4, 'c': 2, 'a': 1}) + + ''' + for elem, count in other.items(): + self[elem] += count + return self._keep_positive() + + def __isub__(self, other): + '''Inplace subtract counter, but keep only results with positive counts. + + >>> c = Counter('abbbc') + >>> c -= Counter('bccd') + >>> c + Counter({'b': 2, 'a': 1}) + + ''' + for elem, count in other.items(): + self[elem] -= count + return self._keep_positive() + + def __ior__(self, other): + '''Inplace union is the maximum of value from either counter. + + >>> c = Counter('abbb') + >>> c |= Counter('bcc') + >>> c + Counter({'b': 3, 'c': 2, 'a': 1}) + + ''' + for elem, other_count in other.items(): + count = self[elem] + if other_count > count: + self[elem] = other_count + return self._keep_positive() + + def __iand__(self, other): + '''Inplace intersection is the minimum of corresponding counts. + + >>> c = Counter('abbb') + >>> c &= Counter('bcc') + >>> c + Counter({'b': 1}) + + ''' + for elem, count in self.items(): + other_count = other[elem] + if other_count < count: + self[elem] = other_count + return self._keep_positive() + ######################################################################## ### ChainMap (helper for configparser and string.Template) diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -932,6 +932,27 @@ set_result = setop(set(p.elements()), set(q.elements())) self.assertEqual(counter_result, dict.fromkeys(set_result, 1)) + def test_inplace_operations(self): + elements = 'abcd' + for i in range(1000): + # test random pairs of multisets + p = Counter(dict((elem, randrange(-2,4)) for elem in elements)) + p.update(e=1, f=-1, g=0) + q = Counter(dict((elem, randrange(-2,4)) for elem in elements)) + q.update(h=1, i=-1, j=0) + for inplace_op, regular_op in [ + (Counter.__iadd__, Counter.__add__), + (Counter.__isub__, Counter.__sub__), + (Counter.__ior__, Counter.__or__), + (Counter.__iand__, Counter.__and__), + ]: + c = p.copy() + c_id = id(c) + regular_result = regular_op(c, q) + inplace_result = inplace_op(c, q) + self.assertEqual(inplace_result, regular_result) + self.assertEqual(id(inplace_result), c_id) + def test_subtract(self): c = Counter(a=-5, b=0, c=5, d=10, e=15,g=40) c.subtract(a=1, b=2, c=-3, d=10, e=20, f=30, h=-50)