Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit e63cd23

Browse files
vvmnnnkvrxwei
authored andcommitted
Add tensor padding modes (#514)
Makes Tensor.padded() on par with tf.pad(). Added padding modes (reflect, symmetric) can be used to make "resize-convolution" layers like in tensorflow/swift-models#191.
1 parent 27ea161 commit e63cd23

File tree

2 files changed

+111
-10
lines changed

2 files changed

+111
-10
lines changed

Sources/TensorFlow/Operators/Basic.swift

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -769,33 +769,65 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
769769
//===------------------------------------------------------------------------------------------===//
770770

771771
public extension Tensor where Scalar: Numeric {
772-
/// Returns a padded tensor according to the specified padding sizes.
772+
/// A mode that dictates how a tensor is padded.
773+
enum PaddingMode {
774+
/// Pads with constant value.
775+
case constant(Scalar)
776+
/// Mirrors values along padding dimensions, excluding the edge value.
777+
case reflect
778+
/// Mirrors values along padding dimensions, including the edge value.
779+
case symmetric
780+
}
781+
782+
/// Returns a tensor padded with constant according to the specified padding sizes.
773783
@inlinable
774-
@differentiable(wrt: self, vjp: _vjpPadded(forSizes:with:) where Scalar: TensorFlowFloatingPoint)
784+
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
775785
func padded(forSizes sizes: [(before: Int, after: Int)], with value: Scalar = 0) -> Tensor {
786+
padded(forSizes: sizes, mode: .constant(value))
787+
}
788+
789+
/// Returns a padded tensor according to the specified padding sizes and mode.
790+
@inlinable
791+
@differentiable(wrt: self, vjp: _vjpPadded(forSizes:mode:) where Scalar: TensorFlowFloatingPoint)
792+
func padded(forSizes sizes: [(before: Int, after: Int)], mode: PaddingMode) -> Tensor {
776793
let paddings = Tensor<Int32>(
777794
shape: [sizes.count, 2],
778795
scalars: sizes.flatMap { [Int32($0.before), Int32($0.after)] })
779-
return Raw.padV2(self, paddings: paddings, constantValues: Tensor(value))
796+
switch mode {
797+
case .constant(let constantValue):
798+
return Raw.padV2(self, paddings: paddings, constantValues: Tensor(constantValue))
799+
case .reflect:
800+
return Raw.mirrorPad(self, paddings: paddings, mode: .reflect)
801+
case .symmetric:
802+
return Raw.mirrorPad(self, paddings: paddings, mode: .symmetric)
803+
}
780804
}
781805
}
782806

783807
internal extension Tensor where Scalar: TensorFlowFloatingPoint {
784808
@inlinable
785809
func _vjpPadded(
786810
forSizes sizes: [(before: Int, after: Int)],
787-
with value: Scalar
811+
mode: PaddingMode
788812
) -> (Tensor, (Tensor) -> Tensor) {
789-
let result = padded(forSizes: sizes, with: value)
813+
let result = padded(forSizes: sizes, mode: mode)
790814
return (result, { [rank = rankTensor, shape = shapeTensor] v in
791815
let paddings = Tensor<Int32>(
792816
shape: [sizes.count, 2],
793817
scalars: sizes.flatMap { [Int32($0.before), Int32($0.after)] })
794-
let padBefore = Raw.slice(paddings,
795-
begin: Tensor<Int32>([0, 0]),
796-
size: Tensor<Int32>(stacking: [rank, Tensor<Int32>(1)]))
797-
let begin = padBefore.reshaped(to: [-1])
798-
return Raw.slice(v, begin: begin, size: shape)
818+
switch mode {
819+
case .constant:
820+
let padBefore = Raw.slice(
821+
paddings,
822+
begin: Tensor<Int32>([0, 0]),
823+
size: Tensor<Int32>(stacking: [rank, Tensor<Int32>(1)]))
824+
let begin = padBefore.reshaped(to: [-1])
825+
return v.slice(lowerBounds: begin, sizes: shape)
826+
case .reflect:
827+
return Raw.mirrorPadGrad(v, paddings: paddings, mode: .reflect)
828+
case .symmetric:
829+
return Raw.mirrorPadGrad(v, paddings: paddings, mode: .symmetric)
830+
}
799831
})
800832
}
801833
}

Tests/TensorFlowTests/OperatorTests/BasicTests.swift

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,39 @@ final class BasicOperatorTests: XCTestCase {
4040
XCTAssertEqual(paddedTensor, target)
4141
}
4242

43+
func testPaddedConstant() {
44+
let x = Tensor<Float>(ones: [2, 2])
45+
let target = Tensor<Float>([[3, 3, 3], [1, 1, 3], [1, 1, 3]])
46+
let paddedTensor = x.padded(forSizes: [(1, 0), (0, 1)], mode: .constant(3.0))
47+
XCTAssertEqual(paddedTensor, target)
48+
}
49+
50+
func testPaddedReflect() {
51+
let x = Tensor<Float>([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
52+
let target = Tensor<Float>([
53+
[7, 8, 9, 8, 7],
54+
[4, 5, 6, 5, 4],
55+
[1, 2, 3, 2, 1],
56+
[4, 5, 6, 5, 4],
57+
[7, 8, 9, 8, 7]
58+
])
59+
let paddedTensor = x.padded(forSizes: [(2, 0), (0, 2)], mode: .reflect)
60+
XCTAssertEqual(paddedTensor, target)
61+
}
62+
63+
func testPaddedSymmetric() {
64+
let x = Tensor<Float>([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
65+
let target = Tensor<Float>([
66+
[4, 5, 6, 6, 5],
67+
[1, 2, 3, 3, 2],
68+
[1, 2, 3, 3, 2],
69+
[4, 5, 6, 6, 5],
70+
[7, 8, 9, 9, 8]
71+
])
72+
let paddedTensor = x.padded(forSizes: [(2, 0), (0, 2)], mode: .symmetric)
73+
XCTAssertEqual(paddedTensor, target)
74+
}
75+
4376
func testVJPPadded() {
4477
let x = Tensor<Float>(ones: [3, 2])
4578
let target = Tensor<Float>([[2, 2], [2, 2], [2, 2]])
@@ -50,6 +83,36 @@ final class BasicOperatorTests: XCTestCase {
5083
XCTAssertEqual(grads, target)
5184
}
5285

86+
func testVJPPaddedConstant() {
87+
let x = Tensor<Float>(ones: [3, 2])
88+
let target = Tensor<Float>([[2, 2], [2, 2], [2, 2]])
89+
let grads = x.gradient { a -> Tensor<Float> in
90+
let paddedTensor = a.padded(forSizes: [(1, 0), (0, 1)], mode: .constant(3.0))
91+
return (paddedTensor * paddedTensor).sum()
92+
}
93+
XCTAssertEqual(grads, target)
94+
}
95+
96+
func testVJPPaddedReflect() {
97+
let x = Tensor<Float>([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
98+
let target = Tensor<Float>([[4, 8, 6], [32, 40, 24], [56, 64, 36]])
99+
let grads = x.gradient { a -> Tensor<Float> in
100+
let paddedTensor = a.padded(forSizes: [(2, 0), (0, 2)], mode: .reflect)
101+
return (paddedTensor * paddedTensor).sum()
102+
}
103+
XCTAssertEqual(grads, target)
104+
}
105+
106+
func testVJPPaddedSymmetric() {
107+
let x = Tensor<Float>([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
108+
let target = Tensor<Float>([[4, 16, 24], [16, 40, 48], [14, 32, 36]])
109+
let grads = x.gradient { a -> Tensor<Float> in
110+
let paddedTensor = a.padded(forSizes: [(2, 0), (0, 2)], mode: .symmetric)
111+
return (paddedTensor * paddedTensor).sum()
112+
}
113+
XCTAssertEqual(grads, target)
114+
}
115+
53116
func testElementIndexing() {
54117
// NOTE: cannot test multiple `Tensor.shape` or `Tensor.scalars` directly
55118
// until send and receive are implemented (without writing a bunch of mini
@@ -599,7 +662,13 @@ final class BasicOperatorTests: XCTestCase {
599662
("testGathering", testGathering),
600663
("testBatchGathering", testBatchGathering),
601664
("testPadded", testPadded),
665+
("testPaddedConstant", testPaddedConstant),
666+
("testPaddedReflect", testPaddedReflect),
667+
("testPaddedSymmetric", testPaddedSymmetric),
602668
("testVJPPadded", testVJPPadded),
669+
("testVJPPaddedConstant", testVJPPaddedConstant),
670+
("testVJPPaddedReflect", testVJPPaddedReflect),
671+
("testVJPPaddedSymmetric", testVJPPaddedSymmetric),
603672
("testElementIndexing", testElementIndexing),
604673
("testElementIndexingAssignment", testElementIndexingAssignment),
605674
("testNestedElementIndexing", testNestedElementIndexing),

0 commit comments

Comments
 (0)