Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit a32fd58

Browse files
brettkooncerxwei
authored andcommitted
update mnist for new layers api (#109)
* update mnist for new layers api * remove public keyword * indentation
1 parent 9e3ac12 commit a32fd58

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

MNIST/main.swift

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ func readMNIST(imagesFile: String, labelsFile: String) -> (images: Tensor<Float>
5353

5454
/// A classifier.
5555
struct Classifier: Layer {
56+
typealias Input = Tensor<Float>
57+
typealias Output = Tensor<Float>
58+
5659
var conv1a = Conv2D<Float>(filterShape: (3, 3, 1, 32), activation: relu)
5760
var conv1b = Conv2D<Float>(filterShape: (3, 3, 32, 64), activation: relu)
5861
var pool1 = MaxPool2D<Float>(poolSize: (2, 2), strides: (2, 2))
@@ -64,7 +67,7 @@ struct Classifier: Layer {
6467
var layer1b = Dense<Float>(inputSize: 128, outputSize: 10, activation: softmax)
6568

6669
@differentiable
67-
func applied(to input: Tensor<Float>) -> Tensor<Float> {
70+
func call(_ input: Input) -> Output {
6871
let convolved = input.sequenced(through: conv1a, conv1b, pool1)
6972
return convolved.sequenced(through: dropout1a, flatten, layer1a, dropout1b, layer1b)
7073
}
@@ -83,7 +86,7 @@ let (images, numericLabels) = readMNIST(imagesFile: "train-images-idx3-ubyte",
8386
let labels = Tensor<Float>(oneHotAtIndices: numericLabels, depth: 10)
8487

8588
var classifier = Classifier()
86-
let optimizer = RMSProp<Classifier, Float>()
89+
let optimizer = RMSProp(for: classifier)
8790

8891
// The training loop.
8992
for epoch in 1...epochCount {
@@ -95,7 +98,7 @@ for epoch in 1...epochCount {
9598
let y = minibatch(in: numericLabels, at: i)
9699
// Compute the gradient with respect to the model.
97100
let 𝛁model = classifier.gradient { classifier -> Tensor<Float> in
98-
let ŷ = classifier.applied(to: x)
101+
let ŷ = classifier(x)
99102
let correctPredictions = ŷ.argmax(squeezingAxis: 1) .== y
100103
correctGuessCount += Int(Tensor<Int32>(correctPredictions).sum().scalarized())
101104
totalGuessCount += batchSize

0 commit comments

Comments
 (0)