Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 2bdb02f

Browse files
jon-towrxwei
authored andcommitted
[Layers] Update DepthwiseConv2D(filterShape:...) bias initialization (#441)
**Issue**: `DepthwiseConv2D(filterShape:strides:padding:activation:filterInitializer:...)` produces a shape mismatch error in the conv-bias sum of `callAsFunction()` when initialized under a `filterShape` with channel-multiplier greater than `1`. More specifically, this is caused by initializing the `bias` vector with dimension equal to said channel-multiplier. **Fix** :This PR adds the proper `bias` vector initialization by computing the dimensionality as: `bias vector dimension = output channel count = (input channel count * channel multiplier)`
1 parent f0a6da6 commit 2bdb02f

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed

Sources/TensorFlow/Layers/Convolutional.swift

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ public struct Conv1D<Scalar: TensorFlowFloatingPoint>: Layer {
7070
///
7171
/// and padding size is determined by the padding scheme.
7272
///
73-
/// - Parameter input: The input to the layer [batch count, input width, input channel count].
74-
/// - Returns: The output of shape [batch count, output width, output channel count].
73+
/// - Parameter input: The input to the layer [batch size, input width, input channel count].
74+
/// - Returns: The output of shape [batch size, output width, output channel count].
7575
///
7676
/// - Note: Padding size equals zero when using `.valid`.
7777
@differentiable
@@ -186,7 +186,7 @@ public struct Conv2D<Scalar: TensorFlowFloatingPoint>: Layer {
186186
/// and padding sizes are determined by the padding scheme.
187187
///
188188
/// - Parameter input: The input to the layer of shape
189-
/// [batch count, input height, input width, input channel count].
189+
/// [batch size, input height, input width, input channel count].
190190
/// - Returns: The output of shape
191191
/// [batch count, output height, output width, output channel count].
192192
///
@@ -495,8 +495,10 @@ public struct DepthwiseConv2D<Scalar: TensorFlowFloatingPoint>: Layer {
495495

496496
/// Returns the output obtained from applying the layer to the given input.
497497
///
498-
/// - Parameter input: The input to the layer.
499-
/// - Returns: The output.
498+
/// - Parameter input: The input to the layer of shape,
499+
/// [batch count, input height, input width, input channel count]
500+
/// - Returns: The output of shape,
501+
/// [batch count, output height, output width, input channel count * channel multiplier]
500502
@differentiable
501503
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
502504
return activation(depthwiseConv2D(
@@ -512,7 +514,8 @@ public extension DepthwiseConv2D {
512514
/// element-wise activation function.
513515
///
514516
/// - Parameters:
515-
/// - filterShape: The shape of the 4-D convolution kernel.
517+
/// - filterShape: The shape of the 4-D convolution kernel with form,
518+
/// [filter width, filter height, input channel count, channel multiplier].
516519
/// - strides: The strides of the sliding window for spatial/spatio-temporal dimensions.
517520
/// - padding: The padding algorithm for convolution.
518521
/// - activation: The element-wise activation function.
@@ -530,7 +533,7 @@ public extension DepthwiseConv2D {
530533
filterShape.0, filterShape.1, filterShape.2, filterShape.3])
531534
self.init(
532535
filter: filterInitializer(filterTensorShape),
533-
bias: biasInitializer([filterShape.3]),
536+
bias: biasInitializer([filterShape.2 * filterShape.3]),
534537
activation: activation,
535538
strides: strides,
536539
padding: padding)

Tests/TensorFlowTests/LayerTests.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,14 @@ final class LayerTests: XCTestCase {
209209
scalars: [9, 12, 23, 28, 25, 36, 55, 68, 41, 60, 87, 108,
210210
57, 84, 119, 148])
211211
XCTAssertEqual(output, expected)
212+
213+
let channelMultiplier = 4
214+
let multiplierLayer = DepthwiseConv2D<Float>(
215+
filterShape: (2, 2, input.shape[3], channelMultiplier),
216+
filterInitializer: glorotUniform(),
217+
biasInitializer: zeros())
218+
let multiplierOutput = multiplierLayer.inferring(from: input)
219+
XCTAssertEqual(multiplierOutput.shape[3], input.shape[3] * channelMultiplier)
212220
}
213221

214222
func testDepthwiseConv2DGradient() {

0 commit comments

Comments
 (0)