@@ -90,25 +90,25 @@ final class EpochsTests: XCTestCase {
90
90
// `inBatches` is lazy so no elements were accessed.
91
91
XCTAssert (
92
92
dataset. accessed. allSatisfy { !$0 } ,
93
- " No elements should have been accessed yet." )
93
+ " Laziness failure: no elements should have been accessed yet." )
94
94
for (i, batch) in batches. enumerated ( ) {
95
95
// Elements are not accessed until we do something with `batch` so only
96
96
// the elements up to `i * batchSize` have been accessed yet.
97
97
XCTAssert (
98
98
dataset. accessed [ ..< ( i * batchSize) ] . allSatisfy { $0 } ,
99
- " Not all elements prior to \( i * batchSize ) have been accessed ." )
99
+ " Some samples in a prior batch were unexpectedly skipped ." )
100
100
XCTAssert (
101
101
dataset. accessed [ ( i * batchSize) ... ] . allSatisfy { !$0 } ,
102
- " Some elements after \( i * batchSize ) have been accessed ." )
102
+ " Laziness failure: some samples were read prematurely ." )
103
103
let _ = Array ( batch)
104
104
let limit = ( i + 1 ) * batchSize
105
105
// We accessed elements up to `limit` but no further.
106
106
XCTAssert (
107
107
dataset. accessed [ ..< limit] . allSatisfy { $0 } ,
108
- " Not all elements prior to \( limit ) have been accessed ." )
108
+ " Some samples in a prior batch were unexpectedly skipped ." )
109
109
XCTAssert (
110
110
dataset. accessed [ limit... ] . allSatisfy { !$0 } ,
111
- " Some elements after \( limit ) have been accessed ." )
111
+ " Laziness failure: some samples were read prematurely ." )
112
112
}
113
113
}
114
114
@@ -119,7 +119,7 @@ final class EpochsTests: XCTestCase {
119
119
samples: dataset, batchSize: batchSize,
120
120
entropy: rng
121
121
) . prefix ( 10 )
122
- var lastEpochSampleOrder = Array ( 0 ..< 512 )
122
+ var lastEpochSampleOrder : [ Int ] ? = nil
123
123
for batches in epochs {
124
124
var newEpochSampleOrder : [ Int ] = [ ]
125
125
for batch in batches {
@@ -131,39 +131,41 @@ final class EpochsTests: XCTestCase {
131
131
132
132
newEpochSampleOrder += samples
133
133
}
134
- XCTAssertNotEqual (
135
- lastEpochSampleOrder, newEpochSampleOrder,
136
- " Dataset should have been reshuffled. " )
134
+ if let l = lastEpochSampleOrder {
135
+ XCTAssertNotEqual (
136
+ l, newEpochSampleOrder,
137
+ " Dataset should have been reshuffled. " )
138
+ }
137
139
138
- lastEpochSampleOrder = newEpochSampleOrder
139
- let uniqueSamples = Set ( lastEpochSampleOrder)
140
+ let uniqueSamples = Set ( newEpochSampleOrder)
140
141
XCTAssertEqual (
141
- uniqueSamples. count, lastEpochSampleOrder . count,
142
+ uniqueSamples. count, newEpochSampleOrder . count,
142
143
" Every epoch sample should be drawn from a different input sample. " )
144
+ lastEpochSampleOrder = newEpochSampleOrder
143
145
}
144
146
}
145
147
146
- func testTrainingEpochsDropsRemainder ( ) {
148
+ func testTrainingEpochsShapes ( ) {
147
149
let batchSize = 64
148
- let dataset = Array ( 0 ..< 500 )
150
+ let dataset = 0 ..< 500
149
151
let epochs = TrainingEpochs (
150
152
samples: dataset, batchSize: batchSize,
151
153
entropy: rng
152
154
) . prefix ( 1 )
153
- let samplesCount = dataset. count - dataset. count % 64
154
- for batches in epochs {
155
- XCTAssertEqual ( batches. count, 7 , " Incorrect number of batches. " )
156
- var count = 0
157
- for batch in batches {
158
- let samples = Array ( batch)
155
+
156
+ for epochBatches in epochs {
157
+ XCTAssertEqual ( epochBatches. count, 7 , " Incorrect number of batches. " )
158
+ var epochSampleCount = 0
159
+ for batch in epochBatches {
159
160
XCTAssertEqual (
160
- samples. count, batchSize,
161
- " This batch doesn't have batchSize elements. " )
162
- count += samples. count
161
+ batch. count, batchSize, " unexpected batch size: \( batch. count) " )
162
+ epochSampleCount += batch. count
163
163
}
164
+ let expectedDropCount = dataset. count % 64
165
+ let actualDropCount = dataset. count - epochSampleCount
164
166
XCTAssertEqual (
165
- count , samplesCount ,
166
- " Didn't access the right number of samples ." )
167
+ expectedDropCount , actualDropCount ,
168
+ " Dropped \( actualDropCount ) samples but expected \( expectedDropCount ) . " )
167
169
}
168
170
}
169
171
@@ -399,7 +401,7 @@ extension EpochsTests {
399
401
( " testInBatchesIsLazy " , testInBatchesIsLazy) ,
400
402
( " testBaseUse " , testBaseUse) ,
401
403
( " testTrainingEpochsShuffles " , testTrainingEpochsShuffles) ,
402
- ( " testTrainingEpochsDropsRemainder " , testTrainingEpochsDropsRemainder ) ,
404
+ ( " testTrainingEpochsShapes " , testTrainingEpochsShapes ) ,
403
405
( " testTrainingEpochsIsLazy " , testTrainingEpochsIsLazy) ,
404
406
( " testLanguageModel " , testLanguageModel) ,
405
407
( " testLanguageModelShuffled " , testLanguageModelShuffled) ,
0 commit comments