@@ -22,120 +22,142 @@ import ModelSupport
22
22
import TensorFlow
23
23
import Batcher
24
24
25
- public struct CIFAR10 : ImageClassificationDataset {
26
- public typealias SourceDataSet = [ TensorPair < Float , Int32 > ]
27
- public let training : Batcher < SourceDataSet >
28
- public let test : Batcher < SourceDataSet >
29
-
30
- public init ( batchSize: Int ) {
31
- self . init (
32
- batchSize: batchSize,
33
- remoteBinaryArchiveLocation: URL (
34
- string: " https://storage.googleapis.com/s4tf-hosted-binaries/datasets/CIFAR10/cifar-10-binary.tar.gz " ) !,
35
- normalizing: true )
25
+ public struct CIFAR10 < Entropy: RandomNumberGenerator > {
26
+ /// Type of the collection of non-collated batches.
27
+ public typealias Batches = Slices < Sampling < [ ( data: [ UInt8 ] , label: Int32 ) ] , ArraySlice < Int > > >
28
+ /// The type of the training data, represented as a sequence of epochs, which
29
+ /// are collection of batches.
30
+ public typealias Training = LazyMapSequence <
31
+ TrainingEpochs < [ ( data: [ UInt8 ] , label: Int32 ) ] , Entropy > ,
32
+ LazyMapSequence < Batches , LabeledImage >
33
+ >
34
+ /// The type of the validation data, represented as a collection of batches.
35
+ public typealias Validation = LazyMapSequence < Slices < [ ( data: [ UInt8 ] , label: Int32 ) ] > , LabeledImage >
36
+ /// The training epochs.
37
+ public let training : Training
38
+ /// The validation batches.
39
+ public let validation : Validation
40
+
41
+ /// Creates an instance with `batchSize`.
42
+ ///
43
+ /// - Parameter entropy: a source of randomness used to shuffle sample
44
+ /// ordering. It will be stored in `self`, so if it is only pseudorandom
45
+ /// and has value semantics, the sequence of epochs is deterministic and not
46
+ /// dependent on other operations.
47
+ public init ( batchSize: Int , entropy: Entropy ) {
48
+ self . init (
49
+ batchSize: batchSize,
50
+ entropy: entropy,
51
+ remoteBinaryArchiveLocation: URL (
52
+ string: " https://storage.googleapis.com/s4tf-hosted-binaries/datasets/CIFAR10/cifar-10-binary.tar.gz " ) !,
53
+ normalizing: true )
54
+ }
55
+
56
+ /// Creates an instance with `batchSize` using `remoteBinaryArchiveLocation`.
57
+ ///
58
+ /// - Parameters:
59
+ /// - entropy: a source of randomness used to shuffle sample ordering. It
60
+ /// will be stored in `self`, so if it is only pseudorandom and has value
61
+ /// semantics, the sequence of epochs is deterministic and not dependent
62
+ /// on other operations.
63
+ /// - normalizing: normalizes the batches with the mean and standard deviation
64
+ /// of the dataset iff `true`. Default value is `true`.
65
+ public init (
66
+ batchSize: Int ,
67
+ entropy: Entropy ,
68
+ remoteBinaryArchiveLocation: URL ,
69
+ localStorageDirectory: URL = DatasetUtilities . defaultDirectory
70
+ . appendingPathComponent ( " CIFAR10 " , isDirectory: true ) ,
71
+ normalizing: Bool
72
+ ) {
73
+ downloadCIFAR10IfNotPresent ( from: remoteBinaryArchiveLocation, to: localStorageDirectory)
74
+
75
+ // Training data
76
+ let trainingSamples = loadCIFARTrainingFiles ( in: localStorageDirectory)
77
+ training = TrainingEpochs ( samples: trainingSamples, batchSize: batchSize, entropy: entropy)
78
+ . lazy. map { ( batches: Batches ) -> LazyMapSequence < Batches , LabeledImage > in
79
+ return batches. lazy. map { makeBatch ( samples: $0, normalizing: normalizing) }
80
+ }
81
+
82
+ // Validation data
83
+ let validationSamples = loadCIFARTestFile ( in: localStorageDirectory)
84
+ validation = validationSamples. inBatches ( of: batchSize) . lazy. map {
85
+ makeBatch ( samples: $0, normalizing: normalizing)
36
86
}
87
+ }
88
+ }
37
89
38
- public init (
39
- batchSize: Int ,
40
- remoteBinaryArchiveLocation: URL ,
41
- localStorageDirectory: URL = DatasetUtilities . defaultDirectory
42
- . appendingPathComponent ( " CIFAR10 " , isDirectory: true ) ,
43
- normalizing: Bool )
44
- {
45
- downloadCIFAR10IfNotPresent ( from: remoteBinaryArchiveLocation, to: localStorageDirectory)
46
- self . training = Batcher (
47
- on: loadCIFARTrainingFiles ( localStorageDirectory: localStorageDirectory, normalizing: normalizing) ,
48
- batchSize: batchSize,
49
- numWorkers: 1 , //No need to use parallelism since everything is loaded in memory
50
- shuffle: true )
51
- self . test = Batcher (
52
- on: loadCIFARTestFile ( localStorageDirectory: localStorageDirectory, normalizing: normalizing) ,
53
- batchSize: batchSize,
54
- numWorkers: 1 ) //No need to use parallelism since everything is loaded in memory
55
- }
90
+ extension CIFAR10 : ImageClassificationData where Entropy == SystemRandomNumberGenerator {
91
+ /// Creates an instance with `batchSize`.
92
+ public init ( batchSize: Int ) {
93
+ self . init ( batchSize: batchSize, entropy: SystemRandomNumberGenerator ( ) )
94
+ }
56
95
}
57
96
58
97
func downloadCIFAR10IfNotPresent( from location: URL , to directory: URL ) {
59
- let downloadPath = directory. appendingPathComponent ( " cifar-10-batches-bin " ) . path
60
- let directoryExists = FileManager . default. fileExists ( atPath: downloadPath)
61
- let contentsOfDir = try ? FileManager . default. contentsOfDirectory ( atPath: downloadPath)
62
- let directoryEmpty = ( contentsOfDir == nil ) || ( contentsOfDir!. isEmpty)
98
+ let downloadPath = directory. appendingPathComponent ( " cifar-10-batches-bin " ) . path
99
+ let directoryExists = FileManager . default. fileExists ( atPath: downloadPath)
100
+ let contentsOfDir = try ? FileManager . default. contentsOfDirectory ( atPath: downloadPath)
101
+ let directoryEmpty = ( contentsOfDir == nil ) || ( contentsOfDir!. isEmpty)
63
102
64
- guard !directoryExists || directoryEmpty else { return }
103
+ guard !directoryExists || directoryEmpty else { return }
65
104
66
- let _ = DatasetUtilities . downloadResource (
67
- filename: " cifar-10-binary " , fileExtension: " tar.gz " ,
68
- remoteRoot: location. deletingLastPathComponent ( ) , localStorageDirectory: directory)
105
+ let _ = DatasetUtilities . downloadResource (
106
+ filename: " cifar-10-binary " , fileExtension: " tar.gz " ,
107
+ remoteRoot: location. deletingLastPathComponent ( ) , localStorageDirectory: directory)
69
108
}
70
109
71
- func loadCIFARFile( named name: String , in directory: URL , normalizing: Bool = true ) -> [ TensorPair < Float , Int32 > ] {
72
- let path = directory. appendingPathComponent ( " cifar-10-batches-bin/ \( name) " ) . path
73
-
74
- let imageCount = 10000
75
- guard let fileContents = try ? Data ( contentsOf: URL ( fileURLWithPath: path) ) else {
76
- printError ( " Could not read dataset file: \( name) " )
77
- exit ( - 1 )
78
- }
79
- guard fileContents. count == 30_730_000 else {
80
- printError (
81
- " Dataset file \( name) should have 30730000 bytes, instead had \( fileContents. count) " )
82
- exit ( - 1 )
83
- }
84
-
85
- var bytes : [ UInt8 ] = [ ]
86
- var labels : [ Int64 ] = [ ]
87
-
88
- let imageByteSize = 3073
89
- for imageIndex in 0 ..< imageCount {
90
- let baseAddress = imageIndex * imageByteSize
91
- labels. append ( Int64 ( fileContents [ baseAddress] ) )
92
- bytes. append ( contentsOf: fileContents [ ( baseAddress + 1 ) ..< ( baseAddress + 3073 ) ] )
93
- }
94
-
95
- let labelTensor = Tensor < Int64 > ( shape: [ imageCount] , scalars: labels)
96
- let images = Tensor < UInt8 > ( shape: [ imageCount, 3 , 32 , 32 ] , scalars: bytes)
97
-
98
- // Transpose from the CIFAR-provided N(CHW) to TF's default NHWC.
99
- var imageTensor = Tensor < Float > ( images. transposed ( permutation: [ 0 , 2 , 3 , 1 ] ) )
100
-
101
- // The value of mean and std were calculated with the following Swift code:
102
- // ```
103
- // import TensorFlow
104
- // import Datasets
105
- // import Foundation
106
- // let urlString = "https://storage.googleapis.com/s4tf-hosted-binaries/datasets/CIFAR10/cifar-10-binary.tar.gz"
107
- // let cifar = CIFAR10(batchSize: 50000,
108
- // remoteBinaryArchiveLocation: URL(string: urlString)!,
109
- // normalizing: false)
110
- // for batch in cifar.training.sequenced() {
111
- // let images = Tensor<Double>(batch.first) / 255.0
112
- // let mom = images.moments(squeezingAxes: [0,1,2])
113
- // print("mean: \(mom.mean) std: \(sqrt(mom.variance))")
114
- // }
115
- // ```
116
- if normalizing {
117
- let mean = Tensor < Float > (
118
- [ 0.4913996898 ,
119
- 0.4821584196 ,
120
- 0.4465309242 ] )
121
- let std = Tensor < Float > (
122
- [ 0.2470322324 ,
123
- 0.2434851280 ,
124
- 0.2615878417 ] )
125
- imageTensor = ( ( imageTensor / 255.0 ) - mean) / std
126
- }
127
-
128
- return ( 0 ..< imageCount) . map { TensorPair ( first: imageTensor [ $0] , second: Tensor < Int32 > ( labelTensor [ $0] ) ) }
129
-
110
+ func loadCIFARFile( named name: String , in directory: URL ) -> [ ( data: [ UInt8 ] , label: Int32 ) ] {
111
+ let path = directory. appendingPathComponent ( " cifar-10-batches-bin/ \( name) " ) . path
112
+
113
+ let imageCount = 10000
114
+ guard let fileContents = try ? Data ( contentsOf: URL ( fileURLWithPath: path) ) else {
115
+ printError ( " Could not read dataset file: \( name) " )
116
+ exit ( - 1 )
117
+ }
118
+ guard fileContents. count == 30_730_000 else {
119
+ printError (
120
+ " Dataset file \( name) should have 30730000 bytes, instead had \( fileContents. count) " )
121
+ exit ( - 1 )
122
+ }
123
+
124
+ var labeledImages : [ ( data: [ UInt8 ] , label: Int32 ) ] = [ ]
125
+
126
+ let imageByteSize = 3073
127
+ for imageIndex in 0 ..< imageCount {
128
+ let baseAddress = imageIndex * imageByteSize
129
+ let label = Int32 ( fileContents [ baseAddress] )
130
+ let data = [ UInt8] ( fileContents [ ( baseAddress + 1 ) ..< ( baseAddress + 3073 ) ] )
131
+ labeledImages. append ( ( data: data, label: label) )
132
+ }
133
+
134
+ return labeledImages
130
135
}
131
136
132
- func loadCIFARTrainingFiles( localStorageDirectory: URL , normalizing : Bool = true ) -> [ TensorPair < Float , Int32 > ] {
133
- let data = ( 1 ..< 6 ) . map {
134
- loadCIFARFile ( named: " data_batch_ \( $0) .bin " , in: localStorageDirectory, normalizing : normalizing )
135
- }
136
- return data. reduce ( [ ] , + )
137
+ func loadCIFARTrainingFiles( in localStorageDirectory: URL ) -> [ ( data : [ UInt8 ] , label : Int32 ) ] {
138
+ let data = ( 1 ..< 6 ) . map {
139
+ loadCIFARFile ( named: " data_batch_ \( $0) .bin " , in: localStorageDirectory)
140
+ }
141
+ return data. reduce ( [ ] , + )
137
142
}
138
143
139
- func loadCIFARTestFile( localStorageDirectory: URL , normalizing : Bool = true ) -> [ TensorPair < Float , Int32 > ] {
140
- return loadCIFARFile ( named: " test_batch.bin " , in: localStorageDirectory, normalizing : normalizing )
144
+ func loadCIFARTestFile( in localStorageDirectory: URL ) -> [ ( data : [ UInt8 ] , label : Int32 ) ] {
145
+ return loadCIFARFile ( named: " test_batch.bin " , in: localStorageDirectory)
141
146
}
147
+
148
+ func makeBatch< BatchSamples: Collection > ( samples: BatchSamples , normalizing: Bool ) -> LabeledImage
149
+ where BatchSamples. Element == ( data: [ UInt8 ] , label: Int32 ) {
150
+ let bytes = samples. lazy. map ( \. data) . reduce ( into: [ ] , += )
151
+ let images = Tensor < UInt8 > ( shape: [ samples. count, 3 , 32 , 32 ] , scalars: bytes)
152
+
153
+ var imageTensor = Tensor < Float > ( images. transposed ( permutation: [ 0 , 2 , 3 , 1 ] ) )
154
+ imageTensor /= 255.0
155
+ if normalizing {
156
+ let mean = Tensor < Float > ( [ 0.4913996898 , 0.4821584196 , 0.4465309242 ] )
157
+ let std = Tensor < Float > ( [ 0.2470322324 , 0.2434851280 , 0.2615878417 ] )
158
+ imageTensor = ( imageTensor - mean) / std
159
+ }
160
+
161
+ let labels = Tensor < Int32 > ( samples. map ( \. label) )
162
+ return LabeledImage ( data: imageTensor, label: labels)
163
+ }
0 commit comments