Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit ed5e615

Browse files
jon-toweaplatanios
authored andcommitted
Add support for dilated Conv1D and Conv2D (#275)
* Add support for dilated conv1d and conv2d * Remove accidental whitespaces. * Add line breaks before parameter listing * Fix argument indentation
1 parent a1a5406 commit ed5e615

File tree

3 files changed

+132
-38
lines changed

3 files changed

+132
-38
lines changed

Sources/TensorFlow/Layers/Convolutional.swift

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ public struct Conv1D<Scalar: TensorFlowFloatingPoint>: Layer {
3030
@noDerivative public let stride: Int
3131
/// The padding algorithm for convolution.
3232
@noDerivative public let padding: Padding
33-
33+
/// The dilation factor for temporal dimension.
34+
@noDerivative public let dilation: Int
35+
3436
/// Creates a `Conv1D` layer with the specified filter, bias, activation function, stride, and
3537
/// padding.
3638
///
@@ -40,18 +42,21 @@ public struct Conv1D<Scalar: TensorFlowFloatingPoint>: Layer {
4042
/// - activation: The element-wise activation function.
4143
/// - stride: The stride of the sliding window for temporal dimension.
4244
/// - padding: The padding algorithm for convolution.
45+
/// - dilation: The dilation factor for temporal dimension.
4346
public init(
4447
filter: Tensor<Scalar>,
4548
bias: Tensor<Scalar>,
4649
activation: @escaping Activation,
4750
stride: Int,
48-
padding: Padding
51+
padding: Padding,
52+
dilation: Int
4953
) {
5054
self.filter = filter
5155
self.bias = bias
5256
self.activation = activation
5357
self.stride = stride
5458
self.padding = padding
59+
self.dilation = dilation
5560
}
5661

5762
/// Returns the output obtained from applying the layer to the given input.
@@ -60,8 +65,12 @@ public struct Conv1D<Scalar: TensorFlowFloatingPoint>: Layer {
6065
/// - Returns: The output `[batchCount, newWidth, outputChannels]`.
6166
@differentiable
6267
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
63-
let conv = conv2D(input.expandingShape(at: 1), filter: filter.expandingShape(at: 0),
64-
strides: (1, 1, stride, 1), padding: padding)
68+
let conv = conv2D(
69+
input.expandingShape(at: 1),
70+
filter: filter.expandingShape(at: 0),
71+
strides: (1, 1, stride, 1),
72+
padding: padding,
73+
dilations: (1, 1, dilation, 1))
6574
return activation(conv.squeezingShape(at: 1) + bias)
6675
}
6776
}
@@ -76,15 +85,17 @@ public extension Conv1D where Scalar.RawSignificand: FixedWidthInteger {
7685
/// `[width, inputChannels, outputChannels]`.
7786
/// - stride: The stride of the sliding window for temporal dimension.
7887
/// - padding: The padding algorithm for convolution.
88+
/// - dilation: The dilation factor for temporal dimension.
7989
/// - activation: The element-wise activation function.
8090
/// - generator: The random number generator for initialization.
8191
///
82-
/// - Note: Use `init(filterShape:stride:padding:activation:seed:)` for faster random
92+
/// - Note: Use `init(filterShape:stride:padding:dilation:activation:seed:)` for faster random
8393
/// initialization.
8494
init<G: RandomNumberGenerator>(
8595
filterShape: (Int, Int, Int),
8696
stride: Int = 1,
8797
padding: Padding = .valid,
98+
dilation: Int = 1,
8899
activation: @escaping Activation = identity,
89100
generator: inout G
90101
) {
@@ -95,7 +106,8 @@ public extension Conv1D where Scalar.RawSignificand: FixedWidthInteger {
95106
bias: Tensor(zeros: [filterShape.2]),
96107
activation: activation,
97108
stride: stride,
98-
padding: padding)
109+
padding: padding,
110+
dilation: dilation)
99111
}
100112
}
101113

@@ -109,12 +121,14 @@ public extension Conv1D {
109121
/// `[width, inputChannels, outputChannels]`.
110122
/// - stride: The stride of the sliding window for temporal dimension.
111123
/// - padding: The padding algorithm for convolution.
124+
/// - dilation: The dilation factor for the temporal dimension.
112125
/// - activation: The element-wise activation function.
113126
/// - seed: The random seed for initialization. The default value is random.
114127
init(
115128
filterShape: (Int, Int, Int),
116129
stride: Int = 1,
117130
padding: Padding = .valid,
131+
dilation: Int = 1,
118132
activation: @escaping Activation = identity,
119133
seed: (Int32, Int32) = (Int32.random(in: Int32.min..<Int32.max),
120134
Int32.random(in: Int32.min..<Int32.max))
@@ -126,7 +140,8 @@ public extension Conv1D {
126140
bias: Tensor(zeros: [filterShape.2]),
127141
activation: activation,
128142
stride: stride,
129-
padding: padding)
143+
padding: padding,
144+
dilation: dilation)
130145
}
131146
}
132147

@@ -148,7 +163,9 @@ public struct Conv2D<Scalar: TensorFlowFloatingPoint>: Layer {
148163
@noDerivative public let strides: (Int, Int)
149164
/// The padding algorithm for convolution.
150165
@noDerivative public let padding: Padding
151-
166+
/// The dilation factor for spatials dimensions.
167+
@noDerivative public let dilations: (Int, Int)
168+
152169
/// Creates a `Conv2D` layer with the specified filter, bias, activation function, strides, and
153170
/// padding.
154171
///
@@ -158,18 +175,21 @@ public struct Conv2D<Scalar: TensorFlowFloatingPoint>: Layer {
158175
/// - activation: The element-wise activation function.
159176
/// - strides: The strides of the sliding window for spatial dimensions.
160177
/// - padding: The padding algorithm for convolution.
178+
/// - dilations: The dilation factor for spatials dimensions.
161179
public init(
162180
filter: Tensor<Scalar>,
163181
bias: Tensor<Scalar>,
164182
activation: @escaping Activation,
165183
strides: (Int, Int),
166-
padding: Padding
184+
padding: Padding,
185+
dilations: (Int, Int)
167186
) {
168187
self.filter = filter
169188
self.bias = bias
170189
self.activation = activation
171190
self.strides = strides
172191
self.padding = padding
192+
self.dilations = dilations
173193
}
174194

175195
/// Returns the output obtained from applying the layer to the given input.
@@ -178,8 +198,12 @@ public struct Conv2D<Scalar: TensorFlowFloatingPoint>: Layer {
178198
/// - Returns: The output.
179199
@differentiable
180200
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
181-
return activation(conv2D(input, filter: filter, strides: (1, strides.0, strides.1, 1),
182-
padding: padding) + bias)
201+
return activation(conv2D(
202+
input,
203+
filter: filter,
204+
strides: (1, strides.0, strides.1, 1),
205+
padding: padding,
206+
dilations: (1, dilations.0, dilations.1, 1)) + bias)
183207
}
184208
}
185209

@@ -192,6 +216,7 @@ public extension Conv2D {
192216
/// - filterShape: The shape of the 4-D convolution kernel.
193217
/// - strides: The strides of the sliding window for spatial dimensions.
194218
/// - padding: The padding algorithm for convolution.
219+
/// - dilations: The dilation factor for spatial dimensions.
195220
/// - activation: The element-wise activation function.
196221
/// - generator: The random number generator for initialization.
197222
///
@@ -201,6 +226,7 @@ public extension Conv2D {
201226
filterShape: (Int, Int, Int, Int),
202227
strides: (Int, Int) = (1, 1),
203228
padding: Padding = .valid,
229+
dilations: (Int, Int) = (1, 1),
204230
activation: @escaping Activation = identity,
205231
generator: inout G
206232
) {
@@ -211,7 +237,8 @@ public extension Conv2D {
211237
bias: Tensor(zeros: [filterShape.3]),
212238
activation: activation,
213239
strides: strides,
214-
padding: padding)
240+
padding: padding,
241+
dilations: dilations)
215242
}
216243
}
217244

@@ -224,12 +251,14 @@ public extension Conv2D {
224251
/// - filterShape: The shape of the 4-D convolution kernel.
225252
/// - strides: The strides of the sliding window for spatial dimensions.
226253
/// - padding: The padding algorithm for convolution.
254+
/// - dilations: The dilation factor for spatial dimensions.
227255
/// - activation: The element-wise activation function.
228256
/// - seed: The random seed for initialization. The default value is random.
229257
init(
230258
filterShape: (Int, Int, Int, Int),
231259
strides: (Int, Int) = (1, 1),
232260
padding: Padding = .valid,
261+
dilations: (Int, Int) = (1, 1),
233262
activation: @escaping Activation = identity,
234263
seed: (Int32, Int32) = (Int32.random(in: Int32.min..<Int32.max),
235264
Int32.random(in: Int32.min..<Int32.max))
@@ -241,7 +270,8 @@ public extension Conv2D {
241270
bias: Tensor(zeros: [filterShape.3]),
242271
activation: activation,
243272
strides: strides,
244-
padding: padding)
273+
padding: padding,
274+
dilations: dilations)
245275
}
246276
}
247277

Sources/TensorFlow/Operators/NN.swift

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,17 @@ func conv2DBackpropInput<Scalar: TensorFlowFloatingPoint>(
116116
shape: Tensor<Int32>,
117117
filter: Tensor<Scalar>,
118118
strides: (Int, Int, Int, Int),
119-
padding: Padding
119+
padding: Padding,
120+
dilations: (Int, Int, Int, Int) = (1, 1, 1, 1)
120121
) -> Tensor<Scalar> {
121122
return Raw.conv2DBackpropInput(
122123
inputSizes: shape,
123124
filter: filter,
124125
outBackprop: x,
125126
strides: [Int32(strides.0), Int32(strides.1), Int32(strides.2), Int32(strides.3)],
126127
padding: padding.raw2,
127-
explicitPaddings: [])
128+
explicitPaddings: [],
129+
dilations: [Int32(dilations.0), Int32(dilations.1), Int32(dilations.2), Int32(dilations.3)])
128130
}
129131

130132
/// TensorFlow builtin conv2d gradient helper for the filter.
@@ -135,15 +137,17 @@ func conv2DBackpropFilter<Scalar: TensorFlowFloatingPoint>(
135137
input: Tensor<Scalar>,
136138
filterSizes: Tensor<Int32>,
137139
strides: (Int, Int, Int, Int),
138-
padding: Padding
140+
padding: Padding,
141+
dilations: (Int, Int, Int, Int)
139142
) -> Tensor<Scalar> {
140143
return Raw.conv2DBackpropFilter(
141144
input,
142145
filterSizes: filterSizes,
143146
outBackprop: x,
144147
strides: [Int32(strides.0), Int32(strides.1), Int32(strides.2), Int32(strides.3)],
145148
padding: padding.raw2,
146-
explicitPaddings: [])
149+
explicitPaddings: [],
150+
dilations: [Int32(dilations.0), Int32(dilations.1), Int32(dilations.2), Int32(dilations.3)])
147151
}
148152

149153
@usableFromInline
@@ -152,13 +156,15 @@ func _vjpConv2DBackpropInput<Scalar: TensorFlowFloatingPoint>(
152156
_ shape: Tensor<Int32>,
153157
_ filter: Tensor<Scalar>,
154158
_ strides: (Int, Int, Int, Int),
155-
_ padding: Padding
159+
_ padding: Padding,
160+
_ dilations: (Int, Int, Int, Int)
156161
) -> (Tensor<Scalar>, (Tensor<Scalar>) -> (Tensor<Scalar>, Tensor<Scalar>)) {
157162
let value = conv2DBackpropInput(x, shape: shape, filter: filter,
158-
strides: strides, padding: padding)
163+
strides: strides, padding: padding, dilations: dilations)
159164
return (value, { v in
160-
(conv2DBackpropFilter(x, input: v, filterSizes: shape, strides: strides, padding: padding),
161-
conv2D(v, filter: filter, strides: strides, padding: padding))
165+
(conv2DBackpropFilter(x, input: v, filterSizes: shape, strides: strides,
166+
padding: padding, dilations: dilations),
167+
conv2D(v, filter: filter, strides: strides, padding: padding, dilations: dilations))
162168
})
163169
}
164170

@@ -168,13 +174,15 @@ func _vjpConv2DBackpropFilter<Scalar: TensorFlowFloatingPoint>(
168174
_ input: Tensor<Scalar>,
169175
_ filterSizes: Tensor<Int32>,
170176
_ strides: (Int, Int, Int, Int),
171-
_ padding: Padding
177+
_ padding: Padding,
178+
_ dilations: (Int, Int, Int, Int)
172179
) -> (Tensor<Scalar>, (Tensor<Scalar>) -> (Tensor<Scalar>, Tensor<Scalar>)) {
173180
let value = conv2DBackpropFilter(x, input: input, filterSizes: filterSizes,
174-
strides: strides, padding: padding)
181+
strides: strides, padding: padding, dilations: dilations)
175182
return (value, { v in
176-
(conv2DBackpropInput(x, shape: filterSizes, filter: v, strides: strides, padding: padding),
177-
conv2D(input, filter: v, strides: strides, padding: padding))
183+
(conv2DBackpropInput(x, shape: filterSizes, filter: v, strides: strides,
184+
padding: padding, dilations: dilations),
185+
conv2D(input, filter: v, strides: strides, padding: padding, dilations: dilations))
178186
})
179187
}
180188

@@ -183,14 +191,15 @@ func _vjpConv2D<Scalar: TensorFlowFloatingPoint>(
183191
_ x: Tensor<Scalar>,
184192
filter: Tensor<Scalar>,
185193
strides: (Int, Int, Int, Int),
186-
padding: Padding
194+
padding: Padding,
195+
dilations: (Int, Int, Int, Int)
187196
) -> (Tensor<Scalar>, (Tensor<Scalar>) -> (Tensor<Scalar>, Tensor<Scalar>)) {
188-
let value = conv2D(x, filter: filter, strides: strides, padding: padding)
197+
let value = conv2D(x, filter: filter, strides: strides, padding: padding, dilations: dilations)
189198
return (value, { v in
190199
(conv2DBackpropInput(v, shape: x.shapeTensor, filter: filter,
191-
strides: strides, padding: padding),
200+
strides: strides, padding: padding, dilations: dilations),
192201
conv2DBackpropFilter(v, input: x, filterSizes: filter.shapeTensor,
193-
strides: strides, padding: padding))
202+
strides: strides, padding: padding, dilations: dilations))
194203
})
195204
}
196205

@@ -282,11 +291,9 @@ func _vjpConv3D<Scalar: TensorFlowFloatingPoint>(
282291
return (value, { v in
283292
return (
284293
conv3DBackpropInput(v, shape: x.shapeTensor, filter: filter,
285-
strides: strides, padding: padding
286-
),
294+
strides: strides, padding: padding),
287295
conv3DBackpropFilter(v, input: x, filterSizes: filter.shapeTensor,
288-
strides: strides, padding: padding
289-
)
296+
strides: strides, padding: padding)
290297
)
291298
})
292299
}
@@ -485,22 +492,26 @@ func _vjpAvgPool3D<Scalar: TensorFlowFloatingPoint>(
485492
/// - input: The input.
486493
/// - filter: The convolution filter.
487494
/// - strides: The strides of the sliding filter for each dimension of the input.
488-
/// - padding: The padding for the operation.
495+
/// - padding: The padding for the operation
496+
/// - dilations: The dilation factor for each dimension of the input.
489497
/// - Precondition: `input` must have rank `4`.
490498
/// - Precondition: `filter` must have rank 4.
491499
@differentiable(wrt: (input, filter), vjp: _vjpConv2D)
492500
public func conv2D<Scalar: TensorFlowFloatingPoint>(
493501
_ input: Tensor<Scalar>,
494502
filter: Tensor<Scalar>,
495503
strides: (Int, Int, Int, Int),
496-
padding: Padding
504+
padding: Padding,
505+
dilations: (Int, Int, Int, Int)
497506
) -> Tensor<Scalar> {
498507
return Raw.conv2D(
499508
input,
500509
filter: filter,
501510
strides: [Int32(strides.0), Int32(strides.1), Int32(strides.2), Int32(strides.3)],
502511
padding: padding.raw2,
503-
explicitPaddings: [])
512+
explicitPaddings: [],
513+
dilations: [Int32(dilations.0), Int32(dilations.1), Int32(dilations.2), Int32(dilations.3)]
514+
)
504515
}
505516

506517
/// Returns a 3-D convolution with the specified input, filter, strides, and padding.

0 commit comments

Comments
 (0)