Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 5d15f5b

Browse files
authored
Update MNIST model to adopt ParameterAggregate synthesis. (#18)
1 parent 2d1f616 commit 5d15f5b

File tree

1 file changed

+39
-68
lines changed

1 file changed

+39
-68
lines changed

MNIST/MNIST.swift

Lines changed: 39 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,11 @@
1515
import Foundation
1616
import TensorFlow
1717

18-
/// Returns the images tensor and labels tensor.
19-
public func readMnist(
20-
imagesFile: String, labelsFile: String
21-
) -> (Tensor<Float>, Tensor<Int32>) {
18+
/// Reads MNIST images and labels from specified file paths.
19+
func readMNIST(imagesFile: String, labelsFile: String) -> (images: Tensor<Float>, labels: Tensor<Int32>) {
2220
print("Reading data.")
23-
let imageData =
24-
try! Data(contentsOf: URL(fileURLWithPath: imagesFile)).dropFirst(16)
25-
let labelData =
26-
try! Data(contentsOf: URL(fileURLWithPath: labelsFile)).dropFirst(8)
21+
let imageData = try! Data(contentsOf: URL(fileURLWithPath: imagesFile)).dropFirst(16)
22+
let labelData = try! Data(contentsOf: URL(fileURLWithPath: labelsFile)).dropFirst(8)
2723
let images = imageData.map { Float($0) }
2824
let labels = labelData.map { Int32($0) }
2925
let rowCount = Int32(labels.count)
@@ -35,96 +31,71 @@ public func readMnist(
3531
return (imagesTensor.toAccelerator(), labelsTensor.toAccelerator())
3632
}
3733

38-
func main() {
34+
/// Parameters of an MNIST classifier.
35+
struct MNISTParameters : ParameterAggregate {
36+
var w1 = Tensor<Float>(randomUniform: [784, 30])
37+
var w2 = Tensor<Float>(randomUniform: [30, 10])
38+
var b1 = Tensor<Float>(zeros: [1, 30])
39+
var b2 = Tensor<Float>(zeros: [1, 10])
40+
}
41+
42+
/// Train a MNIST classifier for the specified number of iterations.
43+
func train(_ parameters: inout MNISTParameters, iterationCount: Int) {
3944
// Get script directory. This is necessary for MNIST.swift to work when
4045
// invoked from any directory.
41-
let currentDirectory =
42-
URL(fileURLWithPath: FileManager.default.currentDirectoryPath)
46+
let currentDirectory = URL(fileURLWithPath: FileManager.default.currentDirectoryPath)
4347
let currentScriptPath = URL(fileURLWithPath: CommandLine.arguments[0],
4448
relativeTo: currentDirectory)
4549
let scriptDirectory = currentScriptPath.appendingPathComponent("..")
4650

4751
// Get training data.
48-
let imagesFile =
49-
scriptDirectory.appendingPathComponent("train-images-idx3-ubyte").path
50-
let labelsFile =
51-
scriptDirectory.appendingPathComponent("train-labels-idx1-ubyte").path
52-
let (images, numericLabels) = readMnist(imagesFile: imagesFile,
53-
labelsFile: labelsFile)
52+
let imagesFile = scriptDirectory.appendingPathComponent("train-images-idx3-ubyte").path
53+
let labelsFile = scriptDirectory.appendingPathComponent("train-labels-idx1-ubyte").path
54+
let (images, numericLabels) = readMNIST(imagesFile: imagesFile, labelsFile: labelsFile)
5455
let labels = Tensor<Float>(oneHotAtIndices: numericLabels, depth: 10)
55-
// FIXME: Defining batchSize as a scalar, or as a tensor as follows instead
56-
// of returning it from readMnist() crashes the compiler:
57-
// https://bugs.swift.org/browse/SR-7706
58-
// let batchSize = Tensor<Float>(Float(images.shape[0]))
59-
let batchSize = Tensor<Float>(images.shapeTensor[0])
56+
let batchSize = Float(images.shape[0])
6057

6158
// Hyper-parameters.
62-
let iterationCount: Int32 = 20
6359
let learningRate: Float = 0.2
6460
var loss = Float.infinity
6561

66-
// Parameters.
67-
var w1 = Tensor<Float>(randomUniform: [784, 30])
68-
var w2 = Tensor<Float>(randomUniform: [30, 10])
69-
var b1 = Tensor<Float>(zeros: [1, 30])
70-
var b2 = Tensor<Float>(zeros: [1, 10])
71-
7262
// Training loop.
7363
print("Begin training for \(iterationCount) iterations.")
7464

75-
var i: Int32 = 0
76-
repeat {
65+
for _ in 0...iterationCount {
7766
// Forward pass.
78-
let z1 = images w1 + b1
67+
let z1 = images parameters.w1 + parameters.b1
7968
let h1 = sigmoid(z1)
80-
let z2 = h1 w2 + b2
69+
let z2 = h1 parameters.w2 + parameters.b2
8170
let predictions = sigmoid(z2)
8271

83-
// Backward pass.
72+
// Backward pass. This will soon be replaced by automatic
73+
// differentiation.
8474
let dz2 = (predictions - labels) / batchSize
85-
let dw2 = h1.transposed(withPermutations: 1, 0) dz2
75+
let dw2 = h1.transposed() dz2
8676
let db2 = dz2.sum(squeezingAxes: 0)
87-
let dz1 = matmul(dz2, w2.transposed(withPermutations: 1, 0)) * h1 * (1 - h1)
88-
let dw1 = images.transposed(withPermutations: 1, 0) dz1
77+
let dz1 = matmul(dz2, parameters.w2.transposed()) * h1 * (1 - h1)
78+
let dw1 = images.transposed() dz1
8979
let db1 = dz1.sum(squeezingAxes: 0)
80+
let gradients = MNISTParameters(w1: dw1, w2: dw2, b1: db1, b2: db2)
9081

91-
// Gradient descent.
92-
w1 -= dw1 * learningRate
93-
b1 -= db1 * learningRate
94-
w2 -= dw2 * learningRate
95-
b2 -= db2 * learningRate
82+
// Update parameters.
83+
parameters.update(withGradients: gradients) { param, grad in
84+
param -= grad * learningRate
85+
}
9686

9787
// Update the sigmoid-based cross-entropy loss, where we treat the 10
9888
// class labels as independent. This is unnecessary for the MNIST case,
9989
// where we want to predict a single label. In that case we should
10090
// consider switching to a softmax-based cross-entropy loss.
101-
//
102-
// Let m be the batch size, y be the target labels, and A be the
103-
// predictions. The formula expressed in TF expression is:
104-
// 1/m * tf.reduce_sum(- y * tf.log(A) - (1-y) * tf.log(1-A))
10591
let part1 = -labels * log(predictions)
10692
let part2 = -(1 - labels) * log(1 - predictions)
107-
// FIXME: Remove scalarized() call when we make `batchSize` scalar,
108-
// after fixing https://bugs.swift.org/browse/SR-7706
109-
loss = (part1 + part2).sum() / batchSize.scalarized()
110-
// To print out the loss value per iteration, uncomment the following
111-
// code.
112-
// FIXME: Fix runtime hanging when we print loss directly instead of
113-
// printing via lossTensor: https://bugs.swift.org/browse/SR-7705
114-
// let lossTensor = Tensor<Float>(loss)
115-
// print(lossTensor)
116-
117-
// Update iteration count.
118-
i += 1
119-
} while i < iterationCount
120-
121-
// Print loss.
122-
print("Loss: \(loss)")
123-
// Uncomment the code below if we also print out loss per loop iteration
124-
// above. This will not be necessary after fixing
125-
// https://bugs.swift.org/browse/SR-7705.
126-
// let lossTensor = Tensor<Float>(loss)
127-
// print(lossTensor)
93+
loss = (part1 + part2).sum() / batchSize
94+
95+
print("Loss:", loss)
96+
}
12897
}
12998

130-
main()
99+
var parameters = MNISTParameters()
100+
// Start training.
101+
train(&parameters, iterationCount: 20)

0 commit comments

Comments
 (0)