Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit bf2b462

Browse files
committed
adding conv transpose 3d
1 parent 2c7a190 commit bf2b462

File tree

1 file changed

+128
-0
lines changed

1 file changed

+128
-0
lines changed

Sources/TensorFlow/Layers/Convolutional.swift

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,3 +485,131 @@ public extension TransposedConv2D {
485485
padding: padding)
486486
}
487487
}
488+
489+
/// A 3-D transposed convolution layer (e.g. spatial transposed convolution over images).
490+
///
491+
/// This layer creates a convolution filter that is transpose-convolved with the layer input
492+
/// to produce a tensor of outputs.
493+
@_fixed_layout
494+
public struct TransposedConv3D: Layer {
495+
/// The 5-D convolution kernel.
496+
public var filter: Tensor<Float>
497+
/// The bias vector.
498+
public var bias: Tensor<Float>
499+
/// An activation function.
500+
public typealias Activation = @differentiable (Tensor<Float>) -> Tensor<Float>
501+
/// The element-wise activation function.
502+
@noDerivative public let activation: Activation
503+
/// The strides of the sliding window for spatial dimensions.
504+
@noDerivative public let strides: (Int, Int, Int)
505+
/// The padding algorithm for convolution.
506+
@noDerivative public let padding: Padding
507+
@noDerivative public let paddingIndex: Int
508+
509+
/// Creates a `TransposedConv3D` layer with the specified filter, bias,
510+
/// activation function, strides, and padding.
511+
///
512+
/// - Parameters:
513+
/// - filter: The 5-D convolution kernel.
514+
/// - bias: The bias vector.
515+
/// - activation: The element-wise activation function.
516+
/// - strides: The strides of the sliding window for spatial dimensions.
517+
/// - padding: The padding algorithm for convolution.
518+
public init(
519+
filter: Tensor<Float>,
520+
bias: Tensor<Float>,
521+
activation: @escaping Activation,
522+
strides: (Int, Int, Int),
523+
padding: Padding
524+
) {
525+
self.filter = filter
526+
self.bias = bias
527+
self.activation = activation
528+
self.strides = strides
529+
self.padding = padding
530+
self.paddingIndex = padding == .same ? 0 : 1
531+
}
532+
533+
/// Returns the output obtained from applying the layer to the given input.
534+
///
535+
/// - Parameter input: The input to the layer.
536+
/// - Returns: The output.
537+
@differentiable
538+
public func call(_ input: Tensor<Float>) -> Tensor<Float> {
539+
let batchSize = input.shape[0]
540+
let w = (input.shape[1] - (1 * paddingIndex)) *
541+
strides.0 + (filter.shape[0] * paddingIndex)
542+
let h = (input.shape[2] - (1 * paddingIndex)) *
543+
strides.1 + (filter.shape[1] * paddingIndex)
544+
let d = (input.shape[3] - (1 * paddingIndex)) *
545+
strides.2 + (filter.shape[2] * paddingIndex)
546+
let c = filter.shape[3]
547+
let newShape = Tensor<Int32>([Int32(batchSize), Int32(w), Int32(h), Int32(d), Int32(c)])
548+
return activation(input.conv2DBackpropInput(shape: newShape, filter: filter,
549+
strides: (1, strides.0, strides.1,
550+
strides.2, 1),
551+
padding: padding) + bias)
552+
}
553+
}
554+
555+
public extension TransposedConv3D {
556+
/// Creates a `TransposedConv3D` layer with the specified filter shape, strides, padding, and
557+
/// element-wise activation function. The filter tensor is initialized using Glorot uniform
558+
/// initialization with the specified generator. The bias vector is initialized with zeros.
559+
///
560+
/// - Parameters:
561+
/// - filterShape: The shape of the 5-D convolution kernel.
562+
/// - strides: The strides of the sliding window for spatial dimensions.
563+
/// - padding: The padding algorithm for convolution.
564+
/// - activation: The element-wise activation function.
565+
/// - generator: The random number generator for initialization.
566+
///
567+
/// - Note: Use `init(filterShape:strides:padding:activation:seed:)` for faster random
568+
/// initialization.
569+
init<G: RandomNumberGenerator>(
570+
filterShape: (Int, Int, Int, Int, Int),
571+
strides: (Int, Int, Int) = (1, 1, 1),
572+
padding: Padding = .valid,
573+
activation: @escaping Activation = identity,
574+
generator: inout G
575+
) {
576+
let filterTensorShape = TensorShape([
577+
filterShape.0, filterShape.1, filterShape.2, filterShape.3, filterShape.4])
578+
self.init(
579+
filter: Tensor(glorotUniform: filterTensorShape, generator: &generator),
580+
bias: Tensor(zeros: TensorShape([filterShape.4])),
581+
activation: activation,
582+
strides: strides,
583+
padding: padding)
584+
}
585+
}
586+
587+
public extension TransposedConv3D {
588+
/// Creates a `TransposedConv3D` layer with the specified filter shape, strides, padding, and
589+
/// element-wise activation function. The filter tensor is initialized using Glorot uniform
590+
/// initialization with the specified seed. The bias vector is initialized with zeros.
591+
///
592+
/// - Parameters:
593+
/// - filterShape: The shape of the 5-D convolution kernel.
594+
/// - strides: The strides of the sliding window for spatial dimensions.
595+
/// - padding: The padding algorithm for convolution.
596+
/// - activation: The element-wise activation function.
597+
/// - seed: The random seed for initialization. The default value is random.
598+
init(
599+
filterShape: (Int, Int, Int, Int, Int),
600+
strides: (Int, Int, Int) = (1, 1, 1),
601+
padding: Padding = .valid,
602+
activation: @escaping Activation = identity,
603+
seed: (Int64, Int64) = (Int64.random(in: Int64.min..<Int64.max),
604+
Int64.random(in: Int64.min..<Int64.max))
605+
) {
606+
let filterTensorShape = TensorShape([
607+
filterShape.0, filterShape.1, filterShape.2, filterShape.3, filterShape.4])
608+
self.init(
609+
filter: Tensor(glorotUniform: filterTensorShape, seed: seed),
610+
bias: Tensor(zeros: TensorShape([filterShape.4])),
611+
activation: activation,
612+
strides: strides,
613+
padding: padding)
614+
}
615+
}

0 commit comments

Comments
 (0)