Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit b353333

Browse files
authored
Continuing repository reorganization: extracting common CIFAR-10 and ResNet code (#185)
* Extracted CIFAR-10 dataset and ResNet models into respective modules. * Minor formatting fixes. * Mirroring PR #187.
1 parent d34fc3c commit b353333

File tree

18 files changed

+346
-656
lines changed

18 files changed

+346
-656
lines changed

CIFAR/Helpers.swift

Lines changed: 0 additions & 51 deletions
This file was deleted.

CIFAR/README.md

Lines changed: 0 additions & 18 deletions
This file was deleted.

CIFAR/ResNet.swift

Lines changed: 0 additions & 119 deletions
This file was deleted.

CIFAR/Data.swift renamed to Datasets/CIFAR10/CIFAR10.swift

Lines changed: 20 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,28 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
// Original source:
16+
// "The CIFAR-10 dataset"
17+
// Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton.
18+
// https://www.cs.toronto.edu/~kriz/cifar.html
19+
1520
import Foundation
1621
import TensorFlow
1722

1823
#if canImport(FoundationNetworking)
1924
import FoundationNetworking
2025
#endif
2126

27+
public struct CIFAR10 {
28+
public let trainingDataset: Dataset<CIFARExample>
29+
public let testDataset: Dataset<CIFARExample>
30+
31+
public init() {
32+
self.trainingDataset = Dataset<CIFARExample>(elements: loadCIFARTrainingFiles())
33+
self.testDataset = Dataset<CIFARExample>(elements: loadCIFARTestFile())
34+
}
35+
}
36+
2237
func downloadCIFAR10IfNotPresent(to directory: String = ".") {
2338
let downloadPath = "\(directory)/cifar-10-batches-bin"
2439
let directoryExists = FileManager.default.fileExists(atPath: downloadPath)
@@ -69,27 +84,7 @@ func downloadCIFAR10IfNotPresent(to directory: String = ".") {
6984
print("Unarchiving completed")
7085
}
7186

72-
struct Example: TensorGroup {
73-
var label: Tensor<Int32>
74-
var data: Tensor<Float>
75-
76-
init(label: Tensor<Int32>, data: Tensor<Float>) {
77-
self.label = label
78-
self.data = data
79-
}
80-
81-
public init<C: RandomAccessCollection>(
82-
_handles: C
83-
) where C.Element: _AnyTensorHandle {
84-
precondition(_handles.count == 2)
85-
let labelIndex = _handles.startIndex
86-
let dataIndex = _handles.index(labelIndex, offsetBy: 1)
87-
label = Tensor<Int32>(handle: TensorHandle<Int32>(handle: _handles[labelIndex]))
88-
data = Tensor<Float>(handle: TensorHandle<Float>(handle: _handles[dataIndex]))
89-
}
90-
}
91-
92-
func loadCIFARFile(named name: String, in directory: String = ".") -> Example {
87+
func loadCIFARFile(named name: String, in directory: String = ".") -> CIFARExample {
9388
downloadCIFAR10IfNotPresent(to: directory)
9489
let path = "\(directory)/cifar-10-batches-bin/\(name)"
9590

@@ -124,25 +119,17 @@ func loadCIFARFile(named name: String, in directory: String = ".") -> Example {
124119
let std = Tensor<Float>([0.229, 0.224, 0.225])
125120
let imagesNormalized = ((imageTensor / 255.0) - mean) / std
126121

127-
return Example(label: Tensor<Int32>(labelTensor), data: imagesNormalized)
122+
return CIFARExample(label: Tensor<Int32>(labelTensor), data: imagesNormalized)
128123
}
129124

130-
func loadCIFARTrainingFiles() -> Example {
125+
func loadCIFARTrainingFiles() -> CIFARExample {
131126
let data = (1..<6).map { loadCIFARFile(named: "data_batch_\($0).bin") }
132-
return Example(
127+
return CIFARExample(
133128
label: Raw.concat(concatDim: Tensor<Int32>(0), data.map { $0.label }),
134129
data: Raw.concat(concatDim: Tensor<Int32>(0), data.map { $0.data })
135130
)
136131
}
137132

138-
func loadCIFARTestFile() -> Example {
133+
func loadCIFARTestFile() -> CIFARExample {
139134
return loadCIFARFile(named: "test_batch.bin")
140135
}
141-
142-
func loadCIFAR10() -> (
143-
training: Dataset<Example>, test: Dataset<Example>
144-
) {
145-
let trainingDataset = Dataset<Example>(elements: loadCIFARTrainingFiles())
146-
let testDataset = Dataset<Example>(elements: loadCIFARTestFile())
147-
return (training: trainingDataset, test: testDataset)
148-
}

Datasets/CIFAR10/CIFARExample.swift

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import TensorFlow
16+
17+
public struct CIFARExample: TensorGroup {
18+
public var label: Tensor<Int32>
19+
public var data: Tensor<Float>
20+
21+
public init(label: Tensor<Int32>, data: Tensor<Float>) {
22+
self.label = label
23+
self.data = data
24+
}
25+
26+
public init<C: RandomAccessCollection>(
27+
_handles: C
28+
) where C.Element: _AnyTensorHandle {
29+
precondition(_handles.count == 2)
30+
let labelIndex = _handles.startIndex
31+
let dataIndex = _handles.index(labelIndex, offsetBy: 1)
32+
label = Tensor<Int32>(handle: TensorHandle<Int32>(handle: _handles[labelIndex]))
33+
data = Tensor<Float>(handle: TensorHandle<Float>(handle: _handles[dataIndex]))
34+
}
35+
}

Datasets/MNIST/MNIST.swift

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
// Original source:
16+
// "The MNIST database of handwritten digits"
17+
// Yann LeCun, Corinna Cortes, and Christopher J.C. Burges
18+
// http://yann.lecun.com/exdb/mnist/
19+
1520
import Foundation
1621
import TensorFlow
1722

File renamed without changes.

Examples/Custom-CIFAR10/README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# CIFAR-10 with custom models
2+
3+
This example demonstrates how to train the custom-defined models (based on examples from [PyTorch](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html) and [Keras](https://github.com/keras-team/keras/blob/master/examples/cifar10_cnn.py) ) against the [CIFAR-10 image classification dataset](https://www.cs.toronto.edu/~kriz/cifar.html).
4+
5+
Two custom models are defined, and one is applied to an instance of the CIFAR-10 dataset. A custom training loop is defined, and the training and test losses and accuracies for each epoch are shown during training.
6+
7+
## Setup
8+
9+
To begin, you'll need the [latest version of Swift for
10+
TensorFlow](https://github.com/tensorflow/swift/blob/master/Installation.md)
11+
installed. Make sure you've added the correct version of `swift` to your path.
12+
13+
To train the model, run:
14+
15+
```sh
16+
cd swift-models
17+
swift run -c release Custom-CIFAR10
18+
```

CIFAR/main.swift renamed to Examples/Custom-CIFAR10/main.swift

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
import Datasets
1516
import TensorFlow
1617

1718
let batchSize = 100
1819

19-
let cifarDataset = loadCIFAR10()
20-
let testBatches = cifarDataset.test.batched(batchSize)
20+
let dataset = CIFAR10()
21+
let testBatches = dataset.testDataset.batched(batchSize)
2122

2223
var model = KerasModel()
2324
let optimizer = RMSProp(for: model, learningRate: 0.0001, decay: 1e-6)
@@ -28,7 +29,7 @@ Context.local.learningPhase = .training
2829
for epoch in 1...100 {
2930
var trainingLossSum: Float = 0
3031
var trainingBatchCount = 0
31-
let trainingShuffled = cifarDataset.training.shuffled(
32+
let trainingShuffled = dataset.trainingDataset.shuffled(
3233
sampleCount: 50000, randomSeed: Int64(epoch))
3334
for batch in trainingShuffled.batched(batchSize) {
3435
let (labels, images) = (batch.label, batch.data)
@@ -52,15 +53,17 @@ for epoch in 1...100 {
5253
testBatchCount += 1
5354

5455
let correctPredictions = logits.argmax(squeezingAxis: 1) .== labels
55-
correctGuessCount = correctGuessCount +
56-
Int(Tensor<Int32>(correctPredictions).sum().scalarized())
56+
correctGuessCount = correctGuessCount + Int(
57+
Tensor<Int32>(correctPredictions).sum().scalarized())
5758
totalGuessCount = totalGuessCount + batchSize
5859
}
5960

6061
let accuracy = Float(correctGuessCount) / Float(totalGuessCount)
61-
print("""
62+
print(
63+
"""
6264
[Epoch \(epoch)] \
6365
Accuracy: \(correctGuessCount)/\(totalGuessCount) (\(accuracy)) \
6466
Loss: \(testLossSum / Float(testBatchCount))
65-
""")
67+
"""
68+
)
6669
}

0 commit comments

Comments
 (0)