Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit e787685

Browse files
committed
adding conv transpose 3d
1 parent d7eff12 commit e787685

File tree

1 file changed

+128
-0
lines changed

1 file changed

+128
-0
lines changed

Sources/TensorFlow/Layers/Convolutional.swift

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,3 +486,131 @@ public extension TransposedConv2D {
486486
padding: padding)
487487
}
488488
}
489+
490+
/// A 3-D transposed convolution layer (e.g. spatial transposed convolution over images).
491+
///
492+
/// This layer creates a convolution filter that is transpose-convolved with the layer input
493+
/// to produce a tensor of outputs.
494+
@_fixed_layout
495+
public struct TransposedConv3D: Layer {
496+
/// The 5-D convolution kernel.
497+
public var filter: Tensor<Float>
498+
/// The bias vector.
499+
public var bias: Tensor<Float>
500+
/// An activation function.
501+
public typealias Activation = @differentiable (Tensor<Float>) -> Tensor<Float>
502+
/// The element-wise activation function.
503+
@noDerivative public let activation: Activation
504+
/// The strides of the sliding window for spatial dimensions.
505+
@noDerivative public let strides: (Int, Int, Int)
506+
/// The padding algorithm for convolution.
507+
@noDerivative public let padding: Padding
508+
@noDerivative public let paddingIndex: Int
509+
510+
/// Creates a `TransposedConv3D` layer with the specified filter, bias,
511+
/// activation function, strides, and padding.
512+
///
513+
/// - Parameters:
514+
/// - filter: The 5-D convolution kernel.
515+
/// - bias: The bias vector.
516+
/// - activation: The element-wise activation function.
517+
/// - strides: The strides of the sliding window for spatial dimensions.
518+
/// - padding: The padding algorithm for convolution.
519+
public init(
520+
filter: Tensor<Float>,
521+
bias: Tensor<Float>,
522+
activation: @escaping Activation,
523+
strides: (Int, Int, Int),
524+
padding: Padding
525+
) {
526+
self.filter = filter
527+
self.bias = bias
528+
self.activation = activation
529+
self.strides = strides
530+
self.padding = padding
531+
self.paddingIndex = padding == .same ? 0 : 1
532+
}
533+
534+
/// Returns the output obtained from applying the layer to the given input.
535+
///
536+
/// - Parameter input: The input to the layer.
537+
/// - Returns: The output.
538+
@differentiable
539+
public func call(_ input: Tensor<Float>) -> Tensor<Float> {
540+
let batchSize = input.shape[0]
541+
let w = (input.shape[1] - (1 * paddingIndex)) *
542+
strides.0 + (filter.shape[0] * paddingIndex)
543+
let h = (input.shape[2] - (1 * paddingIndex)) *
544+
strides.1 + (filter.shape[1] * paddingIndex)
545+
let d = (input.shape[3] - (1 * paddingIndex)) *
546+
strides.2 + (filter.shape[2] * paddingIndex)
547+
let c = filter.shape[3]
548+
let newShape = Tensor<Int32>([Int32(batchSize), Int32(w), Int32(h), Int32(d), Int32(c)])
549+
return activation(input.conv2DBackpropInput(shape: newShape, filter: filter,
550+
strides: (1, strides.0, strides.1,
551+
strides.2, 1),
552+
padding: padding) + bias)
553+
}
554+
}
555+
556+
public extension TransposedConv3D {
557+
/// Creates a `TransposedConv3D` layer with the specified filter shape, strides, padding, and
558+
/// element-wise activation function. The filter tensor is initialized using Glorot uniform
559+
/// initialization with the specified generator. The bias vector is initialized with zeros.
560+
///
561+
/// - Parameters:
562+
/// - filterShape: The shape of the 5-D convolution kernel.
563+
/// - strides: The strides of the sliding window for spatial dimensions.
564+
/// - padding: The padding algorithm for convolution.
565+
/// - activation: The element-wise activation function.
566+
/// - generator: The random number generator for initialization.
567+
///
568+
/// - Note: Use `init(filterShape:strides:padding:activation:seed:)` for faster random
569+
/// initialization.
570+
init<G: RandomNumberGenerator>(
571+
filterShape: (Int, Int, Int, Int, Int),
572+
strides: (Int, Int, Int) = (1, 1, 1),
573+
padding: Padding = .valid,
574+
activation: @escaping Activation = identity,
575+
generator: inout G
576+
) {
577+
let filterTensorShape = TensorShape([
578+
filterShape.0, filterShape.1, filterShape.2, filterShape.3, filterShape.4])
579+
self.init(
580+
filter: Tensor(glorotUniform: filterTensorShape, generator: &generator),
581+
bias: Tensor(zeros: TensorShape([filterShape.4])),
582+
activation: activation,
583+
strides: strides,
584+
padding: padding)
585+
}
586+
}
587+
588+
public extension TransposedConv3D {
589+
/// Creates a `TransposedConv3D` layer with the specified filter shape, strides, padding, and
590+
/// element-wise activation function. The filter tensor is initialized using Glorot uniform
591+
/// initialization with the specified seed. The bias vector is initialized with zeros.
592+
///
593+
/// - Parameters:
594+
/// - filterShape: The shape of the 5-D convolution kernel.
595+
/// - strides: The strides of the sliding window for spatial dimensions.
596+
/// - padding: The padding algorithm for convolution.
597+
/// - activation: The element-wise activation function.
598+
/// - seed: The random seed for initialization. The default value is random.
599+
init(
600+
filterShape: (Int, Int, Int, Int, Int),
601+
strides: (Int, Int, Int) = (1, 1, 1),
602+
padding: Padding = .valid,
603+
activation: @escaping Activation = identity,
604+
seed: (Int64, Int64) = (Int64.random(in: Int64.min..<Int64.max),
605+
Int64.random(in: Int64.min..<Int64.max))
606+
) {
607+
let filterTensorShape = TensorShape([
608+
filterShape.0, filterShape.1, filterShape.2, filterShape.3, filterShape.4])
609+
self.init(
610+
filter: Tensor(glorotUniform: filterTensorShape, seed: seed),
611+
bias: Tensor(zeros: TensorShape([filterShape.4])),
612+
activation: activation,
613+
strides: strides,
614+
padding: padding)
615+
}
616+
}

0 commit comments

Comments
 (0)