Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit f0a6da6

Browse files
authored
[Tensor] Add a VJP for 'init(_:)' (scalars) and 'init(shape:scalars:)'. (#512)
Add a VJP for the `init(_:)` which takes an array of scalars and `init(shape:scalars:)` which takes a shape and an array of scalars. Both VJPs are merely returning the original value and a transpose, because `init(_:)` and `init(shape:scalars:)` are linear. Resolves #510 and #511.
1 parent 176b6ff commit f0a6da6

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

Sources/TensorFlow/Core/Tensor.swift

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ public extension Tensor {
169169
}
170170

171171
extension Tensor where Scalar: TensorFlowFloatingPoint {
172-
@usableFromInline
172+
@inlinable
173173
func _vjpScalars() -> (value: [Scalar], pullback: (Array<Scalar>.TangentVector) -> Tensor) {
174174
(value: scalars, pullback: { [shape = self.shape] v in
175175
Tensor(shape: shape, scalars: v.base)
@@ -200,6 +200,7 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
200200
public extension Tensor {
201201
/// Creates a 1D tensor from scalars.
202202
@inlinable
203+
@differentiable(vjp: _vjpInit(_:) where Scalar: TensorFlowFloatingPoint)
203204
init(_ scalars: [Scalar]) {
204205
self.init(shape: [scalars.count], scalars: scalars)
205206
}
@@ -226,6 +227,7 @@ public extension Tensor {
226227
/// - scalars: The scalar contents of the tensor.
227228
/// - Precondition: The product of the dimensions of the shape must equal the number of scalars.
228229
@inlinable
230+
@differentiable(vjp: _vjpInit(shape:scalars:) where Scalar: TensorFlowFloatingPoint)
229231
init(shape: TensorShape, scalars: [Scalar]) {
230232
precondition(shape.contiguousSize == scalars.count,
231233
"""
@@ -284,6 +286,22 @@ public extension Tensor {
284286
}
285287
}
286288

289+
extension Tensor where Scalar: TensorFlowFloatingPoint {
290+
@inlinable
291+
static func _vjpInit(_ scalars: [Scalar]) -> (
292+
value: Tensor, pullback: (Tensor) -> Array<Scalar>.TangentVector
293+
) {
294+
(value: Tensor(scalars), pullback: { v in Array<Scalar>.TangentVector(v.scalars) })
295+
}
296+
297+
@inlinable
298+
static func _vjpInit(shape: TensorShape, scalars: [Scalar]) -> (
299+
value: Tensor, pullback: (Tensor) -> Array<Scalar>.TangentVector
300+
) {
301+
(value: Tensor(scalars), pullback: { v in Array<Scalar>.TangentVector(v.scalars) })
302+
}
303+
}
304+
287305
// Background story on `TensorElementLiteral` and why it's necessary:
288306
//
289307
// Very importantly, we want users to be able to implicitly convert an array

Tests/TensorFlowTests/TensorAutoDiffTests.swift

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,20 @@ final class TensorAutoDiffTests: XCTestCase {
6363
XCTAssertEqual(grad, Tensor([1, 1]))
6464
}
6565

66+
func testInitFromScalars() {
67+
let grad = gradient(at: [3.0, 4.0]) { x in
68+
Tensor(x).sum()
69+
}
70+
XCTAssertEqual(grad, Array<Double>.TangentVector([1, 1]))
71+
}
72+
73+
func testInitFromScalarsWithShape() {
74+
let grad = gradient(at: [3.0, 4.0]) { x in
75+
Tensor(shape: [1, 2, 1, 1], scalars: x).sum()
76+
}
77+
XCTAssertEqual(grad, Array<Double>.TangentVector([1, 1]))
78+
}
79+
6680
func testPlus() {
6781
func f(a: Tensor<Float>, b: Tensor<Float>) -> Tensor<Float> { a + b }
6882
XCTAssertTrue((Tensor(1), Tensor(1)) == gradient(at: Tensor(0), Tensor(0), in: f))
@@ -516,6 +530,9 @@ final class TensorAutoDiffTests: XCTestCase {
516530
("testGenericGrad", testGenericGrad),
517531
("testScalarGenericGrad", testScalarGenericGrad),
518532
("testScalarized", testScalarized),
533+
("testScalars", testScalars),
534+
("testInitFromScalars", testInitFromScalars),
535+
("testInitFromScalarsWithShape", testInitFromScalarsWithShape),
519536
("testPlus", testPlus),
520537
("testSubtract", testSubtract),
521538
("testMultiply", testMultiply),

0 commit comments

Comments
 (0)