@@ -67,7 +67,7 @@ final class LayerTests: XCTestCase {
67
67
}
68
68
69
69
func testAvgPool3D( ) {
70
- let layer = AvgPool3D < Float > ( poolSize: ( 2 , 4 , 5 ) , stride : ( 1 , 1 , 1 ) , padding: . valid)
70
+ let layer = AvgPool3D < Float > ( poolSize: ( 2 , 4 , 5 ) , strides : ( 1 , 1 , 1 ) , padding: . valid)
71
71
let input = Tensor ( shape: [ 1 , 2 , 4 , 5 , 1 ] , scalars: ( 0 ..< 20 ) . map ( Float . init) )
72
72
let output = layer. inferring ( from: input)
73
73
let expected = Tensor < Float > ( [ [ [ [ [ 9.5 ] ] ] ] ] )
@@ -147,13 +147,18 @@ final class LayerTests: XCTestCase {
147
147
var cell = SimpleRNNCell < Float > ( inputSize: 2 , hiddenSize: 5 )
148
148
cell. weight = weight
149
149
cell. bias = bias
150
- let state = Tensor < Float > ( ones: [ 1 , 5 ] ) * Tensor < Float > ( [ 1 , 0.2 , 0.5 , 2 , 0.6 ] )
150
+ let state = SimpleRNNCell . State (
151
+ Tensor < Float > ( ones: [ 1 , 5 ] ) * Tensor < Float > ( [ 1 , 0.2 , 0.5 , 2 , 0.6 ] )
152
+ )
151
153
let input = Tensor < Float > ( ones: [ 1 , 2 ] ) * Tensor < Float > ( [ 0.3 , 0.7 ] )
152
154
let output = cell ( input: input, state: state) . state
153
- let expected = Tensor < Float > ( [ [ 2.76649 , 6.2999997 , 2.76649 , 6.2999997 , 2.76649 ] ] )
155
+ let expected = SimpleRNNCell . State (
156
+ Tensor < Float > ( [ [ 2.76649 , 6.2999997 , 2.76649 , 6.2999997 , 2.76649 ] ] )
157
+ )
154
158
XCTAssertEqual ( output, expected)
155
159
}
156
160
161
+ // TODO(TF-507): Remove references to `SimpleRNNCell.State` after SR-10697 is fixed.
157
162
func testRNN( ) {
158
163
let x = Tensor < Float > ( rangeFrom: 0.0 , to: 0.4 , stride: 0.1 ) . rankLifted ( )
159
164
let inputs : [ Tensor < Float > ] = Array ( repeating: x, count: 4 )
@@ -162,11 +167,12 @@ final class LayerTests: XCTestCase {
162
167
let ( outputs, pullback) = rnn. valueWithPullback ( at: inputs) { rnn, inputs in
163
168
return rnn ( inputs)
164
169
}
165
- XCTAssertEqual ( outputs, [ [ [ - 0.0026294366 , - 0.0058668107 , 0.04495003 , 0.20311214 ] ] ,
166
- [ [ 0.06788494 , 0.050665878 , 0.02415526 , 0.09249911 ] ] ,
167
- [ [ 0.06621192 , 0.009049267 , 0.065047316 , 0.11534518 ] ] ,
168
- [ [ 0.05612204 , 0.00022032857 , 0.05407162 , 0.09784105 ] ] ] )
169
- let ( 𝛁rnn, 𝛁inputs) = pullback( . init( inputs) )
170
+ XCTAssertEqual ( outputs. map { $0. value } ,
171
+ [ [ [ - 0.0026294366 , - 0.0058668107 , 0.04495003 , 0.20311214 ] ] ,
172
+ [ [ 0.06788494 , 0.050665878 , 0.02415526 , 0.09249911 ] ] ,
173
+ [ [ 0.06621192 , 0.009049267 , 0.065047316 , 0.11534518 ] ] ,
174
+ [ [ 0.05612204 , 0.00022032857 , 0.05407162 , 0.09784105 ] ] ] )
175
+ let ( 𝛁rnn, 𝛁inputs) = pullback( . init( inputs. map { SimpleRNNCell< Float> . State( $0) } ) )
170
176
XCTAssertEqual ( 𝛁rnn. cell. weight,
171
177
[ [ 0.0 , 0.0 , 0.0 , 0.0 ] ,
172
178
[ - 0.0051278225 , 0.0013102926 , 0.00740262 , 0.018119661 ] ,
0 commit comments