diff --git a/Lib/copy.py b/Lib/copy.py index f86040a..7fdd459 100644 --- a/Lib/copy.py +++ b/Lib/copy.py @@ -127,6 +127,55 @@ d[bytearray] = bytearray.copy if PyStringMap is not None: d[PyStringMap] = PyStringMap.copy +def _copy_itertools_chain(x): + rv = x.__reduce_ex__(4) + if len(rv) == 3: + if len(rv[2]) in (1, 2): + from itertools import tee + from operator import itemgetter + source = rv[2][0] + source1, source2 = tee(map(tee, source)) + source1 = map(itemgetter(0), source1) + source2 = map(itemgetter(1), source2) + if len(rv[2]) == 1: + x.__setstate__((source1,)) + rv = rv[:2] + ((source2,),) + else: + active = rv[2][1] + active1, active2 = tee(active) + x.__setstate__((source1, active1)) + rv = rv[:2] + ((source2, active2),) + del active, active1, active2 + del source, source1, source2 + return _reconstruct(x, None, *rv) + +def _copy_itertools_chain(x): + rv = x.__reduce_ex__(4) + if len(rv) == 3: + assert len(rv[2]) in (0, 1, 2) + if len(rv[2]) in (1, 2): + from itertools import tee + from operator import itemgetter + source = rv[2][0] + source1, source2 = tee(map(tee, source)) + source1 = map(itemgetter(0), source1) + source2 = map(itemgetter(1), source2) + if len(rv[2]) == 1: + x.__setstate__((source1,)) + rv = rv[:2] + ((source2,),) + else: + active = rv[2][1] + active1 = copy(active) + active2 = copy(active) + x.__setstate__((source1, active1)) + rv = rv[:2] + ((source2, active2),) + del active, active2 + del source, source1, source2 + return _reconstruct(x, None, *rv) +import itertools +d[itertools.chain] = _copy_itertools_chain +del itertools + del d, t def deepcopy(x, memo=None, _nil=[]): diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index c431f0d..f798771 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -196,6 +196,23 @@ class TestBasicOps(unittest.TestCase): it.__setstate__((iter(['abc', 'def']), iter(['ghi']))) self.assertEqual(list(it), ['ghi', 'a', 'b', 'c', 'd', 'e', 'f']) + def test_chain_copy(self): + a = chain(iter('ab'), iter('cd')) + b = copy.copy(a) + self.assertEqual(list(b), list('abcd')) + + self.assertEqual(next(a), 'a') + b = copy.copy(a) + self.assertEqual(list(b), list('bcd')) + + self.assertEqual(next(a), 'b') + b = copy.copy(a) + self.assertEqual(list(b), list('cd')) + + self.assertEqual(list(a), list('cd')) + b = copy.copy(a) + self.assertEqual(list(b), []) + def test_combinations(self): self.assertRaises(TypeError, combinations, 'abc') # missing r argument self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments