diff --git a/Doc/library/collections.abc.rst b/Doc/library/collections.abc.rst --- a/Doc/library/collections.abc.rst +++ b/Doc/library/collections.abc.rst @@ -55,7 +55,7 @@ :class:`Set` :class:`Sized`, ``__contains__``, ``__le__``, ``__lt__``, ``__eq__``, ``__ne__``, :class:`Iterable`, ``__iter__``, ``__gt__``, ``__ge__``, ``__and__``, ``__or__``, - :class:`Container` ``__len__`` ``__sub__``, ``__xor__``, and ``isdisjoint`` + :class:`Container` ``__len__`` ``__sub__``, ``__xor__``, ``copy``, and ``isdisjoint`` :class:`MutableSet` :class:`Set` ``__contains__``, Inherited :class:`Set` methods and ``__iter__``, ``clear``, ``pop``, ``remove``, ``__ior__``, diff --git a/Lib/_collections_abc.py b/Lib/_collections_abc.py --- a/Lib/_collections_abc.py +++ b/Lib/_collections_abc.py @@ -236,6 +236,10 @@ ''' return cls(it) + def copy(self): + """Return a new set with a shallow copy of s.""" + return self._from_iterable(self) + def __and__(self, other): if not isinstance(other, Iterable): return NotImplemented diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -585,6 +585,22 @@ B.register(C) self.assertTrue(issubclass(C, B)) + +class MySet(Set): + + def __init__(self, itr): + self.contents = itr + + def __contains__(self, x): + return x in self.contents + + def __iter__(self): + return iter(self.contents) + + def __len__(self): + return len([x for x in self.contents]) + + class WithSet(MutableSet): def __init__(self, it=()): @@ -605,6 +621,7 @@ def discard(self, item): self.data.discard(item) + class TestCollectionABCs(ABCTestCase): # XXX For now, we only test some virtual inheritance properties. @@ -616,14 +633,20 @@ self.assertIsInstance(sample(), Set) self.assertTrue(issubclass(sample, Set)) self.validate_abstract_methods(Set, '__contains__', '__iter__', '__len__') - class MySet(Set): + class EmptySet(Set): def __contains__(self, x): return False def __len__(self): return 0 def __iter__(self): return iter([]) - self.validate_comparison(MySet()) + self.validate_comparison(EmptySet()) + + def test_copy_Set(self): + for s in map(MySet, [(), 'a', 'ab', 'abc']): + copy = s.copy() + self.assertEqual(s, copy) + self.assertFalse(s is copy) def test_hash_Set(self): class OneTwoThreeSet(Set): @@ -641,15 +664,6 @@ self.assertTrue(hash(a) == hash(b)) def test_isdisjoint_Set(self): - class MySet(Set): - def __init__(self, itr): - self.contents = itr - def __contains__(self, x): - return x in self.contents - def __iter__(self): - return iter(self.contents) - def __len__(self): - return len([x for x in self.contents]) s1 = MySet((1, 2, 3)) s2 = MySet((4, 5, 6)) s3 = MySet((1, 5, 6)) @@ -657,15 +671,6 @@ self.assertFalse(s1.isdisjoint(s3)) def test_equality_Set(self): - class MySet(Set): - def __init__(self, itr): - self.contents = itr - def __contains__(self, x): - return x in self.contents - def __iter__(self): - return iter(self.contents) - def __len__(self): - return len([x for x in self.contents]) s1 = MySet((1,)) s2 = MySet((1, 2)) s3 = MySet((3, 4)) @@ -679,15 +684,6 @@ self.assertNotEqual(s2, s3) def test_arithmetic_Set(self): - class MySet(Set): - def __init__(self, itr): - self.contents = itr - def __contains__(self, x): - return x in self.contents - def __iter__(self): - return iter(self.contents) - def __len__(self): - return len([x for x in self.contents]) s1 = MySet((1, 2, 3)) s2 = MySet((3, 4, 5)) s3 = s1 & s2