15
15
import Foundation
16
16
import TensorFlow
17
17
18
- /// Returns the images tensor and labels tensor.
19
- public func readMnist(
20
- imagesFile: String , labelsFile: String
21
- ) -> ( Tensor < Float > , Tensor < Int32 > ) {
18
+ /// Reads MNIST images and labels from specified file paths.
19
+ func readMNIST( imagesFile: String , labelsFile: String ) -> ( images: Tensor < Float > , labels: Tensor < Int32 > ) {
22
20
print ( " Reading data. " )
23
- let imageData =
24
- try ! Data ( contentsOf: URL ( fileURLWithPath: imagesFile) ) . dropFirst ( 16 )
25
- let labelData =
26
- try ! Data ( contentsOf: URL ( fileURLWithPath: labelsFile) ) . dropFirst ( 8 )
21
+ let imageData = try ! Data ( contentsOf: URL ( fileURLWithPath: imagesFile) ) . dropFirst ( 16 )
22
+ let labelData = try ! Data ( contentsOf: URL ( fileURLWithPath: labelsFile) ) . dropFirst ( 8 )
27
23
let images = imageData. map { Float ( $0) }
28
24
let labels = labelData. map { Int32 ( $0) }
29
25
let rowCount = Int32 ( labels. count)
@@ -35,96 +31,71 @@ public func readMnist(
35
31
return ( imagesTensor. toAccelerator ( ) , labelsTensor. toAccelerator ( ) )
36
32
}
37
33
38
- func main( ) {
34
+ /// Parameters of an MNIST classifier.
35
+ struct MNISTParameters : ParameterAggregate {
36
+ var w1 = Tensor < Float > ( randomUniform: [ 784 , 30 ] )
37
+ var w2 = Tensor < Float > ( randomUniform: [ 30 , 10 ] )
38
+ var b1 = Tensor < Float > ( zeros: [ 1 , 30 ] )
39
+ var b2 = Tensor < Float > ( zeros: [ 1 , 10 ] )
40
+ }
41
+
42
+ /// Train a MNIST classifier for the specified number of iterations.
43
+ func train( _ parameters: inout MNISTParameters , iterationCount: Int ) {
39
44
// Get script directory. This is necessary for MNIST.swift to work when
40
45
// invoked from any directory.
41
- let currentDirectory =
42
- URL ( fileURLWithPath: FileManager . default. currentDirectoryPath)
46
+ let currentDirectory = URL ( fileURLWithPath: FileManager . default. currentDirectoryPath)
43
47
let currentScriptPath = URL ( fileURLWithPath: CommandLine . arguments [ 0 ] ,
44
48
relativeTo: currentDirectory)
45
49
let scriptDirectory = currentScriptPath. appendingPathComponent ( " .. " )
46
50
47
51
// Get training data.
48
- let imagesFile =
49
- scriptDirectory. appendingPathComponent ( " train-images-idx3-ubyte " ) . path
50
- let labelsFile =
51
- scriptDirectory. appendingPathComponent ( " train-labels-idx1-ubyte " ) . path
52
- let ( images, numericLabels) = readMnist ( imagesFile: imagesFile,
53
- labelsFile: labelsFile)
52
+ let imagesFile = scriptDirectory. appendingPathComponent ( " train-images-idx3-ubyte " ) . path
53
+ let labelsFile = scriptDirectory. appendingPathComponent ( " train-labels-idx1-ubyte " ) . path
54
+ let ( images, numericLabels) = readMNIST ( imagesFile: imagesFile, labelsFile: labelsFile)
54
55
let labels = Tensor < Float > ( oneHotAtIndices: numericLabels, depth: 10 )
55
- // FIXME: Defining batchSize as a scalar, or as a tensor as follows instead
56
- // of returning it from readMnist() crashes the compiler:
57
- // https://bugs.swift.org/browse/SR-7706
58
- // let batchSize = Tensor<Float>(Float(images.shape[0]))
59
- let batchSize = Tensor < Float > ( images. shapeTensor [ 0 ] )
56
+ let batchSize = Float ( images. shape [ 0 ] )
60
57
61
58
// Hyper-parameters.
62
- let iterationCount : Int32 = 20
63
59
let learningRate : Float = 0.2
64
60
var loss = Float . infinity
65
61
66
- // Parameters.
67
- var w1 = Tensor < Float > ( randomUniform: [ 784 , 30 ] )
68
- var w2 = Tensor < Float > ( randomUniform: [ 30 , 10 ] )
69
- var b1 = Tensor < Float > ( zeros: [ 1 , 30 ] )
70
- var b2 = Tensor < Float > ( zeros: [ 1 , 10 ] )
71
-
72
62
// Training loop.
73
63
print ( " Begin training for \( iterationCount) iterations. " )
74
64
75
- var i : Int32 = 0
76
- repeat {
65
+ for _ in 0 ... iterationCount {
77
66
// Forward pass.
78
- let z1 = images • w1 + b1
67
+ let z1 = images • parameters . w1 + parameters . b1
79
68
let h1 = sigmoid ( z1)
80
- let z2 = h1 • w2 + b2
69
+ let z2 = h1 • parameters . w2 + parameters . b2
81
70
let predictions = sigmoid ( z2)
82
71
83
- // Backward pass.
72
+ // Backward pass. This will soon be replaced by automatic
73
+ // differentiation.
84
74
let dz2 = ( predictions - labels) / batchSize
85
- let dw2 = h1. transposed ( withPermutations : 1 , 0 ) • dz2
75
+ let dw2 = h1. transposed ( ) • dz2
86
76
let db2 = dz2. sum ( squeezingAxes: 0 )
87
- let dz1 = matmul ( dz2, w2. transposed ( withPermutations : 1 , 0 ) ) * h1 * ( 1 - h1)
88
- let dw1 = images. transposed ( withPermutations : 1 , 0 ) • dz1
77
+ let dz1 = matmul ( dz2, parameters . w2. transposed ( ) ) * h1 * ( 1 - h1)
78
+ let dw1 = images. transposed ( ) • dz1
89
79
let db1 = dz1. sum ( squeezingAxes: 0 )
80
+ let gradients = MNISTParameters ( w1: dw1, w2: dw2, b1: db1, b2: db2)
90
81
91
- // Gradient descent.
92
- w1 -= dw1 * learningRate
93
- b1 -= db1 * learningRate
94
- w2 -= dw2 * learningRate
95
- b2 -= db2 * learningRate
82
+ // Update parameters.
83
+ parameters. update ( withGradients: gradients) { param, grad in
84
+ param -= grad * learningRate
85
+ }
96
86
97
87
// Update the sigmoid-based cross-entropy loss, where we treat the 10
98
88
// class labels as independent. This is unnecessary for the MNIST case,
99
89
// where we want to predict a single label. In that case we should
100
90
// consider switching to a softmax-based cross-entropy loss.
101
- //
102
- // Let m be the batch size, y be the target labels, and A be the
103
- // predictions. The formula expressed in TF expression is:
104
- // 1/m * tf.reduce_sum(- y * tf.log(A) - (1-y) * tf.log(1-A))
105
91
let part1 = - labels * log( predictions)
106
92
let part2 = - ( 1 - labels) * log( 1 - predictions)
107
- // FIXME: Remove scalarized() call when we make `batchSize` scalar,
108
- // after fixing https://bugs.swift.org/browse/SR-7706
109
- loss = ( part1 + part2) . sum ( ) / batchSize. scalarized ( )
110
- // To print out the loss value per iteration, uncomment the following
111
- // code.
112
- // FIXME: Fix runtime hanging when we print loss directly instead of
113
- // printing via lossTensor: https://bugs.swift.org/browse/SR-7705
114
- // let lossTensor = Tensor<Float>(loss)
115
- // print(lossTensor)
116
-
117
- // Update iteration count.
118
- i += 1
119
- } while i < iterationCount
120
-
121
- // Print loss.
122
- print ( " Loss: \( loss) " )
123
- // Uncomment the code below if we also print out loss per loop iteration
124
- // above. This will not be necessary after fixing
125
- // https://bugs.swift.org/browse/SR-7705.
126
- // let lossTensor = Tensor<Float>(loss)
127
- // print(lossTensor)
93
+ loss = ( part1 + part2) . sum ( ) / batchSize
94
+
95
+ print ( " Loss: " , loss)
96
+ }
128
97
}
129
98
130
- main ( )
99
+ var parameters = MNISTParameters ( )
100
+ // Start training.
101
+ train ( & parameters, iterationCount: 20 )
0 commit comments