Skip to content

Commit 19e24f0

Browse files
lakshya-skydan12411
authored andcommitted
Add derivative for 'padded(forSizes:with:)'. (tensorflow#184)
1 parent cde806c commit 19e24f0

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

Sources/TensorFlow/Operators/Basic.swift

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
674674
public extension Tensor where Scalar: Numeric {
675675
/// Returns a padded tensor according to the specified padding sizes.
676676
@inlinable
677+
@differentiable(wrt: self, vjp: _vjpPadded(forSizes:with:) where Scalar: TensorFlowFloatingPoint)
677678
func padded(forSizes sizes: [(before: Int, after: Int)], with value: Scalar = 0) -> Tensor {
678679
let paddings = Tensor<Int32>(
679680
shape: [sizes.count, 2],
@@ -682,6 +683,26 @@ public extension Tensor where Scalar: Numeric {
682683
}
683684
}
684685

686+
internal extension Tensor where Scalar: TensorFlowFloatingPoint {
687+
@inlinable
688+
func _vjpPadded(
689+
forSizes sizes: [(before: Int, after: Int)],
690+
with value: Scalar
691+
) -> (Tensor, (Tensor) -> Tensor) {
692+
let result = padded(forSizes: sizes, with: value)
693+
return (result, { [rank = rankTensor, shape = shapeTensor] v in
694+
let paddings = Tensor<Int32>(
695+
shape: [sizes.count, 2],
696+
scalars: sizes.flatMap { [Int32($0.before), Int32($0.after)] })
697+
let padBefore = Raw.slice(paddings,
698+
begin: Tensor<Int32>([0, 0]),
699+
size: Tensor<Int32>(stacking: [rank, Tensor<Int32>(1)]))
700+
let begin = padBefore.reshaped(to: [-1])
701+
return Raw.slice(v, begin: begin, size: shape)
702+
})
703+
}
704+
}
705+
685706
//===------------------------------------------------------------------------------------------===//
686707
// Indexing and Slicing
687708
//===------------------------------------------------------------------------------------------===//

Tests/TensorFlowTests/OperatorTests/BasicTests.swift

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,23 @@ final class BasicOperatorTests: XCTestCase {
3232
XCTAssertEqual(y, Tensor<Float>([3.0, 6.0]))
3333
}
3434

35+
func testPadded() {
36+
let x = Tensor<Float>(ones: [2, 2])
37+
let target = Tensor<Float>([[3, 3, 3], [1, 1, 3], [1, 1, 3]])
38+
let paddedTensor = x.padded(forSizes: [(1, 0), (0, 1)], with: 3.0)
39+
XCTAssertEqual(paddedTensor, target)
40+
}
41+
42+
func testVJPPadded() {
43+
let x = Tensor<Float>(ones: [3, 2])
44+
let target = Tensor<Float>([[2, 2], [2, 2], [2, 2]])
45+
let grads = x.gradient { a -> Tensor<Float> in
46+
let paddedTensor = a.padded(forSizes: [(1, 0), (0, 1)], with: 3.0)
47+
return (paddedTensor * paddedTensor).sum()
48+
}
49+
XCTAssertEqual(grads, target)
50+
}
51+
3552
func testElementIndexing() {
3653
// NOTE: cannot test multiple `Tensor.shape` or `Tensor.scalars` directly
3754
// until send and receive are implemented (without writing a bunch of mini
@@ -468,6 +485,8 @@ final class BasicOperatorTests: XCTestCase {
468485

469486
static var allTests = [
470487
("testGathering", testGathering),
488+
("testPadded", testPadded),
489+
("testVJPPadded", testVJPPadded),
471490
("testElementIndexing", testElementIndexing),
472491
("testElementIndexingAssignment", testElementIndexingAssignment),
473492
("testNestedElementIndexing", testNestedElementIndexing),

0 commit comments

Comments
 (0)