diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index d21d8ed..675aa5c 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -183,6 +183,16 @@ class TestBasicOps(unittest.TestCase): for proto in range(pickle.HIGHEST_PROTOCOL + 1): self.pickletest(proto, chain('abc', 'def'), compare=list('abcdef')) + def test_chain_setstate(self): + self.assertRaises(TypeError, chain().__setstate__, ()) + self.assertRaises(SystemError, chain().__setstate__, ([])) + self.assertRaises(ValueError, chain().__setstate__, ([],)) + self.assertRaises(ValueError, chain().__setstate__, ([], 0)) + it = chain() + it.__setstate__((iter(['abc', 'def']),)) + for x in it: + pass + def test_combinations(self): self.assertRaises(TypeError, combinations, 'abc') # missing r argument self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c index 62b6a0c..33abb0a 100644 --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -1906,6 +1906,11 @@ chain_setstate(chainobject *lz, PyObject *state) if (! PyArg_ParseTuple(state, "O|O", &source, &active)) return NULL; + if (!PyIter_Check(source) || (active != NULL && !PyIter_Check(active))) { + PyErr_SetString(PyExc_ValueError, "Arguments must be iterators."); + return NULL; + } + Py_INCREF(source); Py_XSETREF(lz->source, source); Py_XINCREF(active);