Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Add JPEG loading / saving via Raw TF operations, removing Matplotlib dependencies #188

Merged
merged 6 commits into from
Aug 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
.swiftpm
cifar-10-batches-py/
cifar-10-batches-bin/
output/
2 changes: 0 additions & 2 deletions Autoencoder/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ To begin, you'll need the [latest version of Swift for
TensorFlow](https://github.com/tensorflow/swift/blob/master/Installation.md)
installed. Make sure you've added the correct version of `swift` to your path.

This example requires Matplotlib and NumPy to be installed, for use in image output.

To train the model, run:

```
Expand Down
60 changes: 24 additions & 36 deletions Autoencoder/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,59 +12,38 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import Datasets
import Foundation
import ModelSupport
import TensorFlow
import Python
import Datasets

// Import Python modules
let matplotlib = Python.import("matplotlib")
let np = Python.import("numpy")

// Use the AGG renderer for saving images to disk.
matplotlib.use("Agg")

let plt = Python.import("matplotlib.pyplot")

let epochCount = 10
let batchSize = 100
let outputFolder = "./output/"
let imageHeight = 28, imageWidth = 28
let imageHeight = 28
let imageWidth = 28

func plot(image: [Float], name: String) {
// Create figure
let ax = plt.gca()
let array = np.array([image])
let pixels = array.reshape([imageHeight, imageWidth])
if !FileManager.default.fileExists(atPath: outputFolder) {
try! FileManager.default.createDirectory(atPath: outputFolder,
withIntermediateDirectories: false,
attributes: nil)
}
ax.imshow(pixels, cmap: "gray")
plt.savefig("\(outputFolder)\(name).png", dpi: 300)
plt.close()
}
let outputFolder = "./output/"

/// An autoencoder.
struct Autoencoder: Layer {
typealias Input = Tensor<Float>
typealias Output = Tensor<Float>

var encoder1 = Dense<Float>(inputSize: imageHeight * imageWidth, outputSize: 128,
var encoder1 = Dense<Float>(
inputSize: imageHeight * imageWidth, outputSize: 128,
activation: relu)

var encoder2 = Dense<Float>(inputSize: 128, outputSize: 64, activation: relu)
var encoder3 = Dense<Float>(inputSize: 64, outputSize: 12, activation: relu)
var encoder4 = Dense<Float>(inputSize: 12, outputSize: 3, activation: relu)

var decoder1 = Dense<Float>(inputSize: 3, outputSize: 12, activation: relu)
var decoder2 = Dense<Float>(inputSize: 12, outputSize: 64, activation: relu)
var decoder3 = Dense<Float>(inputSize: 64, outputSize: 128, activation: relu)
var decoder4 = Dense<Float>(inputSize: 128, outputSize: imageHeight * imageWidth,

var decoder4 = Dense<Float>(
inputSize: 128, outputSize: imageHeight * imageWidth,
activation: tanh)

@differentiable
func callAsFunction(_ input: Input) -> Output {
func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
let encoder = input.sequenced(through: encoder1, encoder2, encoder3, encoder4)
return encoder.sequenced(through: decoder1, decoder2, decoder3, decoder4)
}
Expand All @@ -76,11 +55,20 @@ let optimizer = RMSProp(for: autoencoder)

// Training loop
for epoch in 1...epochCount {
let sampleImage = Tensor(shape: [1, imageHeight * imageWidth], scalars: dataset.trainingImages[epoch].scalars)
let sampleImage = Tensor(
shape: [1, imageHeight * imageWidth], scalars: dataset.trainingImages[epoch].scalars)
let testImage = autoencoder(sampleImage)

plot(image: sampleImage.scalars, name: "epoch-\(epoch)-input")
plot(image: testImage.scalars, name: "epoch-\(epoch)-output")
do {
try saveImage(
sampleImage, size: (imageWidth, imageHeight), directory: outputFolder,
name: "epoch-\(epoch)-input")
try saveImage(
testImage, size: (imageWidth, imageHeight), directory: outputFolder,
name: "epoch-\(epoch)-output")
} catch {
print("Could not save image with error: \(error)")
}

let sampleLoss = meanSquaredError(predicted: testImage, expected: sampleImage)
print("[Epoch: \(epoch)] Loss: \(sampleLoss)")
Expand Down
2 changes: 0 additions & 2 deletions GAN/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ To begin, you'll need the [latest version of Swift for
TensorFlow](https://github.com/tensorflow/swift/blob/master/Installation.md)
installed. Make sure you've added the correct version of `swift` to your path.

This example requires Matplotlib and NumPy to be installed, for use in image output.

To train the model, run:

```sh
Expand Down
138 changes: 73 additions & 65 deletions GAN/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,59 +12,42 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import Datasets
import Foundation
import ModelSupport
import TensorFlow
import Python
import Datasets

// Import Python modules.
let matplotlib = Python.import("matplotlib")
let np = Python.import("numpy")

// Use the AGG renderer for saving images to disk.
matplotlib.use("Agg")

let plt = Python.import("matplotlib.pyplot")

let epochCount = 10
let batchSize = 32
let outputFolder = "./output/"
let imageHeight = 28, imageWidth = 28
let imageHeight = 28
let imageWidth = 28
let imageSize = imageHeight * imageWidth
let latentSize = 64

func plotImage(_ image: Tensor<Float>, name: String) {
// Create figure.
let ax = plt.gca()
let array = np.array([image.scalars])
let pixels = array.reshape(image.shape)
if !FileManager.default.fileExists(atPath: outputFolder) {
try! FileManager.default.createDirectory(
atPath: outputFolder,
withIntermediateDirectories: false,
attributes: nil)
}
ax.imshow(pixels, cmap: "gray")
plt.savefig("\(outputFolder)\(name).png", dpi: 300)
plt.close()
}

// Models

struct Generator: Layer {
var dense1 = Dense<Float>(inputSize: latentSize, outputSize: latentSize * 2,
activation: { leakyRelu($0) })
var dense2 = Dense<Float>(inputSize: latentSize * 2, outputSize: latentSize * 4,
activation: { leakyRelu($0) })
var dense3 = Dense<Float>(inputSize: latentSize * 4, outputSize: latentSize * 8,
activation: { leakyRelu($0) })
var dense4 = Dense<Float>(inputSize: latentSize * 8, outputSize: imageSize,
activation: tanh)

var dense1 = Dense<Float>(
inputSize: latentSize, outputSize: latentSize * 2,
activation: { leakyRelu($0) })

var dense2 = Dense<Float>(
inputSize: latentSize * 2, outputSize: latentSize * 4,
activation: { leakyRelu($0) })

var dense3 = Dense<Float>(
inputSize: latentSize * 4, outputSize: latentSize * 8,
activation: { leakyRelu($0) })

var dense4 = Dense<Float>(
inputSize: latentSize * 8, outputSize: imageSize,
activation: tanh)

var batchnorm1 = BatchNorm<Float>(featureCount: latentSize * 2)
var batchnorm2 = BatchNorm<Float>(featureCount: latentSize * 4)
var batchnorm3 = BatchNorm<Float>(featureCount: latentSize * 8)

@differentiable
func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
let x1 = batchnorm1(dense1(input))
Expand All @@ -75,15 +58,22 @@ struct Generator: Layer {
}

struct Discriminator: Layer {
var dense1 = Dense<Float>(inputSize: imageSize, outputSize: 256,
activation: { leakyRelu($0) })
var dense2 = Dense<Float>(inputSize: 256, outputSize: 64,
activation: { leakyRelu($0) })
var dense3 = Dense<Float>(inputSize: 64, outputSize: 16,
activation: { leakyRelu($0) })
var dense4 = Dense<Float>(inputSize: 16, outputSize: 1,
activation: identity)

var dense1 = Dense<Float>(
inputSize: imageSize, outputSize: 256,
activation: { leakyRelu($0) })

var dense2 = Dense<Float>(
inputSize: 256, outputSize: 64,
activation: { leakyRelu($0) })

var dense3 = Dense<Float>(
inputSize: 64, outputSize: 16,
activation: { leakyRelu($0) })

var dense4 = Dense<Float>(
inputSize: 16, outputSize: 1,
activation: identity)

@differentiable
func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
input.sequenced(through: dense1, dense2, dense3, dense4)
Expand All @@ -94,16 +84,19 @@ struct Discriminator: Layer {

@differentiable
func generatorLoss(fakeLogits: Tensor<Float>) -> Tensor<Float> {
sigmoidCrossEntropy(logits: fakeLogits,
labels: Tensor(ones: fakeLogits.shape))
sigmoidCrossEntropy(
logits: fakeLogits,
labels: Tensor(ones: fakeLogits.shape))
}

@differentiable
func discriminatorLoss(realLogits: Tensor<Float>, fakeLogits: Tensor<Float>) -> Tensor<Float> {
let realLoss = sigmoidCrossEntropy(logits: realLogits,
labels: Tensor(ones: realLogits.shape))
let fakeLoss = sigmoidCrossEntropy(logits: fakeLogits,
labels: Tensor(zeros: fakeLogits.shape))
let realLoss = sigmoidCrossEntropy(
logits: realLogits,
labels: Tensor(ones: realLogits.shape))
let fakeLoss = sigmoidCrossEntropy(
logits: fakeLogits,
labels: Tensor(zeros: fakeLogits.shape))
return realLoss + fakeLoss
}

Expand All @@ -123,18 +116,28 @@ let optD = Adam(for: discriminator, learningRate: 2e-4, beta1: 0.5)
// Noise vectors and plot function for testing
let testImageGridSize = 4
let testVector = sampleVector(size: testImageGridSize * testImageGridSize)
func plotTestImage(_ testImage: Tensor<Float>, name: String) {
var gridImage = testImage.reshaped(to: [testImageGridSize, testImageGridSize,
imageHeight, imageWidth])

func saveImageGrid(_ testImage: Tensor<Float>, name: String) throws {
var gridImage = testImage.reshaped(
to: [
testImageGridSize, testImageGridSize,
imageHeight, imageWidth,
])
// Add padding.
gridImage = gridImage.padded(forSizes: [(0, 0), (0, 0), (1, 1), (1, 1)], with: 1)
// Transpose to create single image.
gridImage = gridImage.transposed(withPermutations: [0, 2, 1, 3])
gridImage = gridImage.reshaped(to: [(imageHeight + 2) * testImageGridSize,
(imageWidth + 2) * testImageGridSize])
gridImage = gridImage.reshaped(
to: [
(imageHeight + 2) * testImageGridSize,
(imageWidth + 2) * testImageGridSize,
])
// Convert [-1, 1] range to [0, 1] range.
gridImage = (gridImage + 1) / 2
plotImage(gridImage, name: name)

try saveImage(
gridImage, size: (gridImage.shape[0], gridImage.shape[1]), directory: outputFolder,
name: name)
}

print("Start training...")
Expand All @@ -147,20 +150,20 @@ for epoch in 1...epochCount {
// Perform alternative update.
// Update generator.
let vec1 = sampleVector(size: batchSize)

let 𝛁generator = generator.gradient { generator -> Tensor<Float> in
let fakeImages = generator(vec1)
let fakeLogits = discriminator(fakeImages)
let loss = generatorLoss(fakeLogits: fakeLogits)
return loss
}
optG.update(&generator, along: 𝛁generator)

// Update discriminator.
let realImages = dataset.trainingImages.minibatch(at: i, batchSize: batchSize)
let vec2 = sampleVector(size: batchSize)
let fakeImages = generator(vec2)

let 𝛁discriminator = discriminator.gradient { discriminator -> Tensor<Float> in
let realLogits = discriminator(realImages)
let fakeLogits = discriminator(fakeImages)
Expand All @@ -169,12 +172,17 @@ for epoch in 1...epochCount {
}
optD.update(&discriminator, along: 𝛁discriminator)
}

// Start inference phase.
Context.local.learningPhase = .inference
let testImage = generator(testVector)
plotTestImage(testImage, name: "epoch-\(epoch)-output")


do {
try saveImageGrid(testImage, name: "epoch-\(epoch)-output")
} catch {
print("Could not save image grid with error: \(error)")
}

let lossG = generatorLoss(fakeLogits: testImage)
print("[Epoch: \(epoch)] Loss-G: \(lossG)")
}
6 changes: 4 additions & 2 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ let package = Package(
products: [
.library(name: "ImageClassificationModels", targets: ["ImageClassificationModels"]),
.library(name: "Datasets", targets: ["Datasets"]),
.library(name: "ModelSupport", targets: ["ModelSupport"]),
.executable(name: "Custom-CIFAR10", targets: ["Custom-CIFAR10"]),
.executable(name: "ResNet-CIFAR10", targets: ["ResNet-CIFAR10"]),
.executable(name: "LeNet-MNIST", targets: ["LeNet-MNIST"]),
Expand All @@ -21,7 +22,8 @@ let package = Package(
targets: [
.target(name: "ImageClassificationModels", path: "Models/ImageClassification"),
.target(name: "Datasets", path: "Datasets"),
.target(name: "Autoencoder", dependencies: ["Datasets"], path: "Autoencoder"),
.target(name: "ModelSupport", path: "Support"),
.target(name: "Autoencoder", dependencies: ["Datasets", "ModelSupport"], path: "Autoencoder"),
.target(name: "Catch", path: "Catch"),
.target(name: "Gym-FrozenLake", path: "Gym/FrozenLake"),
.target(name: "Gym-CartPole", path: "Gym/CartPole"),
Expand All @@ -41,6 +43,6 @@ let package = Package(
sources: ["main.swift"]),
.testTarget(name: "MiniGoTests", dependencies: ["MiniGo"]),
.target(name: "Transformer", path: "Transformer"),
.target(name: "GAN", dependencies: ["Datasets"], path: "GAN"),
.target(name: "GAN", dependencies: ["Datasets", "ModelSupport"], path: "GAN"),
]
)
23 changes: 23 additions & 0 deletions Support/FileManagement.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation

public func createDirectoryIfMissing(at path: String) throws {
guard !FileManager.default.fileExists(atPath: path) else { return }
try FileManager.default.createDirectory(
atPath: path,
withIntermediateDirectories: false,
attributes: nil)
}
Loading