diff -r 8fb3a6f9b0a4 Lib/random.py --- a/Lib/random.py Mon Aug 26 14:05:19 2013 +0200 +++ b/Lib/random.py Mon Aug 26 19:07:11 2013 -0400 @@ -43,6 +43,8 @@ from os import urandom as _urandom from collections.abc import Set as _Set, Sequence as _Sequence from hashlib import sha512 as _sha512 +from itertools import accumulate as _accumulate +from bisect import bisect as _bisect __all__ = ["Random","seed","random","uniform","randint","choice","sample", "randrange","shuffle","normalvariate","lognormvariate", @@ -246,14 +248,43 @@ ## -------------------- sequence methods ------------------- - def choice(self, seq): - """Choose a random element from a non-empty sequence.""" - try: - i = self._randbelow(len(seq)) - except ValueError: - raise IndexError('Cannot choose from an empty sequence') + def choice(self, seq, weights=None): + """Choose a random element from a non-empty sequence. + + Optional argument weights is a list of non-negative numbers used to + alters the probability that element at index i in the sequence is + selected. + """ + + if not weights: + try: + i = self._randbelow(len(seq)) + except ValueError: + raise IndexError('Cannot choose from an empty sequence') + return seq[i] + + # weights argument supplied + def _add(cumulative, partial): + try: + partial = float(partial) + except ValueError: + raise TypeError("All weights must be numeric") + if partial < 0: + raise ValueError("All weights must be non-negative") + return cumulative + partial + + cumulative_dist = list(_accumulate(weights, _add)) + + if len(weights) != len(seq): + raise ValueError("Length of weights must equal length of sequence") + + if not any(weights): + raise ValueError("At least one weight must be greater than zero") + + x = self.random() * cumulative_dist[-1] + i = _bisect(cumulative_dist, x) return seq[i] - + def shuffle(self, x, random=None, int=int): """Shuffle list x in place, and return None. diff -r 8fb3a6f9b0a4 Lib/test/test_random.py --- a/Lib/test/test_random.py Mon Aug 26 14:05:19 2013 +0200 +++ b/Lib/test/test_random.py Mon Aug 26 19:07:11 2013 -0400 @@ -96,6 +96,11 @@ choice([]) self.assertEqual(choice([50]), 50) self.assertIn(choice([25, 75]), [25, 75]) + self.assertIn(choice([25, 75], [.25, .75]), [25, 75]) + self.assertRaises(ValueError, choice, [25, 75], [.25]) + self.assertRaises(ValueError, choice, [25, 75], [-1, -2]) + self.assertRaises(ValueError, choice, [25, 75], [0, 0]) + self.assertRaises(TypeError, choice, [25, 75], ['str1', 'str2']) def test_sample(self): # For the entire allowable range of 0 <= k <= N, validate that