Skip to content

Commit a90539f

Browse files
bpo-42944 Fix Random.sample when counts is not None (GH-24235) (GH-24243)
1 parent 799722c commit a90539f

File tree

3 files changed

+29
-28
lines changed

3 files changed

+29
-28
lines changed

Lib/random.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def sample(self, population, k, *, counts=None):
442442
raise TypeError('Counts must be integers')
443443
if total <= 0:
444444
raise ValueError('Total of counts must be greater than zero')
445-
selections = sample(range(total), k=k)
445+
selections = self.sample(range(total), k=k)
446446
bisect = _bisect
447447
return [population[bisect(cum_counts, s)] for s in selections]
448448
randbelow = self._randbelow

Lib/test/test_random.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -207,33 +207,6 @@ def test_sample_with_counts(self):
207207
with self.assertRaises(ValueError):
208208
sample(['red', 'green', 'blue'], counts=[1, 2, 3, 4], k=2) # too many counts
209209

210-
def test_sample_counts_equivalence(self):
211-
# Test the documented strong equivalence to a sample with repeated elements.
212-
# We run this test on random.Random() which makes deterministic selections
213-
# for a given seed value.
214-
sample = random.sample
215-
seed = random.seed
216-
217-
colors = ['red', 'green', 'blue', 'orange', 'black', 'amber']
218-
counts = [500, 200, 20, 10, 5, 1 ]
219-
k = 700
220-
seed(8675309)
221-
s1 = sample(colors, counts=counts, k=k)
222-
seed(8675309)
223-
expanded = [color for (color, count) in zip(colors, counts) for i in range(count)]
224-
self.assertEqual(len(expanded), sum(counts))
225-
s2 = sample(expanded, k=k)
226-
self.assertEqual(s1, s2)
227-
228-
pop = 'abcdefghi'
229-
counts = [10, 9, 8, 7, 6, 5, 4, 3, 2]
230-
seed(8675309)
231-
s1 = ''.join(sample(pop, counts=counts, k=30))
232-
expanded = ''.join([letter for (letter, count) in zip(pop, counts) for i in range(count)])
233-
seed(8675309)
234-
s2 = ''.join(sample(expanded, k=30))
235-
self.assertEqual(s1, s2)
236-
237210
def test_choices(self):
238211
choices = self.gen.choices
239212
data = ['red', 'green', 'blue', 'yellow']
@@ -888,6 +861,33 @@ def test_randbytes_getrandbits(self):
888861
self.assertEqual(self.gen.randbytes(n),
889862
gen2.getrandbits(n * 8).to_bytes(n, 'little'))
890863

864+
def test_sample_counts_equivalence(self):
865+
# Test the documented strong equivalence to a sample with repeated elements.
866+
# We run this test on random.Random() which makes deterministic selections
867+
# for a given seed value.
868+
sample = self.gen.sample
869+
seed = self.gen.seed
870+
871+
colors = ['red', 'green', 'blue', 'orange', 'black', 'amber']
872+
counts = [500, 200, 20, 10, 5, 1 ]
873+
k = 700
874+
seed(8675309)
875+
s1 = sample(colors, counts=counts, k=k)
876+
seed(8675309)
877+
expanded = [color for (color, count) in zip(colors, counts) for i in range(count)]
878+
self.assertEqual(len(expanded), sum(counts))
879+
s2 = sample(expanded, k=k)
880+
self.assertEqual(s1, s2)
881+
882+
pop = 'abcdefghi'
883+
counts = [10, 9, 8, 7, 6, 5, 4, 3, 2]
884+
seed(8675309)
885+
s1 = ''.join(sample(pop, counts=counts, k=30))
886+
expanded = ''.join([letter for (letter, count) in zip(pop, counts) for i in range(count)])
887+
seed(8675309)
888+
s2 = ''.join(sample(expanded, k=30))
889+
self.assertEqual(s1, s2)
890+
891891

892892
def gamma(z, sqrt2pi=(2.0*pi)**0.5):
893893
# Reflection to right half of complex plane
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix ``random.Random.sample`` when ``counts`` argument is not ``None``.

0 commit comments

Comments
 (0)