Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 52a11a6

Browse files
authored
Adopt derived 'Parameterized' conformance in the Autoencoder model. (#19)
1 parent 5d15f5b commit 52a11a6

File tree

1 file changed

+81
-80
lines changed

1 file changed

+81
-80
lines changed

Autoencoder/Autoencoder.swift

Lines changed: 81 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,25 @@ import TensorFlow
1818
import Python
1919

2020
let outputFolder = "/tmp/mnist-test/"
21-
enum AutoencoderError: Error {
22-
case noDatasetFound
23-
}
2421

25-
func readDataset() throws -> (images: Tensor<Float>, labels: Tensor<Int32>) {
22+
func readDataset() -> (images: Tensor<Float>, labels: Tensor<Int32>)? {
2623
print("Reading the data.")
27-
guard let swiftFile = CommandLine.arguments.first else { throw AutoencoderError.noDatasetFound }
24+
guard let swiftFile = CommandLine.arguments.first else { return nil }
2825
let swiftFileURL = URL(fileURLWithPath: swiftFile)
2926
var imageFolderURL = swiftFileURL.deletingLastPathComponent()
3027
var labelFolderURL = swiftFileURL.deletingLastPathComponent()
3128
imageFolderURL.appendPathComponent("Resources/train-images-idx3-ubyte")
3229
labelFolderURL.appendPathComponent("Resources/train-labels-idx1-ubyte")
33-
34-
let imageData = try Data(contentsOf: imageFolderURL).dropFirst(16)
35-
let labelData = try Data(contentsOf: labelFolderURL).dropFirst(8)
30+
31+
guard let imageData = try? Data(contentsOf: imageFolderURL).dropFirst(16),
32+
let labelData = try? Data(contentsOf: labelFolderURL).dropFirst(8) else {
33+
return nil
34+
}
3635
let images = imageData.map { Float($0) }
3736
let labels = labelData.map { Int32($0) }
3837
let rowCount = Int32(labels.count)
3938
let columnCount = Int32(images.count) / rowCount
40-
39+
4140
print("Constructing the data tensors.")
4241
let imagesTensor = Tensor(shape: [rowCount, columnCount], scalars: images)
4342
let labelsTensor = Tensor(labels)
@@ -93,27 +92,27 @@ func plot(image: [Float], name: String) {
9392
plt.close()
9493
}
9594

96-
struct Autoencoder {
95+
struct Autoencoder : Parameterized {
9796
static let imageEdge: Int32 = 28
9897
static let imageSize: Int32 = imageEdge * imageEdge
9998
static let decoderLayerSize: Int32 = 50
10099
static let encoderLayerSize: Int32 = 50
101100
static let hiddenLayerSize: Int32 = 2
102-
103-
var w1: Tensor<Float>
104-
var w2: Tensor<Float>
105-
var w3: Tensor<Float>
106-
var w4: Tensor<Float>
107-
108-
var b2 = Tensor<Float>(zeros: [1, Autoencoder.hiddenLayerSize])
101+
102+
@TFParameter var w1: Tensor<Float>
103+
@TFParameter var w2: Tensor<Float>
104+
@TFParameter var w3: Tensor<Float>
105+
@TFParameter var w4: Tensor<Float>
106+
@TFParameter var b2 = Tensor<Float>(zeros: [1, Autoencoder.hiddenLayerSize])
107+
109108
var learningRate: Float = 0.001
110-
109+
111110
init() {
112111
let w1 = Tensor<Float>(randomUniform: [Autoencoder.imageSize, Autoencoder.decoderLayerSize])
113112
let w2 = Tensor<Float>(randomUniform: [Autoencoder.decoderLayerSize, Autoencoder.hiddenLayerSize])
114113
let w3 = Tensor<Float>(randomUniform: [Autoencoder.hiddenLayerSize, Autoencoder.encoderLayerSize])
115114
let w4 = Tensor<Float>(randomUniform: [Autoencoder.encoderLayerSize, Autoencoder.imageSize])
116-
115+
117116
// Xavier initialization
118117
self.w1 = w1 / sqrtf(Float(Autoencoder.imageSize))
119118
self.w2 = w2 / sqrtf(Float(Autoencoder.decoderLayerSize))
@@ -126,7 +125,7 @@ extension Autoencoder {
126125
@inline(never)
127126
func embedding(for input: Tensor<Float>) -> (tensor: Tensor<Float>, loss: Float, input: Tensor<Float>, output: Tensor<Float>) {
128127
let inputNormalized = input / 255.0
129-
128+
130129
// Forward pass
131130
let z1 = inputNormalized w1
132131
let h1 = tanh(z1)
@@ -139,14 +138,14 @@ extension Autoencoder {
139138
let loss: Float = 0.5 * (predictions - inputNormalized).squared().mean()
140139
return (h2, loss, inputNormalized, predictions)
141140
}
142-
141+
143142
mutating func trainStep(input: Tensor<Float>) -> Float {
144143
let learningRate = self.learningRate
145-
144+
146145
// Batch normalization
147146
let inputNormalized = input / 255.0
148147
let batchSize = Tensor<Float>(inputNormalized.shapeTensor[0])
149-
148+
150149
// Forward pass
151150
let z1 = inputNormalized w1
152151
let h1 = tanh(z1)
@@ -156,31 +155,26 @@ extension Autoencoder {
156155
let h3 = tanh(z3)
157156
let z4 = h3 w4
158157
let predictions = sigmoid(z4)
159-
158+
160159
// Backward pass
161160
let dz4 = ((predictions - inputNormalized) / batchSize)
162-
let dw4 = h3.transposed(withPermutations: 1, 0) dz4
163-
164-
let dz3 = matmul(dz4, w4.transposed(withPermutations: 1, 0)) * (1 - h3.squared())
165-
let dw3 = h2.transposed(withPermutations: 1, 0) dz3
166-
167-
let dz2 = matmul(dz3, w3.transposed(withPermutations: 1, 0))
168-
let dw2 = h1.transposed(withPermutations: 1, 0) dz2
161+
let dw4 = h3.transposed() dz4
162+
let dz3 = matmul(dz4, w4.transposed()) * (1 - h3.squared())
163+
let dw3 = h2.transposed() dz3
164+
let dz2 = matmul(dz3, w3.transposed())
165+
let dw2 = h1.transposed() dz2
169166
let db2 = dz2.sum(squeezingAxes: 0)
170-
171-
let dz1 = matmul(dz2, w2.transposed(withPermutations: 1, 0)) * (1 - h1.squared())
172-
let dw1 = inputNormalized.transposed(withPermutations: 1, 0) dz1
173-
167+
let dz1 = matmul(dz2, w2.transposed()) * (1 - h1.squared())
168+
let dw1 = inputNormalized.transposed() dz1
169+
let gradients = Parameters(w1: dw1, w2: dw2, w3: dw3, w4: dw4, b2: db2)
170+
174171
let loss: Float = 0.5 * (predictions - inputNormalized).squared().mean()
175-
172+
176173
// Gradient descent.
177-
w1 -= dw1 * learningRate
178-
w2 -= dw2 * learningRate
179-
w3 -= dw3 * learningRate
180-
w4 -= dw4 * learningRate
181-
182-
b2 -= db2 * learningRate
183-
174+
allParameters.update(withGradients: gradients) { p, g in
175+
p -= g * learningRate
176+
}
177+
184178
return loss
185179
}
186180
}
@@ -204,11 +198,11 @@ extension Autoencoder {
204198
}
205199
} while iterationNumber < maxIterations
206200
}
207-
201+
208202
private static func reshape(image: [Float], imageCountPerLine: Int) -> [Float] {
209203
var fullImage: [Float] = []
210204
let imageEdge = Int(Autoencoder.imageEdge)
211-
205+
212206
//FIXME: Improve fors.
213207
for rowIndex in 0..<imageCountPerLine {
214208
for pixelIndex in 0..<imageEdge {
@@ -222,7 +216,7 @@ extension Autoencoder {
222216
}
223217
return fullImage
224218
}
225-
219+
226220
func embedding(from dataset: (images: Tensor<Float>, labels: Tensor<Int32>), shouldSaveInput: Bool, elementCount: Int32, step: Int) -> (labels: Tensor<Int32>, tensor: [Float]) {
227221
let images = dataset.images.slice(lowerBounds: [0, 0], upperBounds: [elementCount, Autoencoder.imageSize])
228222
let labels = dataset.labels.slice(lowerBounds: [0], upperBounds: [elementCount])
@@ -240,38 +234,45 @@ extension Autoencoder {
240234
}
241235
}
242236

243-
244-
func main() {
245-
do {
246-
let dataset = try readDataset()
247-
var autoencoder = Autoencoder()
248-
249-
var embedding = autoencoder.embedding(from: dataset, shouldSaveInput: true, elementCount: 300, step: 0)
250-
plot(image: embedding.tensor, labels: embedding.labels, step: 0)
251-
252-
autoencoder.train(on: dataset)
253-
embedding = autoencoder.embedding(from: dataset, shouldSaveInput: false, elementCount: 300, step: 1)
254-
plot(image: embedding.tensor, labels: embedding.labels, step: 1)
255-
256-
autoencoder.train(on: dataset)
257-
embedding = autoencoder.embedding(from: dataset, shouldSaveInput: false, elementCount: 300, step: 2)
258-
plot(image: embedding.tensor, labels: embedding.labels, step: 2)
259-
260-
autoencoder.train(on: dataset)
261-
embedding = autoencoder.embedding(from: dataset, shouldSaveInput: false, elementCount: 300, step: 3)
262-
plot(image: embedding.tensor, labels: embedding.labels, step: 3)
263-
264-
autoencoder.train(on: dataset)
265-
embedding = autoencoder.embedding(from: dataset, shouldSaveInput: false, elementCount: 300, step: 4)
266-
plot(image: embedding.tensor, labels: embedding.labels, step: 4)
267-
268-
autoencoder.train(on: dataset)
269-
embedding = autoencoder.embedding(from: dataset, shouldSaveInput: false, elementCount: 300, step: 5)
270-
plot(image: embedding.tensor, labels: embedding.labels, step: 5)
271-
272-
print("Now, you can open /tmp/mnist-test/ folder and review the results.")
273-
} catch {
274-
print(error)
275-
}
237+
guard let dataset = readDataset() else {
238+
print("Error: could not read dataset.")
239+
exit(0)
276240
}
277-
main()
241+
242+
var autoencoder = Autoencoder()
243+
244+
// Initial prediction.
245+
var embedding = autoencoder.embedding(from: dataset, shouldSaveInput: true, elementCount: 300, step: 0)
246+
plot(image: embedding.tensor, labels: embedding.labels, step: 0)
247+
248+
autoencoder.train(on: dataset)
249+
embedding = autoencoder.embedding(from: dataset, shouldSaveInput: false, elementCount: 300, step: 1)
250+
plot(image: embedding.tensor, labels: embedding.labels, step: 1)
251+
252+
autoencoder.train(on: dataset)
253+
embedding = autoencoder.embedding(from: dataset, shouldSaveInput: false, elementCount: 300, step: 2)
254+
plot(image: embedding.tensor, labels: embedding.labels, step: 2)
255+
256+
autoencoder.train(on: dataset)
257+
embedding = autoencoder.embedding(from: dataset, shouldSaveInput: false, elementCount: 300, step: 3)
258+
plot(image: embedding.tensor, labels: embedding.labels, step: 3)
259+
260+
autoencoder.train(on: dataset)
261+
embedding = autoencoder.embedding(from: dataset, shouldSaveInput: false, elementCount: 300, step: 4)
262+
plot(image: embedding.tensor, labels: embedding.labels, step: 4)
263+
264+
autoencoder.train(on: dataset)
265+
embedding = autoencoder.embedding(from: dataset, shouldSaveInput: false, elementCount: 300, step: 5)
266+
plot(image: embedding.tensor, labels: embedding.labels, step: 5)
267+
268+
// Ideally this would be written as a loop. This is currently blocked on some graph program extraction bugs.
269+
//
270+
// for i in 1...5 {
271+
// autoencoder.train(on: dataset)
272+
// embedding = autoencoder.embedding(from: dataset, shouldSaveInput: false, elementCount: 300, step: i)
273+
// plot(image: embedding.tensor, labels: embedding.labels, step: i)
274+
// }
275+
276+
print("Now, you can open /tmp/mnist-test/ folder and review the results.")
277+
278+

0 commit comments

Comments
 (0)