Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit d4d90fb

Browse files
authored
Fix compilation errors. (#125)
- Fix VJPs for `maxPooled3D` and `averagePooled3D`. - Temporarily, fix tests to use `SimpleRNNCell.State`.
1 parent 76f5ee6 commit d4d90fb

File tree

3 files changed

+23
-21
lines changed

3 files changed

+23
-21
lines changed

Sources/DeepLearning/Layer.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1420,7 +1420,7 @@ public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell, VectorNum
14201420

14211421
// TODO(TF-507): Revert to `typealias State = Tensor<Scalar>` after
14221422
// SR-10697 is fixed.
1423-
public struct State: Differentiable {
1423+
public struct State: Equatable, Differentiable, VectorNumeric, KeyPathIterable {
14241424
public let value: Tensor<Scalar>
14251425
public init(_ value: Tensor<Scalar>) {
14261426
self.value = value

Sources/DeepLearning/Operators.swift

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -262,12 +262,10 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
262262
origInput: self,
263263
origOutput: value,
264264
grad: v,
265-
ksize: Tensor<Int32>([Int32(kernelSize.0), Int32(kernelSize.1),
266-
Int32(kernelSize.2), Int32(kernelSize.3),
267-
Int32(kernelSize.4)]),
268-
strides: Tensor<Int32>([Int32(strides.0), Int32(strides.1),
269-
Int32(strides.2), Int32(strides.3),
270-
Int32(strides.4)]),
265+
ksize: [Int32(kernelSize.0), Int32(kernelSize.1), Int32(kernelSize.2),
266+
Int32(kernelSize.3), Int32(kernelSize.4)],
267+
strides: [Int32(strides.0), Int32(strides.1), Int32(strides.2), Int32(strides.3),
268+
Int32(strides.4)],
271269
padding: padding.raw
272270
)
273271
})
@@ -308,12 +306,10 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
308306
return Raw.avgPool3DGrad(
309307
origInputShape: self.shapeTensor,
310308
grad: v,
311-
ksize: Tensor<Int32>([Int32(kernelSize.0), Int32(kernelSize.1),
312-
Int32(kernelSize.2), Int32(kernelSize.3),
313-
Int32(kernelSize.4)]),
314-
strides: Tensor<Int32>([Int32(strides.0), Int32(strides.1),
315-
Int32(strides.2), Int32(strides.3),
316-
Int32(strides.4)]),
309+
ksize: [Int32(kernelSize.0), Int32(kernelSize.1), Int32(kernelSize.2),
310+
Int32(kernelSize.3), Int32(kernelSize.4)],
311+
strides: [Int32(strides.0), Int32(strides.1), Int32(strides.2), Int32(strides.3),
312+
Int32(strides.4)],
317313
padding: padding.raw
318314
)
319315
})

Tests/DeepLearningTests/LayerTests.swift

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ final class LayerTests: XCTestCase {
6767
}
6868

6969
func testAvgPool3D() {
70-
let layer = AvgPool3D<Float>(poolSize: (2, 4, 5), stride: (1, 1, 1), padding: .valid)
70+
let layer = AvgPool3D<Float>(poolSize: (2, 4, 5), strides: (1, 1, 1), padding: .valid)
7171
let input = Tensor(shape: [1, 2, 4, 5, 1], scalars: (0..<20).map(Float.init))
7272
let output = layer.inferring(from: input)
7373
let expected = Tensor<Float>([[[[[9.5]]]]])
@@ -147,13 +147,18 @@ final class LayerTests: XCTestCase {
147147
var cell = SimpleRNNCell<Float>(inputSize: 2, hiddenSize: 5)
148148
cell.weight = weight
149149
cell.bias = bias
150-
let state = Tensor<Float>(ones: [1, 5]) * Tensor<Float>([1, 0.2, 0.5, 2, 0.6])
150+
let state = SimpleRNNCell.State(
151+
Tensor<Float>(ones: [1, 5]) * Tensor<Float>([1, 0.2, 0.5, 2, 0.6])
152+
)
151153
let input = Tensor<Float>(ones: [1, 2]) * Tensor<Float>([0.3, 0.7])
152154
let output = cell(input: input, state: state).state
153-
let expected = Tensor<Float>([[2.76649, 6.2999997, 2.76649, 6.2999997, 2.76649]])
155+
let expected = SimpleRNNCell.State(
156+
Tensor<Float>([[2.76649, 6.2999997, 2.76649, 6.2999997, 2.76649]])
157+
)
154158
XCTAssertEqual(output, expected)
155159
}
156160

161+
// TODO(TF-507): Remove references to `SimpleRNNCell.State` after SR-10697 is fixed.
157162
func testRNN() {
158163
let x = Tensor<Float>(rangeFrom: 0.0, to: 0.4, stride: 0.1).rankLifted()
159164
let inputs: [Tensor<Float>] = Array(repeating: x, count: 4)
@@ -162,11 +167,12 @@ final class LayerTests: XCTestCase {
162167
let (outputs, pullback) = rnn.valueWithPullback(at: inputs) { rnn, inputs in
163168
return rnn(inputs)
164169
}
165-
XCTAssertEqual(outputs, [[[-0.0026294366, -0.0058668107, 0.04495003, 0.20311214]],
166-
[[ 0.06788494, 0.050665878, 0.02415526, 0.09249911]],
167-
[[ 0.06621192, 0.009049267, 0.065047316, 0.11534518]],
168-
[[ 0.05612204, 0.00022032857, 0.05407162, 0.09784105]]])
169-
let (𝛁rnn, 𝛁inputs) = pullback(.init(inputs))
170+
XCTAssertEqual(outputs.map { $0.value },
171+
[[[-0.0026294366, -0.0058668107, 0.04495003, 0.20311214]],
172+
[[ 0.06788494, 0.050665878, 0.02415526, 0.09249911]],
173+
[[ 0.06621192, 0.009049267, 0.065047316, 0.11534518]],
174+
[[ 0.05612204, 0.00022032857, 0.05407162, 0.09784105]]])
175+
let (𝛁rnn, 𝛁inputs) = pullback(.init(inputs.map { SimpleRNNCell<Float>.State($0) }))
170176
XCTAssertEqual(𝛁rnn.cell.weight,
171177
[[ 0.0, 0.0, 0.0, 0.0],
172178
[-0.0051278225, 0.0013102926, 0.00740262, 0.018119661],

0 commit comments

Comments
 (0)