diff --git a/Lib/random.py b/Lib/random.py index 38c4a54052..8322144e8d 100644 --- a/Lib/random.py +++ b/Lib/random.py @@ -484,13 +484,11 @@ def choices(self, population, weights=None, *, cum_weights=None, k=1): the selections are made with equal probability. """ - random = self.random n = len(population) if cum_weights is None: if weights is None: - floor = _floor - n += 0.0 # convert to float for a small speed improvement - return [population[floor(random() * n)] for i in _repeat(None, k)] + randbelow = self._randbelow + return [population[randbelow(n)] for i in _repeat(None, k)] try: cum_weights = list(_accumulate(weights)) except TypeError: @@ -504,13 +502,19 @@ def choices(self, population, weights=None, *, cum_weights=None, k=1): raise TypeError('Cannot specify both weights and cumulative weights') if len(cum_weights) != n: raise ValueError('The number of weights does not match the population') - total = cum_weights[-1] + 0.0 # convert to float + total = cum_weights[-1] if total <= 0.0: raise ValueError('Total of weights must be greater than zero') if not _isfinite(total): raise ValueError('Total of weights must be finite') bisect = _bisect hi = n - 1 + if isinstance(total, int): + randbelow = self._randbelow + return [population[bisect(cum_weights, randbelow(total), 0, hi)] + for i in _repeat(None, k)] + total += 0.0 # convert to float + random = self.random return [population[bisect(cum_weights, random() * total, 0, hi)] for i in _repeat(None, k)] diff --git a/Lib/test/test_random.py b/Lib/test/test_random.py index c2dd50b981..3f24e9a7ed 100644 --- a/Lib/test/test_random.py +++ b/Lib/test/test_random.py @@ -303,6 +303,17 @@ def test_choices_subnormal(self): choices = self.gen.choices choices(population=[1, 2], weights=[1e-323, 1e-323], k=5000) + def test_choices_no_loss_of_precision(self): + # Make sure big odd numbers are included + K = 100 + for n in range(1, 64): + try: + choices = random.choices(range(2**n), k=K) + except OverflowError: + continue + parities = [x % 2 for x in choices] + self.assertNotEqual(parities, [0]*K, msg=f"{n} bit choices") + def test_choices_with_all_zero_weights(self): # See issue #38881 with self.assertRaises(ValueError):