-
Notifications
You must be signed in to change notification settings - Fork 137
Add derivative for 'padded(forSizes:with:)'. #184
Conversation
Fix: #179 |
@@ -491,6 +502,7 @@ final class BasicOperatorTests: XCTestCase { | |||
("testUnbroadcast1", testUnbroadcast1), | |||
("testUnbroadcast2", testUnbroadcast2), | |||
("testSliceUpdate", testSliceUpdate), | |||
("testVJPPadded", testVJPPadded), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a test for padded too as part of this PR since it’s missing?
And also maybe move them both right after testGathering
in an attempt to keep the order consistent between the test files and the source files?
@@ -674,6 +674,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint { | |||
public extension Tensor where Scalar: Numeric { | |||
/// Returns a padded tensor according to the specified padding sizes. | |||
@inlinable | |||
@differentiable(wrt:self, vjp: _vjpPadded(forSizes:with:) where Scalar: TensorFlowFloatingPoint) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wrt:self
-> wrt: self
begin: Tensor<Int32>([0, 0]), | ||
size: Tensor<Int32>(stacking: [rank, Tensor<Int32>(1)])) | ||
let begin = Raw.reshape(padBefore, shape: Tensor<Int32>([-1])) | ||
return Raw.slice(v,begin: begin,size: shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
v,begin
-> v, begin
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
begin,size
-> begin, size
return (result, { [rank = rankTensor, shape = shapeTensor] v in | ||
let paddings = Tensor<Int32>( | ||
shape: [sizes.count, 2], | ||
scalars: sizes.flatMap {[Int32($0.before), Int32($0.after)] }) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Space after {
.
@@ -466,6 +466,17 @@ final class BasicOperatorTests: XCTestCase { | |||
XCTAssertEqual(target, Tensor(repeating: 1, shape: [2, 3, 4])) | |||
} | |||
|
|||
func testVJPPadded() { | |||
// 1 -> 2 x 3 x 4 | |||
let x = Tensor<Float>(ones:[3,2]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the test function code, could you please add spaces after all the argument labels?
And also between the literals in arrays and tuples?
@eaplatanios All the requested changes are done. |
Co-Authored-By: Richard Wei <[email protected]>
Co-Authored-By: Richard Wei <[email protected]>
Co-Authored-By: Richard Wei <[email protected]>
Co-Authored-By: Richard Wei <[email protected]>
Co-Authored-By: Richard Wei <[email protected]>
Co-Authored-By: Richard Wei <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Co-Authored-By: Richard Wei <[email protected]>
Resolves #179.