@@ -1139,10 +1139,9 @@ final class LayerTests: XCTestCase {
1139
1139
func testRNN( ) {
1140
1140
let x = Tensor < Float > ( rangeFrom: 0.0 , to: 0.4 , stride: 0.1 ) . rankLifted ( )
1141
1141
let inputs : [ Tensor < Float > ] = Array ( repeating: x, count: 4 )
1142
- let rnn = RNN ( SimpleRNNCell < Float > ( inputSize: 4 , hiddenSize: 4 ,
1143
- seed: ( 0xFeed , 0xBeef ) ) )
1142
+ let rnn = RNN ( SimpleRNNCell < Float > ( inputSize: 4 , hiddenSize: 4 , seed: ( 0xFeed , 0xBeef ) ) )
1144
1143
withTensorLeakChecking {
1145
- let ( outputs, _ ) = valueWithPullback ( at: rnn, inputs) { rnn, inputs in
1144
+ let ( outputs, pullback ) = valueWithPullback ( at: rnn, inputs) { rnn, inputs in
1146
1145
return rnn ( inputs)
1147
1146
}
1148
1147
assertEqual (
@@ -1152,29 +1151,29 @@ final class LayerTests: XCTestCase {
1152
1151
[ 0.23758979 , 0.32101023 , - 0.20359215 , - 0.1787096 ] ,
1153
1152
[ 0.24337786 , 0.3389194 , - 0.21143384 , - 0.1675081 ] ] ,
1154
1153
accuracy: 1e-6 )
1154
+ let ( 𝛁rnn, _) = pullback ( . init( inputs. map { SimpleRNNCell< Float> . State( $0) } ) )
1155
+ // TODO: Verify that RNN gradients are correct using a reference implementation.
1156
+ XCTAssertEqual ( 𝛁rnn. cell. weight,
1157
+ [ [ 0.0 , 0.0 , 0.0 , 0.0 ] ,
1158
+ [ - 0.014372801 , 0.03128201 , 0.07844338 , 0.08569162 ] ,
1159
+ [ - 0.028745603 , 0.06256402 , 0.15688676 , 0.17138325 ] ,
1160
+ [ - 0.043118402 , 0.09384604 , 0.2353301 , 0.25707486 ] ,
1161
+ [ - 0.019920545 , 0.05355064 , 0.13140751 , 0.15169607 ] ,
1162
+ [ - 0.024906494 , 0.06562942 , 0.15947133 , 0.18506715 ] ,
1163
+ [ 0.016476292 , - 0.042923313 , - 0.10459379 , - 0.12082438 ] ,
1164
+ [ 0.013913135 , - 0.040882945 , - 0.100636974 , - 0.11757788 ] ] )
1165
+ XCTAssertEqual ( 𝛁rnn. cell. bias, [ - 0.14372802 , 0.31282014 , 0.78443366 , 0.8569162 ] )
1155
1166
}
1156
- // TODO: Figure out why the following is numerically unstable.
1157
- // let (𝛁rnn, _) = pullback(.init(inputs.map { SimpleRNNCell<Float>.State($0) }))
1158
- // XCTAssertEqual(𝛁rnn.cell.weight,
1159
- // [[ 0.0, 0.0, 0.0, 0.0],
1160
- // [ 0.02496884, 0.06694733, 0.07978788, -0.022378458],
1161
- // [ 0.04993768, 0.13389467, 0.15957576, -0.044756915],
1162
- // [ 0.07490652, 0.20084201, 0.23936366, -0.06713537],
1163
- // [ 0.0, 0.0, 0.0, 0.0],
1164
- // [ 0.0, 0.0, 0.0, 0.0],
1165
- // [ 0.0, 0.0, 0.0, 0.0],
1166
- // [ 0.0, 0.0, 0.0, 0.0]])
1167
- // XCTAssertEqual(𝛁rnn.cell.bias, [ 0.2496884, 0.66947335, 0.7978788, -0.22378457])
1168
1167
}
1169
1168
1170
1169
func testLSTM( ) {
1171
1170
withRandomSeedForTensorFlow ( ( 0xFeed , 0xBeef ) ) {
1172
1171
let x = Tensor < Float > ( rangeFrom: 0.0 , to: 0.4 , stride: 0.1 ) . rankLifted ( )
1173
1172
let inputs : [ Tensor < Float > ] = Array ( repeating: x, count: 4 )
1174
- let rnn = RNN ( LSTMCell < Float > ( inputSize: 4 , hiddenSize: 4 ) )
1173
+ let lstm = RNN ( LSTMCell < Float > ( inputSize: 4 , hiddenSize: 4 ) )
1175
1174
withTensorLeakChecking {
1176
- let ( outputs, _ ) = valueWithPullback ( at: rnn , inputs) { rnn , inputs in
1177
- return rnn ( inputs)
1175
+ let ( outputs, pullback ) = valueWithPullback ( at: lstm , inputs) { lstm , inputs in
1176
+ return lstm ( inputs)
1178
1177
}
1179
1178
assertEqual (
1180
1179
outputs. map { $0. cell. squeezingShape ( at: 0 ) } [ 0 ] ,
@@ -1190,22 +1189,62 @@ final class LayerTests: XCTestCase {
1190
1189
[ 0.074910110 , 0.021107012 , - 0.049724963 , - 0.069670826 ] ,
1191
1190
[ 0.078670055 , 0.022462710 , - 0.051899005 , - 0.075331904 ] ] ,
1192
1191
accuracy: 1e-6 )
1192
+ let ( 𝛁lstm, _) = pullback ( . init( inputs. map { LSTMCell< Float> . State( cell: $0, hidden: $0) } ) )
1193
+ // TODO: Verify that LSTM gradients are correct using a reference implementation.
1194
+ XCTAssertEqual ( 𝛁lstm. cell. fusedWeight,
1195
+ [ [ 0.0 , 0.0 , 0.0 , 0.0 ,
1196
+ 0.0 , 0.0 , 0.0 , 0.0 ,
1197
+ 0.0 , 0.0 , 0.0 , 0.0 ,
1198
+ 0.0 , 0.0 , 0.0 , 0.0 ] ,
1199
+ [ 0.00012854872 , 0.0013978262 , - 0.0064465487 , - 0.011084668 ,
1200
+ 0.001252454 , 0.04924231 , 0.1023805 , 0.12028344 ,
1201
+ 3.0466243e-05 , 0.0006108698 , - 0.0027553777 , - 0.0048254076 ,
1202
+ 0.00011663328 , 0.0006076429 , - 0.0026212593 , - 0.003298801 ] ,
1203
+ [ 0.00025709745 , 0.0027956525 , - 0.0128930975 , - 0.022169337 ,
1204
+ 0.002504908 , 0.09848462 , 0.204761 , 0.24056688 ,
1205
+ 6.0932485e-05 , 0.0012217396 , - 0.0055107553 , - 0.009650815 ,
1206
+ 0.00023326656 , 0.0012152859 , - 0.0052425186 , - 0.006597602 ] ,
1207
+ [ 0.00038564618 , 0.0041934787 , - 0.019339647 , - 0.03325401 ,
1208
+ 0.003757362 , 0.14772694 , 0.3071415 , 0.36085027 ,
1209
+ 9.1398724e-05 , 0.0018326094 , - 0.008266133 , - 0.014476223 ,
1210
+ 0.00034989987 , 0.0018229289 , - 0.007863778 , - 0.009896403 ] ,
1211
+ [ 2.7438582e-05 , 0.00056287006 , - 0.0024641054 , - 0.004909771 ,
1212
+ 0.00028730888 , 0.019899525 , 0.0410647 , 0.050809838 ,
1213
+ 1.6388643e-05 , 0.0003807871 , - 0.0017060185 , - 0.0030680457 ,
1214
+ 3.7163307e-05 , 0.00029245956 , - 0.0012287574 , - 0.0018296391 ] ,
1215
+ [ 7.462907e-06 , 0.00015513944 , - 0.00067863404 , - 0.0013554879 ,
1216
+ 7.8164114e-05 , 0.0054846643 , 0.011315366 , 0.014021275 ,
1217
+ 4.4683666e-06 , 0.000105314364 , - 0.0004715426 , - 0.0008500715 ,
1218
+ 1.0132099e-05 , 8.078475e-05 , - 0.00033919656 , - 0.00050792634 ] ,
1219
+ [ - 1.818974e-05 , - 0.0003736046 , 0.0016354292 , 0.0032592756 ,
1220
+ - 0.00019047626 , - 0.013208302 , - 0.027256217 , - 0.033727698 ,
1221
+ - 1.0870848e-05 , - 0.00025284386 , 0.0011327572 , 0.002037401 ,
1222
+ - 2.465073e-05 , - 0.00019415902 , 0.000815714 , 0.0012148761 ] ,
1223
+ [ - 2.3125162e-05 , - 0.0004929221 , 0.0021531105 , 0.00431989 ,
1224
+ - 0.00024233271 , - 0.017425863 , - 0.035934873 , - 0.0446482 ,
1225
+ - 1.3914708e-05 , - 0.00033675073 , 0.0015061073 , 0.0027270648 ,
1226
+ - 3.15488e-05 , - 0.0002577127 , 0.001080812 , 0.0016348549 ] ] )
1227
+ XCTAssertEqual ( 𝛁lstm. cell. fusedBias,
1228
+ [ 0.0012854873 , 0.013978262 , - 0.06446548 , - 0.11084669 ,
1229
+ 0.01252454 , 0.49242306 , 1.023805 , 1.2028344 ,
1230
+ 0.0003046624 , 0.0061086984 , - 0.027553776 , - 0.048254073 ,
1231
+ 0.0011663327 , 0.006076429 , - 0.02621259 , - 0.032988008 ] )
1193
1232
}
1194
1233
}
1195
1234
}
1196
1235
1197
1236
func testGRU( ) {
1198
1237
let x = Tensor < Float > ( rangeFrom: 0.0 , to: 0.4 , stride: 0.1 ) . rankLifted ( )
1199
1238
let inputs : [ Tensor < Float > ] = Array ( repeating: x, count: 4 )
1200
- let rnn = RNN ( GRUCell < Float > (
1239
+ let gru = RNN ( GRUCell < Float > (
1201
1240
inputSize: 4 ,
1202
1241
hiddenSize: 4 ,
1203
1242
weightInitializer: glorotUniform ( seed: ( 0xFeed , 0xBeef ) ) ,
1204
1243
biasInitializer: zeros ( ) )
1205
1244
)
1206
1245
withTensorLeakChecking {
1207
- let ( outputs, _ ) = valueWithPullback ( at: rnn , inputs) { rnn , inputs in
1208
- return rnn ( inputs)
1246
+ let ( outputs, pullback ) = valueWithPullback ( at: gru , inputs) { gru , inputs in
1247
+ return gru ( inputs)
1209
1248
}
1210
1249
assertEqual (
1211
1250
outputs. map { $0. hidden } [ 0 ] ,
@@ -1214,6 +1253,23 @@ final class LayerTests: XCTestCase {
1214
1253
[ 0.2230835 , 0.2230835 , 0.2230835 , 0.2230835 ] ,
1215
1254
[ 0.2383619 , 0.2383619 , 0.2383619 , 0.2383619 ] ] ,
1216
1255
accuracy: 1e-5 )
1256
+ // TODO: Verify that GRU gradients are correct using a reference implementation.
1257
+ let ( 𝛁gru, _) = pullback ( . init( inputs. map { GRUCell< Float> . State( hidden: $0) } ) )
1258
+ XCTAssertEqual ( 𝛁gru. cell. updateWeight1,
1259
+ [ [ 0.0 ] , [ - 0.040293925 ] , [ - 0.08058785 ] , [ - 0.12088178 ] ] )
1260
+ XCTAssertEqual ( 𝛁gru. cell. updateWeight2,
1261
+ [ [ - 0.056792725 ] , [ - 0.056792725 ] , [ - 0.056792725 ] , [ - 0.056792725 ] ] )
1262
+ XCTAssertEqual ( 𝛁gru. cell. resetWeight1,
1263
+ [ [ 0.0 ] , [ 0.0039126356 ] , [ 0.007825271 ] , [ 0.011737906 ] ] )
1264
+ XCTAssertEqual ( 𝛁gru. cell. resetWeight2,
1265
+ [ [ 0.0069182813 ] , [ 0.0069182813 ] , [ 0.0069182813 ] , [ 0.0069182813 ] ] )
1266
+ XCTAssertEqual ( 𝛁gru. cell. outputWeight1,
1267
+ [ [ 0.0 ] , [ 0.1221647 ] , [ 0.2443294 ] , [ 0.3664941 ] ] )
1268
+ XCTAssertEqual ( 𝛁gru. cell. outputWeight2,
1269
+ [ [ 0.08078343 ] , [ 0.08078343 ] , [ 0.08078343 ] , [ 0.08078343 ] ] )
1270
+ XCTAssertEqual ( 𝛁gru. cell. updateBias, [ - 0.016739635 , - 0.04493352 , - 0.13216142 , - 0.20910467 ] )
1271
+ XCTAssertEqual ( 𝛁gru. cell. resetBias, [ 0.023218961 , - 0.024303729 , 0.010057628 , 0.030153492 ] )
1272
+ XCTAssertEqual ( 𝛁gru. cell. outputBias, [ 0.06667276 , 0.115095116 , 0.39864573 , 0.6412333 ] )
1217
1273
}
1218
1274
}
1219
1275
0 commit comments