Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit be86ede

Browse files
VolodymyrPavliukevychdan-zheng
authored andcommitted
Matrix multiplication operator was changed, code needs update. (#15)
Fixed `MultiplicationPrecedence` operator issue. Apply new API.
1 parent a1c73e6 commit be86ede

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

Autoencoder/Autoencoder.swift

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,13 @@ extension Autoencoder {
128128
let inputNormalized = input / 255.0
129129

130130
// Forward pass
131-
let z1 = inputNormalized w1
131+
let z1 = inputNormalized w1
132132
let h1 = tanh(z1)
133-
let z2 = h1 w2 + b2
133+
let z2 = h1 w2 + b2
134134
let h2 = z2
135-
let z3 = h2 w3
135+
let z3 = h2 w3
136136
let h3 = tanh(z3)
137-
let z4 = h3 w4
137+
let z4 = h3 w4
138138
let predictions = sigmoid(z4)
139139
let loss: Float = 0.5 * (predictions - inputNormalized).squared().mean()
140140
return (h2, loss, inputNormalized, predictions)
@@ -148,28 +148,28 @@ extension Autoencoder {
148148
let batchSize = Tensor<Float>(inputNormalized.shapeTensor[0])
149149

150150
// Forward pass
151-
let z1 = inputNormalized w1
151+
let z1 = inputNormalized w1
152152
let h1 = tanh(z1)
153-
let z2 = h1 w2 + b2
153+
let z2 = h1 w2 + b2
154154
let h2 = z2
155-
let z3 = h2 w3
155+
let z3 = h2 w3
156156
let h3 = tanh(z3)
157-
let z4 = h3 w4
157+
let z4 = h3 w4
158158
let predictions = sigmoid(z4)
159159

160160
// Backward pass
161161
let dz4 = ((predictions - inputNormalized) / batchSize)
162-
let dw4 = h3.transposed(withPermutations: 1, 0) dz4
162+
let dw4 = h3.transposed(withPermutations: 1, 0) dz4
163163

164164
let dz3 = matmul(dz4, w4.transposed(withPermutations: 1, 0)) * (1 - h3.squared())
165-
let dw3 = h2.transposed(withPermutations: 1, 0) dz3
165+
let dw3 = h2.transposed(withPermutations: 1, 0) dz3
166166

167167
let dz2 = matmul(dz3, w3.transposed(withPermutations: 1, 0))
168-
let dw2 = h1.transposed(withPermutations: 1, 0) dz2
168+
let dw2 = h1.transposed(withPermutations: 1, 0) dz2
169169
let db2 = dz2.sum(squeezingAxes: 0)
170170

171171
let dz1 = matmul(dz2, w2.transposed(withPermutations: 1, 0)) * (1 - h1.squared())
172-
let dw1 = inputNormalized.transposed(withPermutations: 1, 0) dz1
172+
let dw1 = inputNormalized.transposed(withPermutations: 1, 0) dz1
173173

174174
let loss: Float = 0.5 * (predictions - inputNormalized).squared().mean()
175175

0 commit comments

Comments
 (0)