Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

First steps in repository reorganization: extracting common MNIST dataset code #182

Merged
merged 6 commits into from
Jul 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions Autoencoder/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Simple Autoencoder

This is an example of a simple 1-dimensional autoencoder model, using MNIST as a training dataset. It should produce output similar to the following:

### Epoch 1
<p align="center">
<img src="images/epoch-1-input.png" height="270" width="360">
Expand All @@ -12,23 +14,17 @@
<img src="images/epoch-10-output.png" height="270" width="360">
</p>

This directory builds a simple 1-dimensional autoencoder model.

## Setup

To begin, you'll need the [latest version of Swift for
TensorFlow](https://github.com/tensorflow/swift/blob/master/Installation.md)
installed. Make sure you've added the correct version of `swift` to your path.

This example requires Matplotlib and NumPy to be installed, for use in image output.

To train the model, run:

```
swift run Autoencoder
```
If you using brew to install python2 and modules, change the path:
- remove brew path '/usr/local/bin'
- add TensorFlow swift Toolchain /Library/Developer/Toolchains/swift-latest/usr/bin

swift run -c release Autoencoder
```
export PATH=/Library/Developer/Toolchains/swift-latest/usr/bin:/usr/bin:/bin:/usr/sbin:/sbin:"${PATH}"
```
57 changes: 9 additions & 48 deletions Autoencoder/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@
import Foundation
import TensorFlow
import Python
import Datasets

// Import Python modules
let matplotlib = Python.import("matplotlib")
let np = Python.import("numpy")
let plt = Python.import("matplotlib.pyplot")

// Turn off using display on server / linux
// Use the AGG renderer for saving images to disk.
matplotlib.use("Agg")

// Some globals
let plt = Python.import("matplotlib.pyplot")

let epochCount = 10
let batchSize = 100
let outputFolder = "./output/"
Expand All @@ -45,37 +46,6 @@ func plot(image: [Float], name: String) {
plt.close()
}

/// Reads a file into an array of bytes.
func readFile(_ filename: String) -> [UInt8] {
let possibleFolders = [".", "Resources", "Autoencoder/Resources"]
for folder in possibleFolders {
let parent = URL(fileURLWithPath: folder)
let filePath = parent.appendingPathComponent(filename).path
guard FileManager.default.fileExists(atPath: filePath) else {
continue
}
let d = Python.open(filePath, "rb").read()
return Array(numpy: np.frombuffer(d, dtype: np.uint8))!
}
print("Failed to find file with name \(filename) in the following folders: \(possibleFolders).")
exit(-1)
}

/// Reads MNIST images and labels from specified file paths.
func readMNIST(imagesFile: String, labelsFile: String) -> (images: Tensor<Float>,
labels: Tensor<Int32>) {
print("Reading data.")
let images = readFile(imagesFile).dropFirst(16).map { Float($0) }
let labels = readFile(labelsFile).dropFirst(8).map { Int32($0) }
let rowCount = labels.count

print("Constructing data tensors.")
return (
images: Tensor(shape: [rowCount, imageHeight * imageWidth], scalars: images) / 255.0,
labels: Tensor(labels)
)
}

/// An autoencoder.
struct Autoencoder: Layer {
typealias Input = Tensor<Float>
Expand All @@ -91,7 +61,7 @@ struct Autoencoder: Layer {
var decoder2 = Dense<Float>(inputSize: 12, outputSize: 64, activation: relu)
var decoder3 = Dense<Float>(inputSize: 64, outputSize: 128, activation: relu)
var decoder4 = Dense<Float>(inputSize: 128, outputSize: imageHeight * imageWidth,
activation: sigmoid)
activation: tanh)

@differentiable
func callAsFunction(_ input: Input) -> Output {
Expand All @@ -100,22 +70,13 @@ struct Autoencoder: Layer {
}
}

// MNIST data logic
func minibatch<Scalar>(in x: Tensor<Scalar>, at index: Int) -> Tensor<Scalar> {
let start = index * batchSize
return x[start..<start+batchSize]
}

let (images, numericLabels) = readMNIST(imagesFile: "train-images-idx3-ubyte",
labelsFile: "train-labels-idx1-ubyte")
let labels = Tensor<Float>(oneHotAtIndices: numericLabels, depth: 10)

let dataset = MNIST(batchSize: batchSize, flattening: true)
var autoencoder = Autoencoder()
let optimizer = RMSProp(for: autoencoder)

// Training loop
for epoch in 1...epochCount {
let sampleImage = Tensor(shape: [1, imageHeight * imageWidth], scalars: images[epoch].scalars)
let sampleImage = Tensor(shape: [1, imageHeight * imageWidth], scalars: dataset.trainingImages[epoch].scalars)
let testImage = autoencoder(sampleImage)

plot(image: sampleImage.scalars, name: "epoch-\(epoch)-input")
Expand All @@ -124,8 +85,8 @@ for epoch in 1...epochCount {
let sampleLoss = meanSquaredError(predicted: testImage, expected: sampleImage)
print("[Epoch: \(epoch)] Loss: \(sampleLoss)")

for i in 0 ..< Int(labels.shape[0]) / batchSize {
let x = minibatch(in: images, at: i)
for i in 0 ..< dataset.trainingSize / batchSize {
let x = dataset.trainingImages.minibatch(at: i, batchSize: batchSize)

let 𝛁model = autoencoder.gradient { autoencoder -> Tensor<Float> in
let image = autoencoder(x)
Expand Down
104 changes: 104 additions & 0 deletions Datasets/MNIST/MNIST.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation
import TensorFlow

public struct MNIST {
public let trainingImages: Tensor<Float>
public let trainingLabels: Tensor<Int32>
public let testImages: Tensor<Float>
public let testLabels: Tensor<Int32>

public let trainingSize: Int
public let testSize: Int

public let batchSize: Int

public init(batchSize: Int, flattening: Bool = false, normalizing: Bool = false) {
self.batchSize = batchSize

let (trainingImages, trainingLabels) = readMNIST(
imagesFile: "train-images-idx3-ubyte",
labelsFile: "train-labels-idx1-ubyte",
flattening: flattening,
normalizing: normalizing)
self.trainingImages = trainingImages
self.trainingLabels = trainingLabels
self.trainingSize = Int(trainingLabels.shape[0])

let (testImages, testLabels) = readMNIST(
imagesFile: "t10k-images-idx3-ubyte",
labelsFile: "t10k-labels-idx1-ubyte",
flattening: flattening,
normalizing: normalizing)
self.testImages = testImages
self.testLabels = testLabels
self.testSize = Int(testLabels.shape[0])
}
}

extension Tensor {
public func minibatch(at index: Int, batchSize: Int) -> Tensor {
let start = index * batchSize
return self[start..<start+batchSize]
}
}

/// Reads a file into an array of bytes.
func readFile(_ path: String, possibleDirectories: [String]) -> [UInt8] {
for folder in possibleDirectories {
let parent = URL(fileURLWithPath: folder)
let filePath = parent.appendingPathComponent(path)
guard FileManager.default.fileExists(atPath: filePath.path) else {
continue
}
let data = try! Data(contentsOf: filePath, options: [])
return [UInt8](data)
}
print("File not found: \(path)")
exit(-1)
}

/// Reads MNIST images and labels from specified file paths.
func readMNIST(imagesFile: String, labelsFile: String, flattening: Bool, normalizing: Bool) -> (
images: Tensor<Float>,
labels: Tensor<Int32>
) {
print("Reading data from files: \(imagesFile), \(labelsFile).")
let images = readFile(imagesFile, possibleDirectories: [".", "./Datasets/MNIST"]).dropFirst(16)
.map(Float.init)
let labels = readFile(labelsFile, possibleDirectories: [".", "./Datasets/MNIST"]).dropFirst(8)
.map(Int32.init)
let rowCount = labels.count
let imageHeight = 28
let imageWidth = 28

print("Constructing data tensors.")

if flattening {
var flattenedImages = Tensor(shape: [rowCount, imageHeight * imageWidth], scalars: images)
/ 255.0
if normalizing {
flattenedImages = flattenedImages * 2.0 - 1.0
}
return (images: flattenedImages, labels: Tensor(labels))
} else {
return (
images: Tensor(shape: [rowCount, 1, imageHeight, imageWidth], scalars: images)
.transposed(withPermutations: [0, 2, 3, 1]) / 255, // NHWC
labels: Tensor(labels)
)
}
}
19 changes: 19 additions & 0 deletions Examples/LeNet-MNIST/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# LeNet-5 with MNIST

This example demonstrates how to train the [LeNet-5 network]( http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf) against the [MNIST digit classification dataset](http://yann.lecun.com/exdb/mnist/).

The LeNet network is instantiated from the ImageClassificationModels library of standard models, and applied to an instance of the MNIST dataset. A custom training loop is defined, and the training and test losses and accuracies for each epoch are shown during training.


## Setup

To begin, you'll need the [latest version of Swift for
TensorFlow](https://github.com/tensorflow/swift/blob/master/Installation.md)
installed. Make sure you've added the correct version of `swift` to your path.

To train the model, run:

```sh
cd swift-models
swift run -c release LeNet-MNIST
```
82 changes: 82 additions & 0 deletions Examples/LeNet-MNIST/main.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import TensorFlow
import ImageClassificationModels
import Datasets

let epochCount = 12
let batchSize = 128

let dataset = MNIST(batchSize: batchSize)
var classifier = LeNet()

let optimizer = SGD(for: classifier, learningRate: 0.1)

print("Beginning training...")

struct Statistics {
var correctGuessCount: Int = 0
var totalGuessCount: Int = 0
var totalLoss: Float = 0
}

// The training loop.
for epoch in 1...epochCount {
var trainStats = Statistics()
var testStats = Statistics()
Context.local.learningPhase = .training
for i in 0 ..< dataset.trainingSize / batchSize {
let x = dataset.trainingImages.minibatch(at: i, batchSize: batchSize)
let y = dataset.trainingLabels.minibatch(at: i, batchSize: batchSize)
// Compute the gradient with respect to the model.
let 𝛁model = classifier.gradient { classifier -> Tensor<Float> in
let ŷ = classifier(x)
let correctPredictions = ŷ.argmax(squeezingAxis: 1) .== y
trainStats.correctGuessCount += Int(
Tensor<Int32>(correctPredictions).sum().scalarized())
trainStats.totalGuessCount += batchSize
let loss = softmaxCrossEntropy(logits: ŷ, labels: y)
trainStats.totalLoss += loss.scalarized()
return loss
}
// Update the model's differentiable variables along the gradient vector.
optimizer.update(&classifier.allDifferentiableVariables, along: 𝛁model)
}

Context.local.learningPhase = .inference
for i in 0 ..< dataset.testSize / batchSize {
let x = dataset.testImages.minibatch(at: i, batchSize: batchSize)
let y = dataset.testLabels.minibatch(at: i, batchSize: batchSize)
// Compute loss on test set
let ŷ = classifier(x)
let correctPredictions = ŷ.argmax(squeezingAxis: 1) .== y
testStats.correctGuessCount += Int(Tensor<Int32>(correctPredictions).sum().scalarized())
testStats.totalGuessCount += batchSize
let loss = softmaxCrossEntropy(logits: ŷ, labels: y)
testStats.totalLoss += loss.scalarized()
}

let trainAccuracy = Float(trainStats.correctGuessCount) / Float(trainStats.totalGuessCount)
let testAccuracy = Float(testStats.correctGuessCount) / Float(testStats.totalGuessCount)
print("""
[Epoch \(epoch)] \
Training Loss: \(trainStats.totalLoss), \
Training Accuracy: \(trainStats.correctGuessCount)/\(trainStats.totalGuessCount) \
(\(trainAccuracy)), \
Test Loss: \(testStats.totalLoss), \
Test Accuracy: \(testStats.correctGuessCount)/\(testStats.totalGuessCount) \
(\(testAccuracy))
""")
}
4 changes: 3 additions & 1 deletion GAN/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ To begin, you'll need the [latest version of Swift for
TensorFlow](https://github.com/tensorflow/swift/blob/master/Installation.md)
installed. Make sure you've added the correct version of `swift` to your path.

This example requires Matplotlib and NumPy to be installed, for use in image output.

To train the model, run:

```sh
swift run GAN
```
```
Binary file removed GAN/Resources/train-images-idx3-ubyte
Binary file not shown.
Loading