Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit dfcc60d

Browse files
committed
Adding test and review changes
1 parent 31c28ff commit dfcc60d

File tree

2 files changed

+27
-6
lines changed

2 files changed

+27
-6
lines changed

Sources/TensorFlow/Layers/Convolutional.swift

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -686,15 +686,17 @@ public struct SeparableConv2D<Scalar: TensorFlowFloatingPoint>: Layer {
686686
/// - Returns: The output.
687687
@differentiable
688688
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
689-
depthwise = depthwiseConv2D( input,
690-
filter: depthwiseFilter,
691-
strides: (1, strides.0, strides.1, 1),
692-
padding: padding)
693-
return activation(Conv2D(
689+
let depthwise = depthwiseConv2D(
690+
input,
691+
filter: depthwiseFilter,
692+
strides: (1, strides.0, strides.1, 1),
693+
padding: padding)
694+
return activation(conv2D(
694695
depthwise,
695696
filter: pointwiseFilter,
696697
strides: (1, 1, 1, 1),
697-
padding: padding) + bias)
698+
padding: padding,
699+
dilations: (1, 1, 1, 1)) + bias)
698700
}
699701
}
700702

Tests/TensorFlowTests/LayerTests.swift

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,24 @@ final class LayerTests: XCTestCase {
155155
XCTAssertEqual(output, expected)
156156
}
157157

158+
func testSeparableConv2D() {
159+
let depthwiseFilter = Tensor(shape: [2, 2, 2, 2], scalars: (0..<16).map(Float.init))
160+
let pointwiseFilter = Tensor(shape: [1, 1, 4, 1], scalars: (0..<4).map(Float.init))
161+
let bias = Tensor<Float>([4])
162+
let layer = SeparableConv2D<Float>(depthwiseFilter: depthwiseFilter,
163+
pointwiseFilter: pointwiseFilter,
164+
bias: bias,
165+
activation: identity,
166+
strides: (2, 2),
167+
padding: .valid)
168+
let input = Tensor(shape: [2, 2, 2, 2], scalars: (0..<16).map(Float.init))
169+
let output = layer.inferring(from: input)
170+
let expected = Tensor<Float>(shape: [2, 1, 1, 1],
171+
scalars: [1016, 2616])
172+
XCTAssertEqual(output, expected)
173+
}
174+
175+
158176
func testZeroPadding1D() {
159177
let input = Tensor<Float>([0.0, 1.0, 2.0])
160178
let layer = ZeroPadding1D<Float>(padding: 2)
@@ -397,6 +415,7 @@ final class LayerTests: XCTestCase {
397415
("testConv2DDilation", testConv2DDilation),
398416
("testConv3D", testConv3D),
399417
("testDepthConv2D", testDepthConv2D),
418+
("testSeparableConv2D", testSeparableConv2D),
400419
("testZeroPadding1D", testZeroPadding1D),
401420
("testZeroPadding2D", testZeroPadding2D),
402421
("testZeroPadding3D", testZeroPadding3D),

0 commit comments

Comments
 (0)