Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit ee6172c

Browse files
tanmayb123rxwei
authored andcommitted
Change operations a bit + add TransposedConv2D layer (#64)
1 parent 8264ac0 commit ee6172c

File tree

2 files changed

+156
-42
lines changed

2 files changed

+156
-42
lines changed

Sources/DeepLearning/Layer.swift

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,134 @@ public extension Conv2D {
575575
}
576576
}
577577

578+
/// A 2-D transposed convolution layer (e.g. spatial transposed convolution over images).
579+
///
580+
/// This layer creates a convolution filter that is transpose-convolved with the layer input
581+
/// to produce a tensor of outputs.
582+
@_fixed_layout
583+
public struct TransposedConv2D: Layer {
584+
/// The 4-D convolution kernel.
585+
public var filter: Tensor<Float>
586+
/// The bias vector.
587+
public var bias: Tensor<Float>
588+
/// An activation function.
589+
public typealias Activation = @differentiable (Tensor<Float>) -> Tensor<Float>
590+
/// The element-wise activation function.
591+
@noDerivative public let activation: Activation
592+
/// The strides of the sliding window for spatial dimensions.
593+
@noDerivative public let strides: (Int32, Int32)
594+
/// The padding algorithm for convolution.
595+
@noDerivative public let padding: Padding
596+
@noDerivative public let paddingIndex: Int32
597+
598+
/// Creates a `TransposedConv2D` layer with the specified filter, bias,
599+
/// activation function, strides, and padding.
600+
///
601+
/// - Parameters:
602+
/// - filter: The 4-D convolution kernel.
603+
/// - bias: The bias vector.
604+
/// - activation: The element-wise activation function.
605+
/// - strides: The strides of the sliding window for spatial dimensions.
606+
/// - padding: The padding algorithm for convolution.
607+
public init(
608+
filter: Tensor<Float>,
609+
bias: Tensor<Float>,
610+
activation: @escaping Activation,
611+
strides: (Int, Int),
612+
padding: Padding
613+
) {
614+
self.filter = filter
615+
self.bias = bias
616+
self.activation = activation
617+
(self.strides.0, self.strides.1) = (Int32(strides.0), Int32(strides.1))
618+
self.padding = padding
619+
self.paddingIndex = padding == .same ? 0 : 1
620+
}
621+
622+
/// Returns the output obtained from applying the layer to the given input.
623+
///
624+
/// - Parameters:
625+
/// - input: The input to the layer.
626+
/// - context: The contextual information for the layer application, e.g. the current learning
627+
/// phase.
628+
/// - Returns: The output.
629+
@differentiable
630+
public func applied(to input: Tensor<Float>, in _: Context) -> Tensor<Float> {
631+
let batchSize = input.shape[0]
632+
let w = (input.shape[1] - (1 * paddingIndex)) * strides.0 + (filter.shape[0] * paddingIndex)
633+
let h = (input.shape[2] - (1 * paddingIndex)) * strides.1 + (filter.shape[1] * paddingIndex)
634+
let c = filter.shape[2]
635+
let newShape = Tensor<Int32>([batchSize, w, h, c])
636+
return activation(input.conv2DBackpropInput(shape: newShape, filter: filter,
637+
strides: (1, strides.0, strides.1, 1),
638+
padding: padding) + bias)
639+
}
640+
}
641+
642+
public extension TransposedConv2D {
643+
/// Creates a `TransposedConv2D` layer with the specified filter shape, strides, padding, and
644+
/// element-wise activation function. The filter tensor is initialized using Glorot uniform
645+
/// initialization with the specified generator. The bias vector is initialized with zeros.
646+
///
647+
/// - Parameters:
648+
/// - filterShape: The shape of the 4-D convolution kernel.
649+
/// - strides: The strides of the sliding window for spatial dimensions.
650+
/// - padding: The padding algorithm for convolution.
651+
/// - activation: The element-wise activation function.
652+
/// - generator: The random number generator for initialization.
653+
///
654+
/// - Note: Use `init(filterShape:strides:padding:activation:seed:)` for faster random
655+
/// initialization.
656+
init<G: RandomNumberGenerator>(
657+
filterShape: (Int, Int, Int, Int),
658+
strides: (Int, Int) = (1, 1),
659+
padding: Padding = .valid,
660+
activation: @escaping Activation = identity,
661+
generator: inout G
662+
) {
663+
let filterTensorShape = TensorShape([
664+
Int32(filterShape.0), Int32(filterShape.1),
665+
Int32(filterShape.2), Int32(filterShape.3)])
666+
self.init(
667+
filter: Tensor(glorotUniform: filterTensorShape, generator: &generator),
668+
bias: Tensor(zeros: TensorShape([Int32(filterShape.3)])),
669+
activation: activation,
670+
strides: strides,
671+
padding: padding)
672+
}
673+
}
674+
675+
public extension TransposedConv2D {
676+
/// Creates a `TransposedConv2D` layer with the specified filter shape, strides, padding, and
677+
/// element-wise activation function. The filter tensor is initialized using Glorot uniform
678+
/// initialization with the specified seed. The bias vector is initialized with zeros.
679+
///
680+
/// - Parameters:
681+
/// - filterShape: The shape of the 4-D convolution kernel.
682+
/// - strides: The strides of the sliding window for spatial dimensions.
683+
/// - padding: The padding algorithm for convolution.
684+
/// - activation: The element-wise activation function.
685+
/// - seed: The random seed for initialization. The default value is random.
686+
init(
687+
filterShape: (Int, Int, Int, Int),
688+
strides: (Int, Int) = (1, 1),
689+
padding: Padding = .valid,
690+
activation: @escaping Activation = identity,
691+
seed: (Int64, Int64) = (Int64.random(in: Int64.min..<Int64.max),
692+
Int64.random(in: Int64.min..<Int64.max))
693+
) {
694+
let filterTensorShape = TensorShape([
695+
Int32(filterShape.0), Int32(filterShape.1),
696+
Int32(filterShape.2), Int32(filterShape.3)])
697+
self.init(
698+
filter: Tensor(glorotUniform: filterTensorShape, seed: seed),
699+
bias: Tensor(zeros: TensorShape([Int32(filterShape.3)])),
700+
activation: activation,
701+
strides: strides,
702+
padding: padding)
703+
}
704+
}
705+
578706
/// A batch normalization layer.
579707
///
580708
/// Normalizes the activations of the previous layer at each batch, i.e. applies a transformation

Sources/DeepLearning/Operators.swift

Lines changed: 28 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ public func identity<Scalar>(_ x: Tensor<Scalar>) -> Tensor<Scalar> {
3434
public extension Tensor where Scalar: TensorFlowFloatingPoint {
3535
// TODO: Verify that these calculations are correct.
3636
@inlinable
37-
func _vjpBatchNormalized(
37+
internal func _vjpBatchNormalized(
3838
alongAxis axis: Int32,
3939
offset: Tensor,
4040
scale: Tensor,
@@ -120,93 +120,79 @@ public extension Padding {
120120
}
121121
}
122122

123-
extension Tensor where Scalar: TensorFlowFloatingPoint {
123+
public extension Tensor where Scalar: TensorFlowFloatingPoint {
124124
/// TensorFlow builtin conv2d gradient helper for the input.
125125
@inlinable
126-
@differentiable(
127-
wrt: (filter, backpropOutput),
128-
vjp: _vjpTFConv2DBackpropInput(_:_:_:_:_:)
129-
)
130-
func _TFConv2DBackpropInput(
126+
@differentiable(wrt: (self, filter), vjp: _vjpConv2DBackpropInput)
127+
internal func conv2DBackpropInput(
131128
shape: Tensor<Int32>,
132129
filter: Tensor,
133-
backpropOutput: Tensor,
134130
strides: (Int32, Int32, Int32, Int32),
135131
padding: Padding
136132
) -> Tensor {
137133
return Raw.conv2DBackpropInput(
138134
inputSizes: shape,
139135
filter: filter,
140-
outBackprop: backpropOutput,
136+
outBackprop: self,
141137
strides: [strides.0, strides.1, strides.2, strides.3],
142138
padding: padding.raw)
143139
}
144140

145-
/// TensorFlow builtin conv2d gradient helper for the filter.
141+
/// TensorFlow builtin conv2d gradient helper for the filter.
146142
@inlinable
147-
@differentiable(
148-
wrt: (input, backpropOutput),
149-
vjp: _vjpTFConv2DBackpropFilter(_:_:_:_:_:)
150-
)
151-
func _TFConv2DBackpropFilter(
143+
@differentiable(wrt: (self, input), vjp: _vjpConv2DBackpropFilter)
144+
internal func conv2DBackpropFilter(
152145
input: Tensor,
153146
filterSizes: Tensor<Int32>,
154-
backpropOutput: Tensor,
155147
strides: (Int32, Int32, Int32, Int32),
156148
padding: Padding
157149
) -> Tensor {
158150
return Raw.conv2DBackpropFilter(
159151
input,
160152
filterSizes: filterSizes,
161-
outBackprop: backpropOutput,
153+
outBackprop: self,
162154
strides: [strides.0, strides.1, strides.2, strides.3],
163155
padding: padding.raw)
164156
}
165157

166158
@inlinable
167-
func _vjpTFConv2DBackpropInput(
159+
internal func _vjpConv2DBackpropInput(
168160
_ shape: Tensor<Int32>,
169161
_ filter: Tensor,
170-
_ backpropOutput: Tensor,
171162
_ strides: (Int32, Int32, Int32, Int32),
172163
_ padding: Padding
173164
) -> (Tensor, (Tensor) -> (Tensor, Tensor)) {
174-
let value = _TFConv2DBackpropInput(shape: shape, filter: filter,
175-
backpropOutput: backpropOutput,
176-
strides: strides, padding: padding)
165+
let value = conv2DBackpropInput(shape: shape, filter: filter, strides: strides,
166+
padding: padding)
177167
return (value, { v in
178168
return (
179-
self._TFConv2DBackpropFilter(input: v, filterSizes: shape,
180-
backpropOutput: backpropOutput,
181-
strides: strides, padding: padding),
169+
self.conv2DBackpropFilter(input: v, filterSizes: shape, strides: strides,
170+
padding: padding),
182171
v.convolved2D(withFilter: filter, strides: strides, padding: padding)
183172
)
184173
})
185174
}
186175

187176
@inlinable
188-
func _vjpTFConv2DBackpropFilter(
177+
internal func _vjpConv2DBackpropFilter(
189178
_ input: Tensor,
190179
_ filterSizes: Tensor<Int32>,
191-
_ backpropOutput: Tensor,
192180
_ strides: (Int32, Int32, Int32, Int32),
193181
_ padding: Padding
194182
) -> (Tensor, (Tensor) -> (Tensor, Tensor)) {
195-
let value = _TFConv2DBackpropFilter(input: input, filterSizes: filterSizes,
196-
backpropOutput: backpropOutput,
197-
strides: strides, padding: padding)
183+
let value = conv2DBackpropFilter(input: input, filterSizes: filterSizes,
184+
strides: strides, padding: padding)
198185
return (value, { v in
199186
return (
200-
self._TFConv2DBackpropInput(shape: filterSizes, filter: v,
201-
backpropOutput: backpropOutput,
202-
strides: strides, padding: padding),
187+
self.conv2DBackpropInput(shape: filterSizes, filter: v, strides: strides,
188+
padding: padding),
203189
input.convolved2D(withFilter: v, strides: strides, padding: padding)
204190
)
205191
})
206192
}
207193

208194
@inlinable
209-
func _vjpConvolved2D(
195+
internal func _vjpConvolved2D(
210196
filter: Tensor,
211197
strides: (Int32, Int32, Int32, Int32),
212198
padding: Padding
@@ -215,20 +201,20 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
215201
padding: padding)
216202
return (value, { v in
217203
return (
218-
self._TFConv2DBackpropInput(
219-
shape: self.shapeTensor, filter: filter, backpropOutput: v,
204+
v.conv2DBackpropInput(
205+
shape: self.shapeTensor, filter: filter,
220206
strides: strides, padding: padding
221207
),
222-
self._TFConv2DBackpropFilter(
223-
input: self, filterSizes: filter.shapeTensor, backpropOutput: v,
208+
v.conv2DBackpropFilter(
209+
input: self, filterSizes: filter.shapeTensor,
224210
strides: strides, padding: padding
225211
)
226212
)
227213
})
228214
}
229215

230216
@inlinable
231-
func _vjpMaxPooled(
217+
internal func _vjpMaxPooled(
232218
kernelSize: (Int32, Int32, Int32, Int32),
233219
strides: (Int32, Int32, Int32, Int32),
234220
padding: Padding
@@ -250,7 +236,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
250236
}
251237

252238
@inlinable
253-
func _vjpAveragePooled(
239+
internal func _vjpAveragePooled(
254240
kernelSize: (Int32, Int32, Int32, Int32),
255241
strides: (Int32, Int32, Int32, Int32),
256242
padding: Padding
@@ -284,8 +270,8 @@ public extension Tensor where Scalar: FloatingPoint {
284270
/// - Precondition: `filter` must have rank 4.
285271
@inlinable @inline(__always)
286272
@differentiable(
287-
wrt: (self, filter), vjp: _vjpConvolved2D(filter:strides:padding:)
288-
where Scalar : TensorFlowFloatingPoint
273+
wrt: (self, filter), vjp: _vjpConvolved2D
274+
where Scalar: TensorFlowFloatingPoint
289275
)
290276
func convolved2D(
291277
withFilter filter: Tensor,

0 commit comments

Comments
 (0)