Skip to content

Commit 1a3c706

Browse files
authored
Fix RandomSample crash when 0 was out from random. (#70)
1 parent d8a11e0 commit 1a3c706

File tree

2 files changed

+34
-5
lines changed

2 files changed

+34
-5
lines changed

Sources/Algorithms/RandomSample.swift

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ internal func nextW<G: RandomNumberGenerator>(
9494
internal func nextOffset<G: RandomNumberGenerator>(
9595
w: Double, using rng: inout G
9696
) -> Int {
97-
Int(Double.log(.random(in: 0..<1, using: &rng)) / .log(1 - w))
97+
let offset = Double.log(.random(in: 0..<1, using: &rng)) / .log(onePlus: -w)
98+
return offset < Double(Int.max) ? Int(offset) : Int.max
9899
}
99100

100101
extension Collection {
@@ -201,10 +202,10 @@ extension Sequence {
201202
w *= nextW(k: k, using: &rng)
202203

203204
// Find the offset of the next element to swap into the reservoir.
204-
var offset = nextOffset(w: w, using: &rng) + 1
205+
var offset = nextOffset(w: w, using: &rng)
205206

206-
// Skip over `offset - 1` elements to find the selected element.
207-
while offset > 1, let _ = iterator.next() {
207+
// Skip over `offset` elements to find the selected element.
208+
while offset > 0, let _ = iterator.next() {
208209
offset -= 1
209210
}
210211
guard let nextElement = iterator.next() else { break }

Tests/SwiftAlgorithmsTests/RandomSampleTests.swift

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
//===----------------------------------------------------------------------===//
1111

1212
import XCTest
13-
import Algorithms
13+
@testable import Algorithms
1414

1515
func validateRandomSamples<S: Sequence>(
1616
_ samples: [Int: Int],
@@ -95,4 +95,32 @@ final class RandomSampleTests: XCTestCase {
9595
let sample2c = c.randomStableSample(count: k, using: &generator)
9696
XCTAssertEqual(sample1c, sample2c)
9797
}
98+
99+
func testRandomSampleRandomEdgeCasesInternal() {
100+
struct ZeroGenerator: RandomNumberGenerator {
101+
mutating func next() -> UInt64 { 0 }
102+
}
103+
var zero = ZeroGenerator()
104+
_ = nextOffset(w: 1, using: &zero) // must not crash
105+
106+
struct AlmostAllZeroGenerator: RandomNumberGenerator {
107+
private var forward: SplitMix64
108+
private var count: Int = 0
109+
110+
init(seed: UInt64) {
111+
forward = SplitMix64(seed: seed)
112+
}
113+
114+
mutating func next() -> UInt64 {
115+
defer { count &+= 1 }
116+
if count % 1000 == 0 { return forward.next() }
117+
return 0
118+
}
119+
}
120+
121+
var almostAllZero = AlmostAllZeroGenerator(seed: 0)
122+
_ = s.randomSample(count: k, using: &almostAllZero) // must not crash
123+
almostAllZero = AlmostAllZeroGenerator(seed: 0)
124+
_ = c.randomSample(count: k, using: &almostAllZero) // must not crash
125+
}
98126
}

0 commit comments

Comments
 (0)