Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 1702f52

Browse files
Ricardo Ocamposaeta
authored andcommitted
Extract conv1D functionality from Conv1D Layer (#549)
1 parent 786e774 commit 1702f52

File tree

2 files changed

+34
-6
lines changed

2 files changed

+34
-6
lines changed

Sources/TensorFlow/Layers/Convolutional.swift

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,12 @@ public struct Conv1D<Scalar: TensorFlowFloatingPoint>: Layer {
7676
/// - Note: Padding size equals zero when using `.valid`.
7777
@differentiable
7878
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
79-
let conv = conv2D(
80-
input.expandingShape(at: 1),
81-
filter: filter.expandingShape(at: 0),
82-
strides: (1, 1, stride, 1),
79+
activation(conv1D(
80+
input,
81+
filter: filter,
82+
stride: stride,
8383
padding: padding,
84-
dilations: (1, 1, dilation, 1))
85-
return activation(conv.squeezingShape(at: 1) + bias)
84+
dilation: dilation) + bias)
8685
}
8786
}
8887

Sources/TensorFlow/Operators/NN.swift

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,35 @@ public extension Padding {
7272
}
7373
}
7474

75+
/// Returns a 1-D convolution with the specified input, filter, stride, and padding.
76+
///
77+
/// - Parameters:
78+
/// - input: The input.
79+
/// - filter: The convolution filter.
80+
/// - stride: The stride of the sliding filter.
81+
/// - padding: The padding for the operation.
82+
/// - dilation: The dilation factor.
83+
/// - Precondition: `input` must have rank `3`.
84+
/// - Precondition: `filter` must have rank 3.
85+
@differentiable(wrt: (input, filter))
86+
public func conv1D<Scalar: TensorFlowFloatingPoint>(
87+
_ input: Tensor<Scalar>,
88+
filter: Tensor<Scalar>,
89+
stride: Int = 1,
90+
padding: Padding = .valid,
91+
dilation: Int = 1
92+
) -> Tensor<Scalar> {
93+
precondition(input.shape.rank == 3, "The input must have rank 3.")
94+
precondition(filter.shape.rank == 3, "The filter must have rank 3.")
95+
return conv2D(
96+
input.expandingShape(at: 1),
97+
filter: filter.expandingShape(at: 0),
98+
strides: (1, 1, stride, 1),
99+
padding: padding,
100+
dilations: (1, 1, dilation, 1)
101+
).squeezingShape(at: 1)
102+
}
103+
75104
/// Returns a 2-D convolution with the specified input, filter, strides, and padding.
76105
///
77106
/// - Parameters:

0 commit comments

Comments
 (0)