Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 0862d9b

Browse files
authored
Error message for non-scalar Tensor case. (#278)
When taking the gradient of a function, the function must return a scalar. Because Tensors do not encode their shape or rank, this is checked dynamically at runtime. Previously, we did not provide an explicit error message, resulting in the following printed to the console: ``` Precondition failed: file /swift-base/tensorflow-swift-apis/Sources/TensorFlow/Core/DifferentialOperators.swift, line 68 Current stack trace: 0 libswiftCore.so 0x00007f1629f45830 swift_reportError + 50 1 libswiftCore.so 0x00007f1629fb45d0 _swift_stdlib_reportFatalErrorInFile + 115 2 libswiftCore.so 0x00007f1629edca7e <unavailable> + 3734142 3 libswiftCore.so 0x00007f1629edcbf7 <unavailable> + 3734519 4 libswiftCore.so 0x00007f1629cab10d <unavailable> + 1433869 5 libswiftCore.so 0x00007f1629eb1a88 <unavailable> + 3558024 6 libswiftCore.so 0x00007f1629caa569 <unavailable> + 1430889 9 repl_swift 0x0000000000400490 <unavailable> + 1168 11 libswiftCore.so 0x00007f1629caa569 <unavailable> + 1430889 14 repl_swift 0x0000000000400490 <unavailable> + 1168 16 libswiftCore.so 0x00007f1629caa569 <unavailable> + 1430889 17 libswiftTensorFlow.so 0x00007f16272250b0 <unavailable> + 2543792 18 libswiftTensorFlow.so 0x00007f1627098280 checkOk(_:file:line:) + 467 19 libswiftTensorFlow.so 0x00007f162709f480 TFE_Op.evaluateUnsafe() + 506 20 libswiftTensorFlow.so 0x00007f162709fcf0 TFE_Op.execute<A>(_:) + 132 21 libswiftTensorFlow.so 0x00007f16270a8984 <unavailable> + 985476 22 libswiftTensorFlow.so 0x00007f162713dc20 static Raw.matMul<A>(_:_:transposeA:transposeB:) + 1221 23 libswiftTensorFlow.so 0x00007f1627293c00 matmul<A>(_:transposed:_:transposed:) + 1427 24 libswiftTensorFlow.so 0x00007f16272f6210 _vjpMatmul<A>(_:transposed:_:transposed:) + 201 25 libswiftTensorFlow.so 0x00007f16273564b4 <unavailable> + 3794100 26 libswiftTensorFlow.so 0x00007f162731f960 AD__$s10TensorFlow5DenseV14callAsFunctionyAA0A0VyxGAGF__vjp_src_0_wrt_0_1 + 680 31 repl_swift 0x0000000000400490 <unavailable> + 1168 Current stack trace: frame #4: 0x00007f162b926120 $__lldb_expr57`main at <Cell 8>:1 ``` This change provides the user with a bit more context and prints out the shape their computation produced to help them debug what might have gone wrong.
1 parent a265f8f commit 0862d9b

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

Sources/TensorFlow/Core/DifferentialOperators.swift

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ public extension Differentiable {
2929
in f: @differentiable (Self) -> Tensor<R>
3030
) -> (value: Tensor<R>, gradient: TangentVector) {
3131
let (y, pb) = self.valueWithPullback(in: f)
32-
precondition(y.rank == 0)
32+
precondition(y.rank == 0, """
33+
The function being differentiated produced a tensor with shape \(y.shape). \
34+
You can only compute the gradient of functions that return scalar values.
35+
""")
3336
return (y, pb(Tensor<R>(1)))
3437
}
3538

@@ -47,7 +50,10 @@ public extension Differentiable {
4750
in f: @differentiable (Self, T) -> Tensor<R>
4851
) -> (value: Tensor<R>, gradient: (TangentVector, T.TangentVector)) {
4952
let (y, pb) = self.valueWithPullback(at: x, in: f)
50-
precondition(y.rank == 0)
53+
precondition(y.rank == 0, """
54+
The function being differentiated produced a tensor with shape \(y.shape). \
55+
You can only compute the gradient of functions that return scalar values.
56+
""")
5157
return (y, pb(Tensor<R>(1)))
5258
}
5359
}
@@ -65,7 +71,10 @@ public func valueWithGradient<T, R>(
6571
) -> (value: Tensor<R>, gradient: T.TangentVector)
6672
where T: Differentiable, R: TensorFlowFloatingPoint {
6773
let (y, pullback) = valueWithPullback(at: x, in: f)
68-
precondition(y.rank == 0)
74+
precondition(y.rank == 0, """
75+
The function being differentiated produced a tensor with shape \(y.shape). \
76+
You can only compute the gradient of functions that return scalar values.
77+
""")
6978
return (y, pullback(Tensor<R>(1)))
7079
}
7180

@@ -77,7 +86,10 @@ public func valueWithGradient<T, U, R>(
7786
) -> (value: Tensor<R>, gradient: (T.TangentVector, U.TangentVector))
7887
where T: Differentiable, U: Differentiable, R: TensorFlowFloatingPoint {
7988
let (y, pullback) = valueWithPullback(at: x, y, in: f)
80-
precondition(y.rank == 0)
89+
precondition(y.rank == 0, """
90+
The function being differentiated produced a tensor with shape \(y.shape). \
91+
You can only compute the gradient of functions that return scalar values.
92+
""")
8193
return (y, pullback(Tensor<R>(1)))
8294
}
8395

0 commit comments

Comments
 (0)