Index: Doc/library/itertools.rst =================================================================== --- Doc/library/itertools.rst (revision 68387) +++ Doc/library/itertools.rst (working copy) @@ -108,6 +108,8 @@ # combinations(range(4), 3) --> 012 013 023 123 pool = tuple(iterable) n = len(pool) + if r > n: + return indices = range(r) yield tuple(pool[i] for i in indices) while 1: @@ -132,6 +134,9 @@ if sorted(indices) == list(indices): yield tuple(pool[i] for i in indices) + The number of items returned is ``n! / r! / (n-r)!`` when ``0 <= r <= n`` + or zero when ``r > n``. + .. versionadded:: 2.6 .. function:: count([n]) @@ -399,6 +404,8 @@ pool = tuple(iterable) n = len(pool) r = n if r is None else r + if r > n: + return indices = range(n) cycles = range(n, n-r, -1) yield tuple(pool[i] for i in indices[:r]) @@ -428,6 +435,9 @@ if len(set(indices)) == r: yield tuple(pool[i] for i in indices) + The number of items returned is ``n! / (n-r)!`` when ``0 <= r <= n`` + or zero when ``r > n``. + .. versionadded:: 2.6 .. function:: product(*iterables[, repeat]) @@ -674,7 +684,8 @@ return (d for d, s in izip(data, selectors) if s) def combinations_with_replacement(iterable, r): - "combinations_with_replacement('ABC', 3) --> AA AB AC BB BC CC" + "combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC" + # number items returned: (n+r-1)! / r! / (n-1)! pool = tuple(iterable) n = len(pool) indices = [0] * r Index: Lib/test/test_itertools.py =================================================================== --- Lib/test/test_itertools.py (revision 68387) +++ Lib/test/test_itertools.py (working copy) @@ -71,11 +71,11 @@ self.assertRaises(TypeError, list, chain.from_iterable([2, 3])) def test_combinations(self): - self.assertRaises(TypeError, combinations, 'abc') # missing r argument + self.assertRaises(TypeError, combinations, 'abc') # missing r argument self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments self.assertRaises(TypeError, combinations, None) # pool is not iterable self.assertRaises(ValueError, combinations, 'abc', -2) # r is negative - self.assertRaises(ValueError, combinations, 'abc', 32) # r is too big + self.assertEqual(list(combinations('abc', 32)), []) # r > n self.assertEqual(list(combinations(range(4), 3)), [(0,1,2), (0,1,3), (0,2,3), (1,2,3)]) @@ -83,6 +83,8 @@ 'Pure python version shown in the docs' pool = tuple(iterable) n = len(pool) + if r > n: + return indices = range(r) yield tuple(pool[i] for i in indices) while 1: @@ -106,9 +108,9 @@ for n in range(7): values = [5*x-12 for x in range(n)] - for r in range(n+1): + for r in range(n+2): result = list(combinations(values, r)) - self.assertEqual(len(result), fact(n) / fact(r) / fact(n-r)) # right number of combs + self.assertEqual(len(result), 0 if r>n else fact(n) / fact(r) / fact(n-r)) # right number of combs self.assertEqual(len(result), len(set(result))) # no repeats self.assertEqual(result, sorted(result)) # lexicographic order for c in result: @@ -119,7 +121,7 @@ self.assertEqual(list(c), [e for e in values if e in c]) # comb is a subsequence of the input iterable self.assertEqual(result, list(combinations1(values, r))) # matches first pure python version - self.assertEqual(result, list(combinations2(values, r))) # matches first pure python version + self.assertEqual(result, list(combinations2(values, r))) # matches second pure python version # Test implementation detail: tuple re-use self.assertEqual(len(set(map(id, combinations('abcde', 3)))), 1) @@ -130,7 +132,7 @@ self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments self.assertRaises(TypeError, permutations, None) # pool is not iterable self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative - self.assertRaises(ValueError, permutations, 'abc', 32) # r is too big + self.assertEqual(list(permutations('abc', 32)), []) # r > n self.assertRaises(TypeError, permutations, 'abc', 's') # r is not an int or None self.assertEqual(list(permutations(range(3), 2)), [(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)]) @@ -140,6 +142,8 @@ pool = tuple(iterable) n = len(pool) r = n if r is None else r + if r > n: + return indices = range(n) cycles = range(n, n-r, -1) yield tuple(pool[i] for i in indices[:r]) @@ -168,9 +172,9 @@ for n in range(7): values = [5*x-12 for x in range(n)] - for r in range(n+1): + for r in range(n+2): result = list(permutations(values, r)) - self.assertEqual(len(result), fact(n) / fact(n-r)) # right number of perms + self.assertEqual(len(result), 0 if r>n else fact(n) / fact(n-r)) # right number of perms self.assertEqual(len(result), len(set(result))) # no repeats self.assertEqual(result, sorted(result)) # lexicographic order for p in result: @@ -178,7 +182,7 @@ self.assertEqual(len(set(p)), r) # no duplicate elements self.assert_(all(e in values for e in p)) # elements taken from input iterable self.assertEqual(result, list(permutations1(values, r))) # matches first pure python version - self.assertEqual(result, list(permutations2(values, r))) # matches first pure python version + self.assertEqual(result, list(permutations2(values, r))) # matches second pure python version if r == n: self.assertEqual(result, list(permutations(values, None))) # test r as None self.assertEqual(result, list(permutations(values))) # test default r @@ -1363,6 +1367,26 @@ >>> list(combinations_with_replacement('abc', 2)) [('a', 'a'), ('a', 'b'), ('a', 'c'), ('b', 'b'), ('b', 'c'), ('c', 'c')] +>>> list(combinations_with_replacement('01', 3)) +[('0', '0', '0'), ('0', '0', '1'), ('0', '1', '1'), ('1', '1', '1')] + +>>> def combinations_with_replacement2(iterable, r): +... 'Alternate version that filters from product()' +... pool = tuple(iterable) +... n = len(pool) +... for indices in product(range(n), repeat=r): +... if sorted(indices) == list(indices): +... yield tuple(pool[i] for i in indices) + +>>> list(combinations_with_replacement('abc', 2)) == list(combinations_with_replacement2('abc', 2)) +True + +>>> list(combinations_with_replacement('01', 3)) == list(combinations_with_replacement2('01', 3)) +True + +>>> list(combinations_with_replacement('2310', 6)) == list(combinations_with_replacement2('2310', 6)) +True + >>> list(unique_everseen('AAAABBBCCDAABBB')) ['A', 'B', 'C', 'D'] Index: Modules/itertoolsmodule.c =================================================================== --- Modules/itertoolsmodule.c (revision 68387) +++ Modules/itertoolsmodule.c (working copy) @@ -2059,10 +2059,6 @@ PyErr_SetString(PyExc_ValueError, "r must be non-negative"); goto error; } - if (r > n) { - PyErr_SetString(PyExc_ValueError, "r cannot be bigger than the iterable"); - goto error; - } indices = PyMem_Malloc(r * sizeof(Py_ssize_t)); if (indices == NULL) { @@ -2082,7 +2078,7 @@ co->indices = indices; co->result = NULL; co->r = r; - co->stopped = 0; + co->stopped = r > n ? 1 : 0; return (PyObject *)co; @@ -2318,10 +2314,6 @@ PyErr_SetString(PyExc_ValueError, "r must be non-negative"); goto error; } - if (r > n) { - PyErr_SetString(PyExc_ValueError, "r cannot be bigger than the iterable"); - goto error; - } indices = PyMem_Malloc(n * sizeof(Py_ssize_t)); cycles = PyMem_Malloc(r * sizeof(Py_ssize_t)); @@ -2345,7 +2337,7 @@ po->cycles = cycles; po->result = NULL; po->r = r; - po->stopped = 0; + po->stopped = r > n ? 1 : 0; return (PyObject *)po;