Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 3fa05f7

Browse files
authored
Add JPEG loading / saving via Raw TF operations, removing Matplotlib dependencies (#188)
* Added an Image struct as a wrapper for JPEG loading / saving, removed matplotlib dependency from GAN and Autoencoder examples using this. * Formatting update for Autoencoder. * Bring this inline with current API. * Made saveImage() a throwing function, improved formatting. * Changed function parameter.
1 parent 36fdbb1 commit 3fa05f7

File tree

8 files changed

+215
-107
lines changed

8 files changed

+215
-107
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@
88
.swiftpm
99
cifar-10-batches-py/
1010
cifar-10-batches-bin/
11+
output/

Autoencoder/README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ To begin, you'll need the [latest version of Swift for
2121
TensorFlow](https://github.com/tensorflow/swift/blob/master/Installation.md)
2222
installed. Make sure you've added the correct version of `swift` to your path.
2323

24-
This example requires Matplotlib and NumPy to be installed, for use in image output.
25-
2624
To train the model, run:
2725

2826
```

Autoencoder/main.swift

Lines changed: 24 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,59 +12,38 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
import Datasets
1516
import Foundation
17+
import ModelSupport
1618
import TensorFlow
17-
import Python
18-
import Datasets
19-
20-
// Import Python modules
21-
let matplotlib = Python.import("matplotlib")
22-
let np = Python.import("numpy")
23-
24-
// Use the AGG renderer for saving images to disk.
25-
matplotlib.use("Agg")
26-
27-
let plt = Python.import("matplotlib.pyplot")
2819

2920
let epochCount = 10
3021
let batchSize = 100
31-
let outputFolder = "./output/"
32-
let imageHeight = 28, imageWidth = 28
22+
let imageHeight = 28
23+
let imageWidth = 28
3324

34-
func plot(image: [Float], name: String) {
35-
// Create figure
36-
let ax = plt.gca()
37-
let array = np.array([image])
38-
let pixels = array.reshape([imageHeight, imageWidth])
39-
if !FileManager.default.fileExists(atPath: outputFolder) {
40-
try! FileManager.default.createDirectory(atPath: outputFolder,
41-
withIntermediateDirectories: false,
42-
attributes: nil)
43-
}
44-
ax.imshow(pixels, cmap: "gray")
45-
plt.savefig("\(outputFolder)\(name).png", dpi: 300)
46-
plt.close()
47-
}
25+
let outputFolder = "./output/"
4826

4927
/// An autoencoder.
5028
struct Autoencoder: Layer {
51-
typealias Input = Tensor<Float>
52-
typealias Output = Tensor<Float>
53-
54-
var encoder1 = Dense<Float>(inputSize: imageHeight * imageWidth, outputSize: 128,
29+
var encoder1 = Dense<Float>(
30+
inputSize: imageHeight * imageWidth, outputSize: 128,
5531
activation: relu)
32+
5633
var encoder2 = Dense<Float>(inputSize: 128, outputSize: 64, activation: relu)
5734
var encoder3 = Dense<Float>(inputSize: 64, outputSize: 12, activation: relu)
5835
var encoder4 = Dense<Float>(inputSize: 12, outputSize: 3, activation: relu)
5936

6037
var decoder1 = Dense<Float>(inputSize: 3, outputSize: 12, activation: relu)
6138
var decoder2 = Dense<Float>(inputSize: 12, outputSize: 64, activation: relu)
6239
var decoder3 = Dense<Float>(inputSize: 64, outputSize: 128, activation: relu)
63-
var decoder4 = Dense<Float>(inputSize: 128, outputSize: imageHeight * imageWidth,
40+
41+
var decoder4 = Dense<Float>(
42+
inputSize: 128, outputSize: imageHeight * imageWidth,
6443
activation: tanh)
6544

6645
@differentiable
67-
func callAsFunction(_ input: Input) -> Output {
46+
func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
6847
let encoder = input.sequenced(through: encoder1, encoder2, encoder3, encoder4)
6948
return encoder.sequenced(through: decoder1, decoder2, decoder3, decoder4)
7049
}
@@ -76,11 +55,20 @@ let optimizer = RMSProp(for: autoencoder)
7655

7756
// Training loop
7857
for epoch in 1...epochCount {
79-
let sampleImage = Tensor(shape: [1, imageHeight * imageWidth], scalars: dataset.trainingImages[epoch].scalars)
58+
let sampleImage = Tensor(
59+
shape: [1, imageHeight * imageWidth], scalars: dataset.trainingImages[epoch].scalars)
8060
let testImage = autoencoder(sampleImage)
8161

82-
plot(image: sampleImage.scalars, name: "epoch-\(epoch)-input")
83-
plot(image: testImage.scalars, name: "epoch-\(epoch)-output")
62+
do {
63+
try saveImage(
64+
sampleImage, size: (imageWidth, imageHeight), directory: outputFolder,
65+
name: "epoch-\(epoch)-input")
66+
try saveImage(
67+
testImage, size: (imageWidth, imageHeight), directory: outputFolder,
68+
name: "epoch-\(epoch)-output")
69+
} catch {
70+
print("Could not save image with error: \(error)")
71+
}
8472

8573
let sampleLoss = meanSquaredError(predicted: testImage, expected: sampleImage)
8674
print("[Epoch: \(epoch)] Loss: \(sampleLoss)")

GAN/README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ To begin, you'll need the [latest version of Swift for
1616
TensorFlow](https://github.com/tensorflow/swift/blob/master/Installation.md)
1717
installed. Make sure you've added the correct version of `swift` to your path.
1818

19-
This example requires Matplotlib and NumPy to be installed, for use in image output.
20-
2119
To train the model, run:
2220

2321
```sh

GAN/main.swift

Lines changed: 73 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -12,59 +12,42 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
import Datasets
1516
import Foundation
17+
import ModelSupport
1618
import TensorFlow
17-
import Python
18-
import Datasets
19-
20-
// Import Python modules.
21-
let matplotlib = Python.import("matplotlib")
22-
let np = Python.import("numpy")
23-
24-
// Use the AGG renderer for saving images to disk.
25-
matplotlib.use("Agg")
26-
27-
let plt = Python.import("matplotlib.pyplot")
2819

2920
let epochCount = 10
3021
let batchSize = 32
3122
let outputFolder = "./output/"
32-
let imageHeight = 28, imageWidth = 28
23+
let imageHeight = 28
24+
let imageWidth = 28
3325
let imageSize = imageHeight * imageWidth
3426
let latentSize = 64
3527

36-
func plotImage(_ image: Tensor<Float>, name: String) {
37-
// Create figure.
38-
let ax = plt.gca()
39-
let array = np.array([image.scalars])
40-
let pixels = array.reshape(image.shape)
41-
if !FileManager.default.fileExists(atPath: outputFolder) {
42-
try! FileManager.default.createDirectory(
43-
atPath: outputFolder,
44-
withIntermediateDirectories: false,
45-
attributes: nil)
46-
}
47-
ax.imshow(pixels, cmap: "gray")
48-
plt.savefig("\(outputFolder)\(name).png", dpi: 300)
49-
plt.close()
50-
}
51-
5228
// Models
5329

5430
struct Generator: Layer {
55-
var dense1 = Dense<Float>(inputSize: latentSize, outputSize: latentSize * 2,
56-
activation: { leakyRelu($0) })
57-
var dense2 = Dense<Float>(inputSize: latentSize * 2, outputSize: latentSize * 4,
58-
activation: { leakyRelu($0) })
59-
var dense3 = Dense<Float>(inputSize: latentSize * 4, outputSize: latentSize * 8,
60-
activation: { leakyRelu($0) })
61-
var dense4 = Dense<Float>(inputSize: latentSize * 8, outputSize: imageSize,
62-
activation: tanh)
63-
31+
var dense1 = Dense<Float>(
32+
inputSize: latentSize, outputSize: latentSize * 2,
33+
activation: { leakyRelu($0) })
34+
35+
var dense2 = Dense<Float>(
36+
inputSize: latentSize * 2, outputSize: latentSize * 4,
37+
activation: { leakyRelu($0) })
38+
39+
var dense3 = Dense<Float>(
40+
inputSize: latentSize * 4, outputSize: latentSize * 8,
41+
activation: { leakyRelu($0) })
42+
43+
var dense4 = Dense<Float>(
44+
inputSize: latentSize * 8, outputSize: imageSize,
45+
activation: tanh)
46+
6447
var batchnorm1 = BatchNorm<Float>(featureCount: latentSize * 2)
6548
var batchnorm2 = BatchNorm<Float>(featureCount: latentSize * 4)
6649
var batchnorm3 = BatchNorm<Float>(featureCount: latentSize * 8)
67-
50+
6851
@differentiable
6952
func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
7053
let x1 = batchnorm1(dense1(input))
@@ -75,15 +58,22 @@ struct Generator: Layer {
7558
}
7659

7760
struct Discriminator: Layer {
78-
var dense1 = Dense<Float>(inputSize: imageSize, outputSize: 256,
79-
activation: { leakyRelu($0) })
80-
var dense2 = Dense<Float>(inputSize: 256, outputSize: 64,
81-
activation: { leakyRelu($0) })
82-
var dense3 = Dense<Float>(inputSize: 64, outputSize: 16,
83-
activation: { leakyRelu($0) })
84-
var dense4 = Dense<Float>(inputSize: 16, outputSize: 1,
85-
activation: identity)
86-
61+
var dense1 = Dense<Float>(
62+
inputSize: imageSize, outputSize: 256,
63+
activation: { leakyRelu($0) })
64+
65+
var dense2 = Dense<Float>(
66+
inputSize: 256, outputSize: 64,
67+
activation: { leakyRelu($0) })
68+
69+
var dense3 = Dense<Float>(
70+
inputSize: 64, outputSize: 16,
71+
activation: { leakyRelu($0) })
72+
73+
var dense4 = Dense<Float>(
74+
inputSize: 16, outputSize: 1,
75+
activation: identity)
76+
8777
@differentiable
8878
func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
8979
input.sequenced(through: dense1, dense2, dense3, dense4)
@@ -94,16 +84,19 @@ struct Discriminator: Layer {
9484

9585
@differentiable
9686
func generatorLoss(fakeLogits: Tensor<Float>) -> Tensor<Float> {
97-
sigmoidCrossEntropy(logits: fakeLogits,
98-
labels: Tensor(ones: fakeLogits.shape))
87+
sigmoidCrossEntropy(
88+
logits: fakeLogits,
89+
labels: Tensor(ones: fakeLogits.shape))
9990
}
10091

10192
@differentiable
10293
func discriminatorLoss(realLogits: Tensor<Float>, fakeLogits: Tensor<Float>) -> Tensor<Float> {
103-
let realLoss = sigmoidCrossEntropy(logits: realLogits,
104-
labels: Tensor(ones: realLogits.shape))
105-
let fakeLoss = sigmoidCrossEntropy(logits: fakeLogits,
106-
labels: Tensor(zeros: fakeLogits.shape))
94+
let realLoss = sigmoidCrossEntropy(
95+
logits: realLogits,
96+
labels: Tensor(ones: realLogits.shape))
97+
let fakeLoss = sigmoidCrossEntropy(
98+
logits: fakeLogits,
99+
labels: Tensor(zeros: fakeLogits.shape))
107100
return realLoss + fakeLoss
108101
}
109102

@@ -123,18 +116,28 @@ let optD = Adam(for: discriminator, learningRate: 2e-4, beta1: 0.5)
123116
// Noise vectors and plot function for testing
124117
let testImageGridSize = 4
125118
let testVector = sampleVector(size: testImageGridSize * testImageGridSize)
126-
func plotTestImage(_ testImage: Tensor<Float>, name: String) {
127-
var gridImage = testImage.reshaped(to: [testImageGridSize, testImageGridSize,
128-
imageHeight, imageWidth])
119+
120+
func saveImageGrid(_ testImage: Tensor<Float>, name: String) throws {
121+
var gridImage = testImage.reshaped(
122+
to: [
123+
testImageGridSize, testImageGridSize,
124+
imageHeight, imageWidth,
125+
])
129126
// Add padding.
130127
gridImage = gridImage.padded(forSizes: [(0, 0), (0, 0), (1, 1), (1, 1)], with: 1)
131128
// Transpose to create single image.
132129
gridImage = gridImage.transposed(withPermutations: [0, 2, 1, 3])
133-
gridImage = gridImage.reshaped(to: [(imageHeight + 2) * testImageGridSize,
134-
(imageWidth + 2) * testImageGridSize])
130+
gridImage = gridImage.reshaped(
131+
to: [
132+
(imageHeight + 2) * testImageGridSize,
133+
(imageWidth + 2) * testImageGridSize,
134+
])
135135
// Convert [-1, 1] range to [0, 1] range.
136136
gridImage = (gridImage + 1) / 2
137-
plotImage(gridImage, name: name)
137+
138+
try saveImage(
139+
gridImage, size: (gridImage.shape[0], gridImage.shape[1]), directory: outputFolder,
140+
name: name)
138141
}
139142

140143
print("Start training...")
@@ -147,20 +150,20 @@ for epoch in 1...epochCount {
147150
// Perform alternative update.
148151
// Update generator.
149152
let vec1 = sampleVector(size: batchSize)
150-
153+
151154
let 𝛁generator = generator.gradient { generator -> Tensor<Float> in
152155
let fakeImages = generator(vec1)
153156
let fakeLogits = discriminator(fakeImages)
154157
let loss = generatorLoss(fakeLogits: fakeLogits)
155158
return loss
156159
}
157160
optG.update(&generator, along: 𝛁generator)
158-
161+
159162
// Update discriminator.
160163
let realImages = dataset.trainingImages.minibatch(at: i, batchSize: batchSize)
161164
let vec2 = sampleVector(size: batchSize)
162165
let fakeImages = generator(vec2)
163-
166+
164167
let 𝛁discriminator = discriminator.gradient { discriminator -> Tensor<Float> in
165168
let realLogits = discriminator(realImages)
166169
let fakeLogits = discriminator(fakeImages)
@@ -169,12 +172,17 @@ for epoch in 1...epochCount {
169172
}
170173
optD.update(&discriminator, along: 𝛁discriminator)
171174
}
172-
175+
173176
// Start inference phase.
174177
Context.local.learningPhase = .inference
175178
let testImage = generator(testVector)
176-
plotTestImage(testImage, name: "epoch-\(epoch)-output")
177-
179+
180+
do {
181+
try saveImageGrid(testImage, name: "epoch-\(epoch)-output")
182+
} catch {
183+
print("Could not save image grid with error: \(error)")
184+
}
185+
178186
let lossG = generatorLoss(fakeLogits: testImage)
179187
print("[Epoch: \(epoch)] Loss-G: \(lossG)")
180188
}

Package.swift

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ let package = Package(
1111
products: [
1212
.library(name: "ImageClassificationModels", targets: ["ImageClassificationModels"]),
1313
.library(name: "Datasets", targets: ["Datasets"]),
14+
.library(name: "ModelSupport", targets: ["ModelSupport"]),
1415
.executable(name: "Custom-CIFAR10", targets: ["Custom-CIFAR10"]),
1516
.executable(name: "ResNet-CIFAR10", targets: ["ResNet-CIFAR10"]),
1617
.executable(name: "LeNet-MNIST", targets: ["LeNet-MNIST"]),
@@ -21,7 +22,8 @@ let package = Package(
2122
targets: [
2223
.target(name: "ImageClassificationModels", path: "Models/ImageClassification"),
2324
.target(name: "Datasets", path: "Datasets"),
24-
.target(name: "Autoencoder", dependencies: ["Datasets"], path: "Autoencoder"),
25+
.target(name: "ModelSupport", path: "Support"),
26+
.target(name: "Autoencoder", dependencies: ["Datasets", "ModelSupport"], path: "Autoencoder"),
2527
.target(name: "Catch", path: "Catch"),
2628
.target(name: "Gym-FrozenLake", path: "Gym/FrozenLake"),
2729
.target(name: "Gym-CartPole", path: "Gym/CartPole"),
@@ -41,6 +43,6 @@ let package = Package(
4143
sources: ["main.swift"]),
4244
.testTarget(name: "MiniGoTests", dependencies: ["MiniGo"]),
4345
.target(name: "Transformer", path: "Transformer"),
44-
.target(name: "GAN", dependencies: ["Datasets"], path: "GAN"),
46+
.target(name: "GAN", dependencies: ["Datasets", "ModelSupport"], path: "GAN"),
4547
]
4648
)

Support/FileManagement.swift

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
17+
public func createDirectoryIfMissing(at path: String) throws {
18+
guard !FileManager.default.fileExists(atPath: path) else { return }
19+
try FileManager.default.createDirectory(
20+
atPath: path,
21+
withIntermediateDirectories: false,
22+
attributes: nil)
23+
}

0 commit comments

Comments
 (0)