diff -r 06cf4044a11a Doc/library/random.rst --- a/Doc/library/random.rst Sat Aug 09 09:34:25 2014 +0300 +++ b/Doc/library/random.rst Sun Aug 10 11:47:06 2014 +0300 @@ -121,6 +121,20 @@ raises :exc:`IndexError`. +.. function:: weighted_choice(data): + + Return a random key (if *data* is a mapping) or index (if *data* is a + sequence) with the probability that is proportional to corresponded value. + If *data* is empty, raises :exc:`IndexError`. + + +.. function:: weighted_choice_generator(data): + + Return an iterator which generates random keys (if *data* is a mapping) or + indices (if *data* is a sequence) with the probability that is proportional + to corresponded values. + + .. function:: shuffle(x[, random]) Shuffle the sequence *x* in place. The optional argument *random* is a @@ -290,8 +304,8 @@ .. _random-examples: -Examples and Recipes --------------------- +Examples +-------- Basic usage:: @@ -318,22 +332,9 @@ >>> random.sample([1, 2, 3, 4, 5], 3) # Three samples without replacement [4, 1, 5] -A common task is to make a :func:`random.choice` with weighted probabilities. + >>> data = {'Red': 3, 'Blue': 2, 'Yellow': 1, 'Green': 4} + >>> random.weighted_choice(data) + 'Green' -If the weights are small integer ratios, a simple technique is to build a sample -population with repeats:: - - >>> weighted_choices = [('Red', 3), ('Blue', 2), ('Yellow', 1), ('Green', 4)] - >>> population = [val for val, cnt in weighted_choices for i in range(cnt)] - >>> random.choice(population) - 'Green' - -A more general approach is to arrange the weights in a cumulative distribution -with :func:`itertools.accumulate`, and then locate the random value with -:func:`bisect.bisect`:: - - >>> choices, weights = zip(*weighted_choices) - >>> cumdist = list(itertools.accumulate(weights)) - >>> x = random.random() * cumdist[-1] - >>> choices[bisect.bisect(cumdist, x)] - 'Blue' + >>> list(itertools.islice(random.weighted_choice_generator(data), 5)) + ['Red', 'Blue', 'Red', 'Green', 'Green'] diff -r 06cf4044a11a Lib/random.py --- a/Lib/random.py Sat Aug 09 09:34:25 2014 +0300 +++ b/Lib/random.py Sun Aug 10 11:47:06 2014 +0300 @@ -41,15 +41,17 @@ from math import log as _log, exp as _exp, pi as _pi, e as _e, ceil as _ceil from math import sqrt as _sqrt, acos as _acos, cos as _cos, sin as _sin from os import urandom as _urandom -from _collections_abc import Set as _Set, Sequence as _Sequence +from _collections_abc import Set as _Set, Sequence as _Sequence, Mapping as _Mapping from hashlib import sha512 as _sha512 +from itertools import accumulate as _accumulate +from bisect import bisect as _bisect __all__ = ["Random","seed","random","uniform","randint","choice","sample", "randrange","shuffle","normalvariate","lognormvariate", "expovariate","vonmisesvariate","gammavariate","triangular", "gauss","betavariate","paretovariate","weibullvariate", "getstate","setstate", "getrandbits", - "SystemRandom"] + "SystemRandom", "weighted_choice_generator", "weighted_choice"] NV_MAGICCONST = 4 * _exp(-0.5)/_sqrt(2.0) TWOPI = 2.0*_pi @@ -255,6 +257,56 @@ raise IndexError('Cannot choose from an empty sequence') return seq[i] + def weighted_choice(self, data): + """Choose a random element with the chances defined by relative weights. + + If argument is a mapping then returns a random key with the probability + that is proportional to the value mapped from this key. + + If argument is a sequence then returns a random index with the + probability that is proportional to value at this index. + """ + return next(self.weighted_choice_generator(data)) + + def weighted_choice_generator(self, data): + """An iterator of random elements with the chances defined by relative + weights. + + If argument is a mapping then generates random keys with the + probability that is proportional to the value mapped from this key. + + If argument is a sequence then generates random indices with the + probability that is proportional to value at this index. + """ + if isinstance(data, _Mapping): + indices = list(data.keys()) + weights = data.values() + else: + indices = None + weights = data + cumulative_dist = list(_accumulate(weights)) + total_sum = cumulative_dist[-1] + if any(w < 0 for w in weights): + raise ValueError("All weights must be non-negative") + if not total_sum: + raise ValueError("At least one weight must be greater than zero") + del weights + + # Fast path for weighted_choice() + u = _bisect(cumulative_dist, self.random() * total_sum) + if indices is not None: + u = indices[u] + yield u + + k = 1.0 / total_sum + cumulative_dist = [k * s for s in cumulative_dist] + if indices is None: + for u in iter(self.random, None): + yield _bisect(cumulative_dist, u) + else: + for u in iter(self.random, None): + yield indices[_bisect(cumulative_dist, u)] + def shuffle(self, x, random=None): """Shuffle list x in place, and return None. @@ -722,6 +774,8 @@ triangular = _inst.triangular randint = _inst.randint choice = _inst.choice +weighted_choice_generator = _inst.weighted_choice_generator +weighted_choice = _inst.weighted_choice randrange = _inst.randrange sample = _inst.sample shuffle = _inst.shuffle diff -r 06cf4044a11a Lib/test/test_random.py --- a/Lib/test/test_random.py Sat Aug 09 09:34:25 2014 +0300 +++ b/Lib/test/test_random.py Sun Aug 10 11:47:06 2014 +0300 @@ -6,6 +6,7 @@ import warnings from functools import partial from math import log, exp, pi, fsum, sin +from itertools import islice from test import support class TestBasicOps: @@ -95,6 +96,60 @@ self.assertEqual(choice([50]), 50) self.assertIn(choice([25, 75]), [25, 75]) + def test_weighted_choice(self): + weighted_choice = self.gen.weighted_choice + with self.assertRaises(TypeError): + weighted_choice() + with self.assertRaises(TypeError): + weighted_choice(42) + with self.assertRaises(TypeError): + weighted_choice([1], [1]) + with self.assertRaises(IndexError): + weighted_choice([]) + with self.assertRaises(ValueError): + weighted_choice([0]) + with self.assertRaises(ValueError): + weighted_choice([2, -1]) + self.assertEqual(weighted_choice([50]), 0) + self.assertIn(weighted_choice([25, 75]), [0, 1]) + with self.assertRaises(IndexError): + weighted_choice({}) + with self.assertRaises(ValueError): + weighted_choice({'spam': 0}) + with self.assertRaises(ValueError): + weighted_choice({'spam': 2, 'ham': -1}) + self.assertEqual(weighted_choice({'spam': 50}), 'spam') + self.assertIn(weighted_choice({'spam': 25, 'ham': 75}), ['ham', 'spam']) + + def test_weighted_choice_generator(self): + wcg = self.gen.weighted_choice_generator + with self.assertRaises(TypeError): + wcg() + with self.assertRaises(TypeError): + next(wcg(42)) + with self.assertRaises(TypeError): + wcg([1], [1]) + with self.assertRaises(IndexError): + next(wcg([])) + with self.assertRaises(ValueError): + next(wcg([0])) + with self.assertRaises(ValueError): + next(wcg([2, -1])) + for i in islice(wcg([50]), 10): + self.assertEqual(i, 0) + for i in islice(wcg([25, 75]), 10): + self.assertIn(i, [0, 1]) + with self.assertRaises(IndexError): + next(wcg({})) + with self.assertRaises(ValueError): + next(wcg({'spam': 0})) + with self.assertRaises(ValueError): + next(wcg({'spam': 2, 'ham': -1})) + for i in islice(wcg({'spam': 50}), 10): + self.assertEqual(i, 'spam') + for i in islice(wcg({'spam': 25, 'ham': 75}), 10): + self.assertIn(i, ['ham', 'spam']) + def test_sample(self): # For the entire allowable range of 0 <= k <= N, validate that # the sample is of the correct length and contains only unique items