Skip to content

Commit a3527a7

Browse files
dan-zhengdan12411
authored andcommitted
Fix implementation of root(Tensor, Int). (tensorflow#238)
Fix test failures regarding `root(Tensor, Int)` erroneous returning NaNs. Inspired by the implementation of `root` for `FloatingPoint` types in swift/stdlib/public/core/MathFunctions.swift.gyb.
1 parent 024062d commit a3527a7

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

Sources/TensorFlow/Operators/Math.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -876,7 +876,7 @@ public func pow<T: TensorFlowFloatingPoint>(_ x: Tensor<T>, _ n: Int) -> Tensor<
876876
@inlinable
877877
// @differentiable
878878
public func root<T: TensorFlowFloatingPoint>(_ x: Tensor<T>, _ n: Int) -> Tensor<T> {
879-
pow(x, Tensor(T(1) / T(n)))
879+
sign(x) * pow(abs(x), Tensor(T(1) / T(n)))
880880
}
881881

882882
/// Computes the element-wise maximum of two tensors.

Tests/TensorFlowTests/Helpers.swift

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ internal func assertEqual<T: TensorFlowFloatingPoint>(
2121
) {
2222
for (x, y) in zip(x, y) {
2323
if x.isNaN || y.isNaN {
24-
XCTAssertTrue(x.isNaN && y.isNaN, message, file: file, line: line)
24+
XCTAssertTrue(x.isNaN && y.isNaN,
25+
"\(x) is not equal to \(y) - \(message)",
26+
file: file, line: line)
2527
continue
2628
}
2729
XCTAssertEqual(x, y, accuracy: accuracy, message, file: file, line: line)

0 commit comments

Comments
 (0)