diff --git a/Doc/library/random.rst b/Doc/library/random.rst --- a/Doc/library/random.rst +++ b/Doc/library/random.rst @@ -121,6 +121,14 @@ Return a random element from the non-empty sequence *seq*. If *seq* is empty, raises :exc:`IndexError`. +.. function:: choices(seq, k) + + Return *k* random elements from the non-empty sequence *seq*. If *seq* is + empty, raises :exc:`IndexError`. + + Returns a new list containing elements from the *seq* while leaving the + original *seq* unchanged. An element can appear up to *k* times. *k* can + be larger than *len(seq)*. .. function:: shuffle(x[, random]) diff --git a/Lib/random.py b/Lib/random.py --- a/Lib/random.py +++ b/Lib/random.py @@ -48,7 +48,7 @@ "randrange","shuffle","normalvariate","lognormvariate", "expovariate","vonmisesvariate","gammavariate","triangular", "gauss","betavariate","paretovariate","weibullvariate", - "getstate","setstate", "getrandbits", + "getstate","setstate", "getrandbits", "choices", "SystemRandom"] NV_MAGICCONST = 4 * _exp(-0.5)/_sqrt(2.0) @@ -254,6 +254,17 @@ raise IndexError('Cannot choose from an empty sequence') return seq[i] + def choices(self, seq, k): + """Choose k random elements from a non-empty sequence.""" + seqlen = len(seq) + if seqlen == 0: + raise IndexError('Cannot choose from an empty sequence') + result = [None] * k + for i in range(k): + idx = self._randbelow(seqlen) + result[i] = seq[idx] + return result + def shuffle(self, x, random=None, int=int): """Shuffle list x in place, and return None. @@ -710,6 +721,7 @@ triangular = _inst.triangular randint = _inst.randint choice = _inst.choice +choices = _inst.choices randrange = _inst.randrange sample = _inst.sample shuffle = _inst.shuffle diff --git a/Lib/test/test_random.py b/Lib/test/test_random.py --- a/Lib/test/test_random.py +++ b/Lib/test/test_random.py @@ -97,6 +97,14 @@ self.assertEqual(choice([50]), 50) self.assertIn(choice([25, 75]), [25, 75]) + def test_choices(self): + choices = self.gen.choices + with self.assertRaises(IndexError): + choices([], 3) + self.assertEqual(choices([50], 3), [50, 50, 50]) + self.assertIn(choices([25, 75], 1)[0], [25, 75]) + self.assertTrue(set(choices("abcd", 2)).issubset(set("abcd"))) + def test_sample(self): # For the entire allowable range of 0 <= k <= N, validate that # the sample is of the correct length and contains only unique items