Index: Doc/library/itertools.rst =================================================================== --- Doc/library/itertools.rst (revision 68387) +++ Doc/library/itertools.rst (working copy) @@ -92,6 +92,7 @@ .. function:: combinations(iterable, r) Return *r* length subsequences of elements from the input *iterable*. + If *r* is larger than the *iterable*, the result will be empty. Combinations are emitted in lexicographic sort order. So, if the input *iterable* is sorted, the combination tuples will be produced @@ -108,6 +109,8 @@ # combinations(range(4), 3) --> 012 013 023 123 pool = tuple(iterable) n = len(pool) + if r > n: + return indices = range(r) yield tuple(pool[i] for i in indices) while 1: @@ -381,7 +384,8 @@ If *r* is not specified or is ``None``, then *r* defaults to the length of the *iterable* and all possible full-length permutations - are generated. + are generated. If *r* is larger than the *iterable*, the result + will be empty. Permutations are emitted in lexicographic sort order. So, if the input *iterable* is sorted, the permutation tuples will be produced @@ -399,6 +403,8 @@ pool = tuple(iterable) n = len(pool) r = n if r is None else r + if r > n: + return indices = range(n) cycles = range(n, n-r, -1) yield tuple(pool[i] for i in indices[:r]) @@ -677,6 +683,8 @@ "combinations_with_replacement('ABC', 3) --> AA AB AC BB BC CC" pool = tuple(iterable) n = len(pool) + if r > n: + return indices = [0] * r yield tuple(pool[i] for i in indices) while 1: Index: Lib/test/test_itertools.py =================================================================== --- Lib/test/test_itertools.py (revision 68387) +++ Lib/test/test_itertools.py (working copy) @@ -75,7 +75,7 @@ self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments self.assertRaises(TypeError, combinations, None) # pool is not iterable self.assertRaises(ValueError, combinations, 'abc', -2) # r is negative - self.assertRaises(ValueError, combinations, 'abc', 32) # r is too big + self.assertEqual(list(combinations('abc', 32)), []) # r > n self.assertEqual(list(combinations(range(4), 3)), [(0,1,2), (0,1,3), (0,2,3), (1,2,3)]) @@ -83,6 +83,8 @@ 'Pure python version shown in the docs' pool = tuple(iterable) n = len(pool) + if r > n: + return indices = range(r) yield tuple(pool[i] for i in indices) while 1: @@ -106,9 +108,9 @@ for n in range(7): values = [5*x-12 for x in range(n)] - for r in range(n+1): + for r in range(n+3): result = list(combinations(values, r)) - self.assertEqual(len(result), fact(n) / fact(r) / fact(n-r)) # right number of combs + self.assertEqual(len(result), 0 if r>n else fact(n) / fact(r) / fact(n-r)) # right number of combs self.assertEqual(len(result), len(set(result))) # no repeats self.assertEqual(result, sorted(result)) # lexicographic order for c in result: @@ -119,7 +121,7 @@ self.assertEqual(list(c), [e for e in values if e in c]) # comb is a subsequence of the input iterable self.assertEqual(result, list(combinations1(values, r))) # matches first pure python version - self.assertEqual(result, list(combinations2(values, r))) # matches first pure python version + self.assertEqual(result, list(combinations2(values, r))) # matches second pure python version # Test implementation detail: tuple re-use self.assertEqual(len(set(map(id, combinations('abcde', 3)))), 1) @@ -130,7 +132,7 @@ self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments self.assertRaises(TypeError, permutations, None) # pool is not iterable self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative - self.assertRaises(ValueError, permutations, 'abc', 32) # r is too big + self.assertEqual(list(permutations('abc', 32)), []) # r > n self.assertRaises(TypeError, permutations, 'abc', 's') # r is not an int or None self.assertEqual(list(permutations(range(3), 2)), [(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)]) @@ -140,6 +142,8 @@ pool = tuple(iterable) n = len(pool) r = n if r is None else r + if r > n: + return indices = range(n) cycles = range(n, n-r, -1) yield tuple(pool[i] for i in indices[:r]) @@ -168,9 +172,9 @@ for n in range(7): values = [5*x-12 for x in range(n)] - for r in range(n+1): + for r in range(n+3): result = list(permutations(values, r)) - self.assertEqual(len(result), fact(n) / fact(n-r)) # right number of perms + self.assertEqual(len(result), 0 if r>n else fact(n) / fact(n-r)) # right number of perms self.assertEqual(len(result), len(set(result))) # no repeats self.assertEqual(result, sorted(result)) # lexicographic order for p in result: @@ -178,7 +182,7 @@ self.assertEqual(len(set(p)), r) # no duplicate elements self.assert_(all(e in values for e in p)) # elements taken from input iterable self.assertEqual(result, list(permutations1(values, r))) # matches first pure python version - self.assertEqual(result, list(permutations2(values, r))) # matches first pure python version + self.assertEqual(result, list(permutations2(values, r))) # matches second pure python version if r == n: self.assertEqual(result, list(permutations(values, None))) # test r as None self.assertEqual(result, list(permutations(values))) # test default r @@ -1266,6 +1270,8 @@ ... "combinations_with_replacement('ABC', 3) --> AA AB AC BB BC CC" ... pool = tuple(iterable) ... n = len(pool) +... if r > n: +... return ... indices = [0] * r ... yield tuple(pool[i] for i in indices) ... while 1: @@ -1363,6 +1369,9 @@ >>> list(combinations_with_replacement('abc', 2)) [('a', 'a'), ('a', 'b'), ('a', 'c'), ('b', 'b'), ('b', 'c'), ('c', 'c')] +>>> list(combinations_with_replacement('abc', 10)) +[] + >>> list(unique_everseen('AAAABBBCCDAABBB')) ['A', 'B', 'C', 'D'] Index: Modules/itertoolsmodule.c =================================================================== --- Modules/itertoolsmodule.c (revision 68387) +++ Modules/itertoolsmodule.c (working copy) @@ -2059,10 +2059,6 @@ PyErr_SetString(PyExc_ValueError, "r must be non-negative"); goto error; } - if (r > n) { - PyErr_SetString(PyExc_ValueError, "r cannot be bigger than the iterable"); - goto error; - } indices = PyMem_Malloc(r * sizeof(Py_ssize_t)); if (indices == NULL) { @@ -2082,7 +2078,7 @@ co->indices = indices; co->result = NULL; co->r = r; - co->stopped = 0; + co->stopped = r > n ? 1 : 0; return (PyObject *)co; @@ -2318,10 +2314,6 @@ PyErr_SetString(PyExc_ValueError, "r must be non-negative"); goto error; } - if (r > n) { - PyErr_SetString(PyExc_ValueError, "r cannot be bigger than the iterable"); - goto error; - } indices = PyMem_Malloc(n * sizeof(Py_ssize_t)); cycles = PyMem_Malloc(r * sizeof(Py_ssize_t)); @@ -2345,7 +2337,7 @@ po->cycles = cycles; po->result = NULL; po->r = r; - po->stopped = 0; + po->stopped = r > n ? 1 : 0; return (PyObject *)po;