Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit e275188

Browse files
authored
Use Epochs to load CIFAR10 (#495)
1 parent b0cc1c4 commit e275188

File tree

10 files changed

+376
-136
lines changed

10 files changed

+376
-136
lines changed

Benchmarks/Models/ResNetCIFAR10.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ enum ResNetCIFAR10: BenchmarkModel {
3838
}
3939

4040
static func makeInferenceBenchmark(settings: BenchmarkSettings) -> Benchmark {
41-
return ImageClassificationInference<ResNet56, CIFAR10>(settings: settings)
41+
return ImageClassificationInference<ResNet56, OldCIFAR10>(settings: settings)
4242
}
4343

4444
static func makeTrainingBenchmark(settings: BenchmarkSettings) -> Benchmark {
45-
return ImageClassificationTraining<ResNet56, CIFAR10>(settings: settings)
45+
return ImageClassificationTraining<ResNet56, OldCIFAR10>(settings: settings)
4646
}
4747
}
4848

Datasets/CIFAR10/CIFAR10.swift

Lines changed: 125 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -22,120 +22,142 @@ import ModelSupport
2222
import TensorFlow
2323
import Batcher
2424

25-
public struct CIFAR10: ImageClassificationDataset {
26-
public typealias SourceDataSet = [TensorPair<Float, Int32>]
27-
public let training: Batcher<SourceDataSet>
28-
public let test: Batcher<SourceDataSet>
29-
30-
public init(batchSize: Int) {
31-
self.init(
32-
batchSize: batchSize,
33-
remoteBinaryArchiveLocation: URL(
34-
string: "https://storage.googleapis.com/s4tf-hosted-binaries/datasets/CIFAR10/cifar-10-binary.tar.gz")!,
35-
normalizing: true)
25+
public struct CIFAR10<Entropy: RandomNumberGenerator> {
26+
/// Type of the collection of non-collated batches.
27+
public typealias Batches = Slices<Sampling<[(data: [UInt8], label: Int32)], ArraySlice<Int>>>
28+
/// The type of the training data, represented as a sequence of epochs, which
29+
/// are collection of batches.
30+
public typealias Training = LazyMapSequence<
31+
TrainingEpochs<[(data: [UInt8], label: Int32)], Entropy>,
32+
LazyMapSequence<Batches, LabeledImage>
33+
>
34+
/// The type of the validation data, represented as a collection of batches.
35+
public typealias Validation = LazyMapSequence<Slices<[(data: [UInt8], label: Int32)]>, LabeledImage>
36+
/// The training epochs.
37+
public let training: Training
38+
/// The validation batches.
39+
public let validation: Validation
40+
41+
/// Creates an instance with `batchSize`.
42+
///
43+
/// - Parameter entropy: a source of randomness used to shuffle sample
44+
/// ordering. It will be stored in `self`, so if it is only pseudorandom
45+
/// and has value semantics, the sequence of epochs is deterministic and not
46+
/// dependent on other operations.
47+
public init(batchSize: Int, entropy: Entropy) {
48+
self.init(
49+
batchSize: batchSize,
50+
entropy: entropy,
51+
remoteBinaryArchiveLocation: URL(
52+
string: "https://storage.googleapis.com/s4tf-hosted-binaries/datasets/CIFAR10/cifar-10-binary.tar.gz")!,
53+
normalizing: true)
54+
}
55+
56+
/// Creates an instance with `batchSize` using `remoteBinaryArchiveLocation`.
57+
///
58+
/// - Parameters:
59+
/// - entropy: a source of randomness used to shuffle sample ordering. It
60+
/// will be stored in `self`, so if it is only pseudorandom and has value
61+
/// semantics, the sequence of epochs is deterministic and not dependent
62+
/// on other operations.
63+
/// - normalizing: normalizes the batches with the mean and standard deviation
64+
/// of the dataset iff `true`. Default value is `true`.
65+
public init(
66+
batchSize: Int,
67+
entropy: Entropy,
68+
remoteBinaryArchiveLocation: URL,
69+
localStorageDirectory: URL = DatasetUtilities.defaultDirectory
70+
.appendingPathComponent("CIFAR10", isDirectory: true),
71+
normalizing: Bool
72+
){
73+
downloadCIFAR10IfNotPresent(from: remoteBinaryArchiveLocation, to: localStorageDirectory)
74+
75+
// Training data
76+
let trainingSamples = loadCIFARTrainingFiles(in: localStorageDirectory)
77+
training = TrainingEpochs(samples: trainingSamples, batchSize: batchSize, entropy: entropy)
78+
.lazy.map { (batches: Batches) -> LazyMapSequence<Batches, LabeledImage> in
79+
return batches.lazy.map{ makeBatch(samples: $0, normalizing: normalizing) }
80+
}
81+
82+
// Validation data
83+
let validationSamples = loadCIFARTestFile(in: localStorageDirectory)
84+
validation = validationSamples.inBatches(of: batchSize).lazy.map {
85+
makeBatch(samples: $0, normalizing: normalizing)
3686
}
87+
}
88+
}
3789

38-
public init(
39-
batchSize: Int,
40-
remoteBinaryArchiveLocation: URL,
41-
localStorageDirectory: URL = DatasetUtilities.defaultDirectory
42-
.appendingPathComponent("CIFAR10", isDirectory: true),
43-
normalizing: Bool)
44-
{
45-
downloadCIFAR10IfNotPresent(from: remoteBinaryArchiveLocation, to: localStorageDirectory)
46-
self.training = Batcher(
47-
on: loadCIFARTrainingFiles(localStorageDirectory: localStorageDirectory, normalizing: normalizing),
48-
batchSize: batchSize,
49-
numWorkers: 1, //No need to use parallelism since everything is loaded in memory
50-
shuffle: true)
51-
self.test = Batcher(
52-
on: loadCIFARTestFile(localStorageDirectory: localStorageDirectory, normalizing: normalizing),
53-
batchSize: batchSize,
54-
numWorkers: 1) //No need to use parallelism since everything is loaded in memory
55-
}
90+
extension CIFAR10: ImageClassificationData where Entropy == SystemRandomNumberGenerator {
91+
/// Creates an instance with `batchSize`.
92+
public init(batchSize: Int) {
93+
self.init(batchSize: batchSize, entropy: SystemRandomNumberGenerator())
94+
}
5695
}
5796

5897
func downloadCIFAR10IfNotPresent(from location: URL, to directory: URL) {
59-
let downloadPath = directory.appendingPathComponent("cifar-10-batches-bin").path
60-
let directoryExists = FileManager.default.fileExists(atPath: downloadPath)
61-
let contentsOfDir = try? FileManager.default.contentsOfDirectory(atPath: downloadPath)
62-
let directoryEmpty = (contentsOfDir == nil) || (contentsOfDir!.isEmpty)
98+
let downloadPath = directory.appendingPathComponent("cifar-10-batches-bin").path
99+
let directoryExists = FileManager.default.fileExists(atPath: downloadPath)
100+
let contentsOfDir = try? FileManager.default.contentsOfDirectory(atPath: downloadPath)
101+
let directoryEmpty = (contentsOfDir == nil) || (contentsOfDir!.isEmpty)
63102

64-
guard !directoryExists || directoryEmpty else { return }
103+
guard !directoryExists || directoryEmpty else { return }
65104

66-
let _ = DatasetUtilities.downloadResource(
67-
filename: "cifar-10-binary", fileExtension: "tar.gz",
68-
remoteRoot: location.deletingLastPathComponent(), localStorageDirectory: directory)
105+
let _ = DatasetUtilities.downloadResource(
106+
filename: "cifar-10-binary", fileExtension: "tar.gz",
107+
remoteRoot: location.deletingLastPathComponent(), localStorageDirectory: directory)
69108
}
70109

71-
func loadCIFARFile(named name: String, in directory: URL, normalizing: Bool = true) -> [TensorPair<Float, Int32>] {
72-
let path = directory.appendingPathComponent("cifar-10-batches-bin/\(name)").path
73-
74-
let imageCount = 10000
75-
guard let fileContents = try? Data(contentsOf: URL(fileURLWithPath: path)) else {
76-
printError("Could not read dataset file: \(name)")
77-
exit(-1)
78-
}
79-
guard fileContents.count == 30_730_000 else {
80-
printError(
81-
"Dataset file \(name) should have 30730000 bytes, instead had \(fileContents.count)")
82-
exit(-1)
83-
}
84-
85-
var bytes: [UInt8] = []
86-
var labels: [Int64] = []
87-
88-
let imageByteSize = 3073
89-
for imageIndex in 0..<imageCount {
90-
let baseAddress = imageIndex * imageByteSize
91-
labels.append(Int64(fileContents[baseAddress]))
92-
bytes.append(contentsOf: fileContents[(baseAddress + 1)..<(baseAddress + 3073)])
93-
}
94-
95-
let labelTensor = Tensor<Int64>(shape: [imageCount], scalars: labels)
96-
let images = Tensor<UInt8>(shape: [imageCount, 3, 32, 32], scalars: bytes)
97-
98-
// Transpose from the CIFAR-provided N(CHW) to TF's default NHWC.
99-
var imageTensor = Tensor<Float>(images.transposed(permutation: [0, 2, 3, 1]))
100-
101-
// The value of mean and std were calculated with the following Swift code:
102-
// ```
103-
// import TensorFlow
104-
// import Datasets
105-
// import Foundation
106-
// let urlString = "https://storage.googleapis.com/s4tf-hosted-binaries/datasets/CIFAR10/cifar-10-binary.tar.gz"
107-
// let cifar = CIFAR10(batchSize: 50000,
108-
// remoteBinaryArchiveLocation: URL(string: urlString)!,
109-
// normalizing: false)
110-
// for batch in cifar.training.sequenced() {
111-
// let images = Tensor<Double>(batch.first) / 255.0
112-
// let mom = images.moments(squeezingAxes: [0,1,2])
113-
// print("mean: \(mom.mean) std: \(sqrt(mom.variance))")
114-
// }
115-
// ```
116-
if normalizing {
117-
let mean = Tensor<Float>(
118-
[0.4913996898,
119-
0.4821584196,
120-
0.4465309242])
121-
let std = Tensor<Float>(
122-
[0.2470322324,
123-
0.2434851280,
124-
0.2615878417])
125-
imageTensor = ((imageTensor / 255.0) - mean) / std
126-
}
127-
128-
return (0..<imageCount).map { TensorPair(first: imageTensor[$0], second: Tensor<Int32>(labelTensor[$0])) }
129-
110+
func loadCIFARFile(named name: String, in directory: URL) -> [(data: [UInt8], label: Int32)] {
111+
let path = directory.appendingPathComponent("cifar-10-batches-bin/\(name)").path
112+
113+
let imageCount = 10000
114+
guard let fileContents = try? Data(contentsOf: URL(fileURLWithPath: path)) else {
115+
printError("Could not read dataset file: \(name)")
116+
exit(-1)
117+
}
118+
guard fileContents.count == 30_730_000 else {
119+
printError(
120+
"Dataset file \(name) should have 30730000 bytes, instead had \(fileContents.count)")
121+
exit(-1)
122+
}
123+
124+
var labeledImages: [(data: [UInt8], label: Int32)] = []
125+
126+
let imageByteSize = 3073
127+
for imageIndex in 0..<imageCount {
128+
let baseAddress = imageIndex * imageByteSize
129+
let label = Int32(fileContents[baseAddress])
130+
let data = [UInt8](fileContents[(baseAddress + 1)..<(baseAddress + 3073)])
131+
labeledImages.append((data: data, label: label))
132+
}
133+
134+
return labeledImages
130135
}
131136

132-
func loadCIFARTrainingFiles(localStorageDirectory: URL, normalizing: Bool = true) -> [TensorPair<Float, Int32>] {
133-
let data = (1..<6).map {
134-
loadCIFARFile(named: "data_batch_\($0).bin", in: localStorageDirectory, normalizing: normalizing)
135-
}
136-
return data.reduce([], +)
137+
func loadCIFARTrainingFiles(in localStorageDirectory: URL) -> [(data: [UInt8], label: Int32)] {
138+
let data = (1..<6).map {
139+
loadCIFARFile(named: "data_batch_\($0).bin", in: localStorageDirectory)
140+
}
141+
return data.reduce([], +)
137142
}
138143

139-
func loadCIFARTestFile(localStorageDirectory: URL, normalizing: Bool = true) -> [TensorPair<Float, Int32>] {
140-
return loadCIFARFile(named: "test_batch.bin", in: localStorageDirectory, normalizing: normalizing)
144+
func loadCIFARTestFile(in localStorageDirectory: URL) -> [(data: [UInt8], label: Int32)] {
145+
return loadCIFARFile(named: "test_batch.bin", in: localStorageDirectory)
141146
}
147+
148+
func makeBatch<BatchSamples: Collection>(samples: BatchSamples, normalizing: Bool) -> LabeledImage
149+
where BatchSamples.Element == (data: [UInt8], label: Int32) {
150+
let bytes = samples.lazy.map(\.data).reduce(into: [], +=)
151+
let images = Tensor<UInt8>(shape: [samples.count, 3, 32, 32], scalars: bytes)
152+
153+
var imageTensor = Tensor<Float>(images.transposed(permutation: [0, 2, 3, 1]))
154+
imageTensor /= 255.0
155+
if normalizing {
156+
let mean = Tensor<Float>([0.4913996898, 0.4821584196, 0.4465309242])
157+
let std = Tensor<Float>([0.2470322324, 0.2434851280, 0.2615878417])
158+
imageTensor = (imageTensor - mean) / std
159+
}
160+
161+
let labels = Tensor<Int32>(samples.map(\.label))
162+
return LabeledImage(data: imageTensor, label: labels)
163+
}

0 commit comments

Comments
 (0)