Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Change operations a bit + add TransposedConv2D layer #64

Merged
merged 9 commits into from
Mar 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions Sources/DeepLearning/Layer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,134 @@ public extension Conv2D {
}
}

/// A 2-D transposed convolution layer (e.g. spatial transposed convolution over images).
///
/// This layer creates a convolution filter that is transpose-convolved with the layer input
/// to produce a tensor of outputs.
@_fixed_layout
public struct TransposedConv2D: Layer {
/// The 4-D convolution kernel.
public var filter: Tensor<Float>
/// The bias vector.
public var bias: Tensor<Float>
/// An activation function.
public typealias Activation = @differentiable (Tensor<Float>) -> Tensor<Float>
/// The element-wise activation function.
@noDerivative public let activation: Activation
/// The strides of the sliding window for spatial dimensions.
@noDerivative public let strides: (Int32, Int32)
/// The padding algorithm for convolution.
@noDerivative public let padding: Padding
@noDerivative public let paddingIndex: Int32

/// Creates a `TransposedConv2D` layer with the specified filter, bias,
/// activation function, strides, and padding.
///
/// - Parameters:
/// - filter: The 4-D convolution kernel.
/// - bias: The bias vector.
/// - activation: The element-wise activation function.
/// - strides: The strides of the sliding window for spatial dimensions.
/// - padding: The padding algorithm for convolution.
public init(
filter: Tensor<Float>,
bias: Tensor<Float>,
activation: @escaping Activation,
strides: (Int, Int),
padding: Padding
) {
self.filter = filter
self.bias = bias
self.activation = activation
(self.strides.0, self.strides.1) = (Int32(strides.0), Int32(strides.1))
self.padding = padding
self.paddingIndex = padding == .same ? 0 : 1
}

/// Returns the output obtained from applying the layer to the given input.
///
/// - Parameters:
/// - input: The input to the layer.
/// - context: The contextual information for the layer application, e.g. the current learning
/// phase.
/// - Returns: The output.
@differentiable
public func applied(to input: Tensor<Float>, in _: Context) -> Tensor<Float> {
let batchSize = input.shape[0]
let w = (input.shape[1] - (1 * paddingIndex)) * strides.0 + (filter.shape[0] * paddingIndex)
let h = (input.shape[2] - (1 * paddingIndex)) * strides.1 + (filter.shape[1] * paddingIndex)
let c = filter.shape[2]
let newShape = Tensor<Int32>([batchSize, w, h, c])
return activation(input.conv2DBackpropInput(shape: newShape, filter: filter,
strides: (1, strides.0, strides.1, 1),
padding: padding) + bias)
}
}

public extension TransposedConv2D {
/// Creates a `TransposedConv2D` layer with the specified filter shape, strides, padding, and
/// element-wise activation function. The filter tensor is initialized using Glorot uniform
/// initialization with the specified generator. The bias vector is initialized with zeros.
///
/// - Parameters:
/// - filterShape: The shape of the 4-D convolution kernel.
/// - strides: The strides of the sliding window for spatial dimensions.
/// - padding: The padding algorithm for convolution.
/// - activation: The element-wise activation function.
/// - generator: The random number generator for initialization.
///
/// - Note: Use `init(filterShape:strides:padding:activation:seed:)` for faster random
/// initialization.
init<G: RandomNumberGenerator>(
filterShape: (Int, Int, Int, Int),
strides: (Int, Int) = (1, 1),
padding: Padding = .valid,
activation: @escaping Activation = identity,
generator: inout G
) {
let filterTensorShape = TensorShape([
Int32(filterShape.0), Int32(filterShape.1),
Int32(filterShape.2), Int32(filterShape.3)])
self.init(
filter: Tensor(glorotUniform: filterTensorShape, generator: &generator),
bias: Tensor(zeros: TensorShape([Int32(filterShape.3)])),
activation: activation,
strides: strides,
padding: padding)
}
}

public extension TransposedConv2D {
/// Creates a `TransposedConv2D` layer with the specified filter shape, strides, padding, and
/// element-wise activation function. The filter tensor is initialized using Glorot uniform
/// initialization with the specified seed. The bias vector is initialized with zeros.
///
/// - Parameters:
/// - filterShape: The shape of the 4-D convolution kernel.
/// - strides: The strides of the sliding window for spatial dimensions.
/// - padding: The padding algorithm for convolution.
/// - activation: The element-wise activation function.
/// - seed: The random seed for initialization. The default value is random.
init(
filterShape: (Int, Int, Int, Int),
strides: (Int, Int) = (1, 1),
padding: Padding = .valid,
activation: @escaping Activation = identity,
seed: (Int64, Int64) = (Int64.random(in: Int64.min..<Int64.max),
Int64.random(in: Int64.min..<Int64.max))
) {
let filterTensorShape = TensorShape([
Int32(filterShape.0), Int32(filterShape.1),
Int32(filterShape.2), Int32(filterShape.3)])
self.init(
filter: Tensor(glorotUniform: filterTensorShape, seed: seed),
bias: Tensor(zeros: TensorShape([Int32(filterShape.3)])),
activation: activation,
strides: strides,
padding: padding)
}
}

/// A batch normalization layer.
///
/// Normalizes the activations of the previous layer at each batch, i.e. applies a transformation
Expand Down
70 changes: 28 additions & 42 deletions Sources/DeepLearning/Operators.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public func identity<Scalar>(_ x: Tensor<Scalar>) -> Tensor<Scalar> {
public extension Tensor where Scalar: TensorFlowFloatingPoint {
// TODO: Verify that these calculations are correct.
@inlinable
func _vjpBatchNormalized(
internal func _vjpBatchNormalized(
alongAxis axis: Int32,
offset: Tensor,
scale: Tensor,
Expand Down Expand Up @@ -120,93 +120,79 @@ public extension Padding {
}
}

extension Tensor where Scalar: TensorFlowFloatingPoint {
public extension Tensor where Scalar: TensorFlowFloatingPoint {
/// TensorFlow builtin conv2d gradient helper for the input.
@inlinable
@differentiable(
wrt: (filter, backpropOutput),
vjp: _vjpTFConv2DBackpropInput(_:_:_:_:_:)
)
func _TFConv2DBackpropInput(
@differentiable(wrt: (self, filter), vjp: _vjpConv2DBackpropInput)
internal func conv2DBackpropInput(
shape: Tensor<Int32>,
filter: Tensor,
backpropOutput: Tensor,
strides: (Int32, Int32, Int32, Int32),
padding: Padding
) -> Tensor {
return Raw.conv2DBackpropInput(
inputSizes: shape,
filter: filter,
outBackprop: backpropOutput,
outBackprop: self,
strides: [strides.0, strides.1, strides.2, strides.3],
padding: padding.raw)
}

/// TensorFlow builtin conv2d gradient helper for the filter.
/// TensorFlow builtin conv2d gradient helper for the filter.
@inlinable
@differentiable(
wrt: (input, backpropOutput),
vjp: _vjpTFConv2DBackpropFilter(_:_:_:_:_:)
)
func _TFConv2DBackpropFilter(
@differentiable(wrt: (self, input), vjp: _vjpConv2DBackpropFilter)
internal func conv2DBackpropFilter(
input: Tensor,
filterSizes: Tensor<Int32>,
backpropOutput: Tensor,
strides: (Int32, Int32, Int32, Int32),
padding: Padding
) -> Tensor {
return Raw.conv2DBackpropFilter(
input,
filterSizes: filterSizes,
outBackprop: backpropOutput,
outBackprop: self,
strides: [strides.0, strides.1, strides.2, strides.3],
padding: padding.raw)
}

@inlinable
func _vjpTFConv2DBackpropInput(
internal func _vjpConv2DBackpropInput(
_ shape: Tensor<Int32>,
_ filter: Tensor,
_ backpropOutput: Tensor,
_ strides: (Int32, Int32, Int32, Int32),
_ padding: Padding
) -> (Tensor, (Tensor) -> (Tensor, Tensor)) {
let value = _TFConv2DBackpropInput(shape: shape, filter: filter,
backpropOutput: backpropOutput,
strides: strides, padding: padding)
let value = conv2DBackpropInput(shape: shape, filter: filter, strides: strides,
padding: padding)
return (value, { v in
return (
self._TFConv2DBackpropFilter(input: v, filterSizes: shape,
backpropOutput: backpropOutput,
strides: strides, padding: padding),
self.conv2DBackpropFilter(input: v, filterSizes: shape, strides: strides,
padding: padding),
v.convolved2D(withFilter: filter, strides: strides, padding: padding)
)
})
}

@inlinable
func _vjpTFConv2DBackpropFilter(
internal func _vjpConv2DBackpropFilter(
_ input: Tensor,
_ filterSizes: Tensor<Int32>,
_ backpropOutput: Tensor,
_ strides: (Int32, Int32, Int32, Int32),
_ padding: Padding
) -> (Tensor, (Tensor) -> (Tensor, Tensor)) {
let value = _TFConv2DBackpropFilter(input: input, filterSizes: filterSizes,
backpropOutput: backpropOutput,
strides: strides, padding: padding)
let value = conv2DBackpropFilter(input: input, filterSizes: filterSizes,
strides: strides, padding: padding)
return (value, { v in
return (
self._TFConv2DBackpropInput(shape: filterSizes, filter: v,
backpropOutput: backpropOutput,
strides: strides, padding: padding),
self.conv2DBackpropInput(shape: filterSizes, filter: v, strides: strides,
padding: padding),
input.convolved2D(withFilter: v, strides: strides, padding: padding)
)
})
}

@inlinable
func _vjpConvolved2D(
internal func _vjpConvolved2D(
filter: Tensor,
strides: (Int32, Int32, Int32, Int32),
padding: Padding
Expand All @@ -215,20 +201,20 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
padding: padding)
return (value, { v in
return (
self._TFConv2DBackpropInput(
shape: self.shapeTensor, filter: filter, backpropOutput: v,
v.conv2DBackpropInput(
shape: self.shapeTensor, filter: filter,
strides: strides, padding: padding
),
self._TFConv2DBackpropFilter(
input: self, filterSizes: filter.shapeTensor, backpropOutput: v,
v.conv2DBackpropFilter(
input: self, filterSizes: filter.shapeTensor,
strides: strides, padding: padding
)
)
})
}

@inlinable
func _vjpMaxPooled(
internal func _vjpMaxPooled(
kernelSize: (Int32, Int32, Int32, Int32),
strides: (Int32, Int32, Int32, Int32),
padding: Padding
Expand All @@ -250,7 +236,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
}

@inlinable
func _vjpAveragePooled(
internal func _vjpAveragePooled(
kernelSize: (Int32, Int32, Int32, Int32),
strides: (Int32, Int32, Int32, Int32),
padding: Padding
Expand Down Expand Up @@ -284,8 +270,8 @@ public extension Tensor where Scalar: FloatingPoint {
/// - Precondition: `filter` must have rank 4.
@inlinable @inline(__always)
@differentiable(
wrt: (self, filter), vjp: _vjpConvolved2D(filter:strides:padding:)
where Scalar : TensorFlowFloatingPoint
wrt: (self, filter), vjp: _vjpConvolved2D
where Scalar: TensorFlowFloatingPoint
)
func convolved2D(
withFilter filter: Tensor,
Expand Down