Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit eb996cd

Browse files
dominikgrewerxwei
authored andcommitted
Implement 'isAlmostEqual(to:tolerance:)' method for floating-point tensors. (#348)
* Implement `isAlmostEqual(to:tolerance:)` method for floating-point tensors. * Rename `elementsApproximatelyEqual(_:tolerance:)` to `elementsAlmostEqual(_:tolerance:)`.
1 parent 8d0b1d8 commit eb996cd

File tree

2 files changed

+31
-5
lines changed

2 files changed

+31
-5
lines changed

Sources/TensorFlow/Operators/Comparison.swift

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -207,15 +207,15 @@ public extension Tensor where Scalar: Equatable {
207207

208208
// TODO: infix operator ≈: ComparisonPrecedence
209209

210-
public extension Tensor where Scalar: FloatingPoint & Equatable {
210+
public extension Tensor where Scalar: TensorFlowFloatingPoint & Equatable {
211211
/// Returns a tensor of Boolean values indicating whether the elements of `self` are
212212
/// approximately equal to those of `other`.
213213
@inlinable
214-
func elementsApproximatelyEqual(
214+
func elementsAlmostEqual(
215215
_ other: Tensor,
216-
tolerance: Double = 0.00001
216+
tolerance: Scalar = Scalar.ulpOfOne.squareRoot()
217217
) -> Tensor<Bool> {
218-
return Raw.approximateEqual(self, other, tolerance: tolerance)
218+
return Raw.approximateEqual(self, other, tolerance: Double(tolerance))
219219
}
220220
}
221221

@@ -227,3 +227,16 @@ public extension StringTensor {
227227
return Raw.equal(self, other)
228228
}
229229
}
230+
231+
public extension Tensor where Scalar: TensorFlowFloatingPoint {
232+
/// Returns `true` if tensors are of equal shape and all pairs of scalars are approximately
233+
/// equal.
234+
@inlinable
235+
func isAlmostEqual(
236+
to other: Tensor,
237+
tolerance: Scalar = Scalar.ulpOfOne.squareRoot()
238+
) -> Bool {
239+
return self.shape == other.shape &&
240+
self.elementsAlmostEqual(other, tolerance: tolerance).all()
241+
}
242+
}

Tests/TensorFlowTests/OperatorTests/ComparisonTests.swift

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,21 @@ final class ComparisonOperatorTests: XCTestCase {
2828
XCTAssertTrue(x < y)
2929
}
3030

31+
func testIsAlmostEqual() {
32+
let x = Tensor<Float>([0.1, 0.2, 0.3, 0.4])
33+
let y = Tensor<Float>([0.0999, 0.20001, 0.2998, 0.4])
34+
let z = Tensor<Float>([0.0999, 0.20001, 0.2998, 0.3])
35+
36+
XCTAssertTrue(x.isAlmostEqual(to: y, tolerance: 0.01))
37+
XCTAssertFalse(x.isAlmostEqual(to: z))
38+
39+
let nanInf = Tensor<Float>([.nan, .infinity])
40+
XCTAssertFalse(nanInf.isAlmostEqual(to: nanInf))
41+
}
42+
3143
static var allTests = [
3244
("testElementwiseComparison", testElementwiseComparison),
33-
("testLexicographicalComparison", testLexicographicalComparison)
45+
("testLexicographicalComparison", testLexicographicalComparison),
46+
("testIsAlmostEqual", testIsAlmostEqual),
3447
]
3548
}

0 commit comments

Comments
 (0)