Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit c79f62e

Browse files
authored
First steps in repository reorganization: extracting common MNIST dataset code (#182)
* Extracted MNIST dataset, created LeNet network, added example combining the two. * Extracted redundant MNIST loading code from GAN and Autoencoder examples, replaced with central MNIST dataset. * Renamed input parameters and applied standard formatting style to MNIST. * Punctuation correction. Co-Authored-By: Richard Wei <[email protected]> * README formatting update. Co-Authored-By: Richard Wei <[email protected]> * Renamed trainImages -> trainingImages, corrected Python package names, formatted Package.swift.
1 parent 7fb9185 commit c79f62e

File tree

18 files changed

+285
-267
lines changed

18 files changed

+285
-267
lines changed

Autoencoder/README.md

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Simple Autoencoder
22

3+
This is an example of a simple 1-dimensional autoencoder model, using MNIST as a training dataset. It should produce output similar to the following:
4+
35
### Epoch 1
46
<p align="center">
57
<img src="images/epoch-1-input.png" height="270" width="360">
@@ -12,23 +14,17 @@
1214
<img src="images/epoch-10-output.png" height="270" width="360">
1315
</p>
1416

15-
This directory builds a simple 1-dimensional autoencoder model.
1617

1718
## Setup
1819

1920
To begin, you'll need the [latest version of Swift for
2021
TensorFlow](https://github.com/tensorflow/swift/blob/master/Installation.md)
2122
installed. Make sure you've added the correct version of `swift` to your path.
2223

24+
This example requires Matplotlib and NumPy to be installed, for use in image output.
25+
2326
To train the model, run:
2427

2528
```
26-
swift run Autoencoder
27-
```
28-
If you using brew to install python2 and modules, change the path:
29-
- remove brew path '/usr/local/bin'
30-
- add TensorFlow swift Toolchain /Library/Developer/Toolchains/swift-latest/usr/bin
31-
29+
swift run -c release Autoencoder
3230
```
33-
export PATH=/Library/Developer/Toolchains/swift-latest/usr/bin:/usr/bin:/bin:/usr/sbin:/sbin:"${PATH}"
34-
```

Autoencoder/main.swift

Lines changed: 9 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,17 @@
1515
import Foundation
1616
import TensorFlow
1717
import Python
18+
import Datasets
1819

1920
// Import Python modules
2021
let matplotlib = Python.import("matplotlib")
2122
let np = Python.import("numpy")
22-
let plt = Python.import("matplotlib.pyplot")
2323

24-
// Turn off using display on server / linux
24+
// Use the AGG renderer for saving images to disk.
2525
matplotlib.use("Agg")
2626

27-
// Some globals
27+
let plt = Python.import("matplotlib.pyplot")
28+
2829
let epochCount = 10
2930
let batchSize = 100
3031
let outputFolder = "./output/"
@@ -45,37 +46,6 @@ func plot(image: [Float], name: String) {
4546
plt.close()
4647
}
4748

48-
/// Reads a file into an array of bytes.
49-
func readFile(_ filename: String) -> [UInt8] {
50-
let possibleFolders = [".", "Resources", "Autoencoder/Resources"]
51-
for folder in possibleFolders {
52-
let parent = URL(fileURLWithPath: folder)
53-
let filePath = parent.appendingPathComponent(filename).path
54-
guard FileManager.default.fileExists(atPath: filePath) else {
55-
continue
56-
}
57-
let d = Python.open(filePath, "rb").read()
58-
return Array(numpy: np.frombuffer(d, dtype: np.uint8))!
59-
}
60-
print("Failed to find file with name \(filename) in the following folders: \(possibleFolders).")
61-
exit(-1)
62-
}
63-
64-
/// Reads MNIST images and labels from specified file paths.
65-
func readMNIST(imagesFile: String, labelsFile: String) -> (images: Tensor<Float>,
66-
labels: Tensor<Int32>) {
67-
print("Reading data.")
68-
let images = readFile(imagesFile).dropFirst(16).map { Float($0) }
69-
let labels = readFile(labelsFile).dropFirst(8).map { Int32($0) }
70-
let rowCount = labels.count
71-
72-
print("Constructing data tensors.")
73-
return (
74-
images: Tensor(shape: [rowCount, imageHeight * imageWidth], scalars: images) / 255.0,
75-
labels: Tensor(labels)
76-
)
77-
}
78-
7949
/// An autoencoder.
8050
struct Autoencoder: Layer {
8151
typealias Input = Tensor<Float>
@@ -91,7 +61,7 @@ struct Autoencoder: Layer {
9161
var decoder2 = Dense<Float>(inputSize: 12, outputSize: 64, activation: relu)
9262
var decoder3 = Dense<Float>(inputSize: 64, outputSize: 128, activation: relu)
9363
var decoder4 = Dense<Float>(inputSize: 128, outputSize: imageHeight * imageWidth,
94-
activation: sigmoid)
64+
activation: tanh)
9565

9666
@differentiable
9767
func callAsFunction(_ input: Input) -> Output {
@@ -100,22 +70,13 @@ struct Autoencoder: Layer {
10070
}
10171
}
10272

103-
// MNIST data logic
104-
func minibatch<Scalar>(in x: Tensor<Scalar>, at index: Int) -> Tensor<Scalar> {
105-
let start = index * batchSize
106-
return x[start..<start+batchSize]
107-
}
108-
109-
let (images, numericLabels) = readMNIST(imagesFile: "train-images-idx3-ubyte",
110-
labelsFile: "train-labels-idx1-ubyte")
111-
let labels = Tensor<Float>(oneHotAtIndices: numericLabels, depth: 10)
112-
73+
let dataset = MNIST(batchSize: batchSize, flattening: true)
11374
var autoencoder = Autoencoder()
11475
let optimizer = RMSProp(for: autoencoder)
11576

11677
// Training loop
11778
for epoch in 1...epochCount {
118-
let sampleImage = Tensor(shape: [1, imageHeight * imageWidth], scalars: images[epoch].scalars)
79+
let sampleImage = Tensor(shape: [1, imageHeight * imageWidth], scalars: dataset.trainingImages[epoch].scalars)
11980
let testImage = autoencoder(sampleImage)
12081

12182
plot(image: sampleImage.scalars, name: "epoch-\(epoch)-input")
@@ -124,8 +85,8 @@ for epoch in 1...epochCount {
12485
let sampleLoss = meanSquaredError(predicted: testImage, expected: sampleImage)
12586
print("[Epoch: \(epoch)] Loss: \(sampleLoss)")
12687

127-
for i in 0 ..< Int(labels.shape[0]) / batchSize {
128-
let x = minibatch(in: images, at: i)
88+
for i in 0 ..< dataset.trainingSize / batchSize {
89+
let x = dataset.trainingImages.minibatch(at: i, batchSize: batchSize)
12990

13091
let 𝛁model = autoencoder.gradient { autoencoder -> Tensor<Float> in
13192
let image = autoencoder(x)

Datasets/MNIST/MNIST.swift

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
import TensorFlow
17+
18+
public struct MNIST {
19+
public let trainingImages: Tensor<Float>
20+
public let trainingLabels: Tensor<Int32>
21+
public let testImages: Tensor<Float>
22+
public let testLabels: Tensor<Int32>
23+
24+
public let trainingSize: Int
25+
public let testSize: Int
26+
27+
public let batchSize: Int
28+
29+
public init(batchSize: Int, flattening: Bool = false, normalizing: Bool = false) {
30+
self.batchSize = batchSize
31+
32+
let (trainingImages, trainingLabels) = readMNIST(
33+
imagesFile: "train-images-idx3-ubyte",
34+
labelsFile: "train-labels-idx1-ubyte",
35+
flattening: flattening,
36+
normalizing: normalizing)
37+
self.trainingImages = trainingImages
38+
self.trainingLabels = trainingLabels
39+
self.trainingSize = Int(trainingLabels.shape[0])
40+
41+
let (testImages, testLabels) = readMNIST(
42+
imagesFile: "t10k-images-idx3-ubyte",
43+
labelsFile: "t10k-labels-idx1-ubyte",
44+
flattening: flattening,
45+
normalizing: normalizing)
46+
self.testImages = testImages
47+
self.testLabels = testLabels
48+
self.testSize = Int(testLabels.shape[0])
49+
}
50+
}
51+
52+
extension Tensor {
53+
public func minibatch(at index: Int, batchSize: Int) -> Tensor {
54+
let start = index * batchSize
55+
return self[start..<start+batchSize]
56+
}
57+
}
58+
59+
/// Reads a file into an array of bytes.
60+
func readFile(_ path: String, possibleDirectories: [String]) -> [UInt8] {
61+
for folder in possibleDirectories {
62+
let parent = URL(fileURLWithPath: folder)
63+
let filePath = parent.appendingPathComponent(path)
64+
guard FileManager.default.fileExists(atPath: filePath.path) else {
65+
continue
66+
}
67+
let data = try! Data(contentsOf: filePath, options: [])
68+
return [UInt8](data)
69+
}
70+
print("File not found: \(path)")
71+
exit(-1)
72+
}
73+
74+
/// Reads MNIST images and labels from specified file paths.
75+
func readMNIST(imagesFile: String, labelsFile: String, flattening: Bool, normalizing: Bool) -> (
76+
images: Tensor<Float>,
77+
labels: Tensor<Int32>
78+
) {
79+
print("Reading data from files: \(imagesFile), \(labelsFile).")
80+
let images = readFile(imagesFile, possibleDirectories: [".", "./Datasets/MNIST"]).dropFirst(16)
81+
.map(Float.init)
82+
let labels = readFile(labelsFile, possibleDirectories: [".", "./Datasets/MNIST"]).dropFirst(8)
83+
.map(Int32.init)
84+
let rowCount = labels.count
85+
let imageHeight = 28
86+
let imageWidth = 28
87+
88+
print("Constructing data tensors.")
89+
90+
if flattening {
91+
var flattenedImages = Tensor(shape: [rowCount, imageHeight * imageWidth], scalars: images)
92+
/ 255.0
93+
if normalizing {
94+
flattenedImages = flattenedImages * 2.0 - 1.0
95+
}
96+
return (images: flattenedImages, labels: Tensor(labels))
97+
} else {
98+
return (
99+
images: Tensor(shape: [rowCount, 1, imageHeight, imageWidth], scalars: images)
100+
.transposed(withPermutations: [0, 2, 3, 1]) / 255, // NHWC
101+
labels: Tensor(labels)
102+
)
103+
}
104+
}
File renamed without changes.
File renamed without changes.

Examples/LeNet-MNIST/README.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# LeNet-5 with MNIST
2+
3+
This example demonstrates how to train the [LeNet-5 network]( http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf) against the [MNIST digit classification dataset](http://yann.lecun.com/exdb/mnist/).
4+
5+
The LeNet network is instantiated from the ImageClassificationModels library of standard models, and applied to an instance of the MNIST dataset. A custom training loop is defined, and the training and test losses and accuracies for each epoch are shown during training.
6+
7+
8+
## Setup
9+
10+
To begin, you'll need the [latest version of Swift for
11+
TensorFlow](https://github.com/tensorflow/swift/blob/master/Installation.md)
12+
installed. Make sure you've added the correct version of `swift` to your path.
13+
14+
To train the model, run:
15+
16+
```sh
17+
cd swift-models
18+
swift run -c release LeNet-MNIST
19+
```

Examples/LeNet-MNIST/main.swift

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import TensorFlow
16+
import ImageClassificationModels
17+
import Datasets
18+
19+
let epochCount = 12
20+
let batchSize = 128
21+
22+
let dataset = MNIST(batchSize: batchSize)
23+
var classifier = LeNet()
24+
25+
let optimizer = SGD(for: classifier, learningRate: 0.1)
26+
27+
print("Beginning training...")
28+
29+
struct Statistics {
30+
var correctGuessCount: Int = 0
31+
var totalGuessCount: Int = 0
32+
var totalLoss: Float = 0
33+
}
34+
35+
// The training loop.
36+
for epoch in 1...epochCount {
37+
var trainStats = Statistics()
38+
var testStats = Statistics()
39+
Context.local.learningPhase = .training
40+
for i in 0 ..< dataset.trainingSize / batchSize {
41+
let x = dataset.trainingImages.minibatch(at: i, batchSize: batchSize)
42+
let y = dataset.trainingLabels.minibatch(at: i, batchSize: batchSize)
43+
// Compute the gradient with respect to the model.
44+
let 𝛁model = classifier.gradient { classifier -> Tensor<Float> in
45+
let ŷ = classifier(x)
46+
let correctPredictions = ŷ.argmax(squeezingAxis: 1) .== y
47+
trainStats.correctGuessCount += Int(
48+
Tensor<Int32>(correctPredictions).sum().scalarized())
49+
trainStats.totalGuessCount += batchSize
50+
let loss = softmaxCrossEntropy(logits: ŷ, labels: y)
51+
trainStats.totalLoss += loss.scalarized()
52+
return loss
53+
}
54+
// Update the model's differentiable variables along the gradient vector.
55+
optimizer.update(&classifier.allDifferentiableVariables, along: 𝛁model)
56+
}
57+
58+
Context.local.learningPhase = .inference
59+
for i in 0 ..< dataset.testSize / batchSize {
60+
let x = dataset.testImages.minibatch(at: i, batchSize: batchSize)
61+
let y = dataset.testLabels.minibatch(at: i, batchSize: batchSize)
62+
// Compute loss on test set
63+
let ŷ = classifier(x)
64+
let correctPredictions = ŷ.argmax(squeezingAxis: 1) .== y
65+
testStats.correctGuessCount += Int(Tensor<Int32>(correctPredictions).sum().scalarized())
66+
testStats.totalGuessCount += batchSize
67+
let loss = softmaxCrossEntropy(logits: ŷ, labels: y)
68+
testStats.totalLoss += loss.scalarized()
69+
}
70+
71+
let trainAccuracy = Float(trainStats.correctGuessCount) / Float(trainStats.totalGuessCount)
72+
let testAccuracy = Float(testStats.correctGuessCount) / Float(testStats.totalGuessCount)
73+
print("""
74+
[Epoch \(epoch)] \
75+
Training Loss: \(trainStats.totalLoss), \
76+
Training Accuracy: \(trainStats.correctGuessCount)/\(trainStats.totalGuessCount) \
77+
(\(trainAccuracy)), \
78+
Test Loss: \(testStats.totalLoss), \
79+
Test Accuracy: \(testStats.correctGuessCount)/\(testStats.totalGuessCount) \
80+
(\(testAccuracy))
81+
""")
82+
}

GAN/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@ To begin, you'll need the [latest version of Swift for
1616
TensorFlow](https://github.com/tensorflow/swift/blob/master/Installation.md)
1717
installed. Make sure you've added the correct version of `swift` to your path.
1818

19+
This example requires Matplotlib and NumPy to be installed, for use in image output.
20+
1921
To train the model, run:
2022

2123
```sh
2224
swift run GAN
23-
```
25+
```

GAN/Resources/train-images-idx3-ubyte

-44.9 MB
Binary file not shown.

0 commit comments

Comments
 (0)