Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 6094803

Browse files
authored
Clean-up tests (#927)
1 parent 5bc6e87 commit 6094803

File tree

1 file changed

+28
-26
lines changed

1 file changed

+28
-26
lines changed

Tests/TensorFlowTests/EpochsTests.swift

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -90,25 +90,25 @@ final class EpochsTests: XCTestCase {
9090
// `inBatches` is lazy so no elements were accessed.
9191
XCTAssert(
9292
dataset.accessed.allSatisfy { !$0 },
93-
"No elements should have been accessed yet.")
93+
"Laziness failure: no elements should have been accessed yet.")
9494
for (i, batch) in batches.enumerated() {
9595
// Elements are not accessed until we do something with `batch` so only
9696
// the elements up to `i * batchSize` have been accessed yet.
9797
XCTAssert(
9898
dataset.accessed[..<(i * batchSize)].allSatisfy { $0 },
99-
"Not all elements prior to \(i * batchSize) have been accessed.")
99+
"Some samples in a prior batch were unexpectedly skipped.")
100100
XCTAssert(
101101
dataset.accessed[(i * batchSize)...].allSatisfy { !$0 },
102-
"Some elements after \(i * batchSize) have been accessed.")
102+
"Laziness failure: some samples were read prematurely.")
103103
let _ = Array(batch)
104104
let limit = (i + 1) * batchSize
105105
// We accessed elements up to `limit` but no further.
106106
XCTAssert(
107107
dataset.accessed[..<limit].allSatisfy { $0 },
108-
"Not all elements prior to \(limit) have been accessed.")
108+
"Some samples in a prior batch were unexpectedly skipped.")
109109
XCTAssert(
110110
dataset.accessed[limit...].allSatisfy { !$0 },
111-
"Some elements after \(limit) have been accessed.")
111+
"Laziness failure: some samples were read prematurely.")
112112
}
113113
}
114114

@@ -119,7 +119,7 @@ final class EpochsTests: XCTestCase {
119119
samples: dataset, batchSize: batchSize,
120120
entropy: rng
121121
).prefix(10)
122-
var lastEpochSampleOrder = Array(0..<512)
122+
var lastEpochSampleOrder: [Int]? = nil
123123
for batches in epochs {
124124
var newEpochSampleOrder: [Int] = []
125125
for batch in batches {
@@ -131,39 +131,41 @@ final class EpochsTests: XCTestCase {
131131

132132
newEpochSampleOrder += samples
133133
}
134-
XCTAssertNotEqual(
135-
lastEpochSampleOrder, newEpochSampleOrder,
136-
"Dataset should have been reshuffled.")
134+
if let l = lastEpochSampleOrder {
135+
XCTAssertNotEqual(
136+
l, newEpochSampleOrder,
137+
"Dataset should have been reshuffled.")
138+
}
137139

138-
lastEpochSampleOrder = newEpochSampleOrder
139-
let uniqueSamples = Set(lastEpochSampleOrder)
140+
let uniqueSamples = Set(newEpochSampleOrder)
140141
XCTAssertEqual(
141-
uniqueSamples.count, lastEpochSampleOrder.count,
142+
uniqueSamples.count, newEpochSampleOrder.count,
142143
"Every epoch sample should be drawn from a different input sample.")
144+
lastEpochSampleOrder = newEpochSampleOrder
143145
}
144146
}
145147

146-
func testTrainingEpochsDropsRemainder() {
148+
func testTrainingEpochsShapes() {
147149
let batchSize = 64
148-
let dataset = Array(0..<500)
150+
let dataset = 0..<500
149151
let epochs = TrainingEpochs(
150152
samples: dataset, batchSize: batchSize,
151153
entropy: rng
152154
).prefix(1)
153-
let samplesCount = dataset.count - dataset.count % 64
154-
for batches in epochs {
155-
XCTAssertEqual(batches.count, 7, "Incorrect number of batches.")
156-
var count = 0
157-
for batch in batches {
158-
let samples = Array(batch)
155+
156+
for epochBatches in epochs {
157+
XCTAssertEqual(epochBatches.count, 7, "Incorrect number of batches.")
158+
var epochSampleCount = 0
159+
for batch in epochBatches {
159160
XCTAssertEqual(
160-
samples.count, batchSize,
161-
"This batch doesn't have batchSize elements.")
162-
count += samples.count
161+
batch.count, batchSize, "unexpected batch size: \(batch.count)")
162+
epochSampleCount += batch.count
163163
}
164+
let expectedDropCount = dataset.count % 64
165+
let actualDropCount = dataset.count - epochSampleCount
164166
XCTAssertEqual(
165-
count, samplesCount,
166-
"Didn't access the right number of samples.")
167+
expectedDropCount, actualDropCount,
168+
"Dropped \(actualDropCount) samples but expected \(expectedDropCount).")
167169
}
168170
}
169171

@@ -399,7 +401,7 @@ extension EpochsTests {
399401
("testInBatchesIsLazy", testInBatchesIsLazy),
400402
("testBaseUse", testBaseUse),
401403
("testTrainingEpochsShuffles", testTrainingEpochsShuffles),
402-
("testTrainingEpochsDropsRemainder", testTrainingEpochsDropsRemainder),
404+
("testTrainingEpochsShapes", testTrainingEpochsShapes),
403405
("testTrainingEpochsIsLazy", testTrainingEpochsIsLazy),
404406
("testLanguageModel", testLanguageModel),
405407
("testLanguageModelShuffled", testLanguageModelShuffled),

0 commit comments

Comments
 (0)