This repository was archived by the owner on Jul 1, 2023. It is now read-only.
File tree Expand file tree Collapse file tree 2 files changed +31
-5
lines changed
Sources/TensorFlow/Operators
Tests/TensorFlowTests/OperatorTests Expand file tree Collapse file tree 2 files changed +31
-5
lines changed Original file line number Diff line number Diff line change @@ -207,15 +207,15 @@ public extension Tensor where Scalar: Equatable {
207
207
208
208
// TODO: infix operator ≈: ComparisonPrecedence
209
209
210
- public extension Tensor where Scalar: FloatingPoint & Equatable {
210
+ public extension Tensor where Scalar: TensorFlowFloatingPoint & Equatable {
211
211
/// Returns a tensor of Boolean values indicating whether the elements of `self` are
212
212
/// approximately equal to those of `other`.
213
213
@inlinable
214
- func elementsApproximatelyEqual (
214
+ func elementsAlmostEqual (
215
215
_ other: Tensor ,
216
- tolerance: Double = 0.00001
216
+ tolerance: Scalar = Scalar . ulpOfOne . squareRoot ( )
217
217
) -> Tensor < Bool > {
218
- return Raw . approximateEqual ( self , other, tolerance: tolerance)
218
+ return Raw . approximateEqual ( self , other, tolerance: Double ( tolerance) )
219
219
}
220
220
}
221
221
@@ -227,3 +227,16 @@ public extension StringTensor {
227
227
return Raw . equal ( self , other)
228
228
}
229
229
}
230
+
231
+ public extension Tensor where Scalar: TensorFlowFloatingPoint {
232
+ /// Returns `true` if tensors are of equal shape and all pairs of scalars are approximately
233
+ /// equal.
234
+ @inlinable
235
+ func isAlmostEqual(
236
+ to other: Tensor ,
237
+ tolerance: Scalar = Scalar . ulpOfOne. squareRoot ( )
238
+ ) -> Bool {
239
+ return self . shape == other. shape &&
240
+ self . elementsAlmostEqual ( other, tolerance: tolerance) . all ( )
241
+ }
242
+ }
Original file line number Diff line number Diff line change @@ -28,8 +28,21 @@ final class ComparisonOperatorTests: XCTestCase {
28
28
XCTAssertTrue ( x < y)
29
29
}
30
30
31
+ func testIsAlmostEqual( ) {
32
+ let x = Tensor < Float > ( [ 0.1 , 0.2 , 0.3 , 0.4 ] )
33
+ let y = Tensor < Float > ( [ 0.0999 , 0.20001 , 0.2998 , 0.4 ] )
34
+ let z = Tensor < Float > ( [ 0.0999 , 0.20001 , 0.2998 , 0.3 ] )
35
+
36
+ XCTAssertTrue ( x. isAlmostEqual ( to: y, tolerance: 0.01 ) )
37
+ XCTAssertFalse ( x. isAlmostEqual ( to: z) )
38
+
39
+ let nanInf = Tensor < Float > ( [ . nan, . infinity] )
40
+ XCTAssertFalse ( nanInf. isAlmostEqual ( to: nanInf) )
41
+ }
42
+
31
43
static var allTests = [
32
44
( " testElementwiseComparison " , testElementwiseComparison) ,
33
- ( " testLexicographicalComparison " , testLexicographicalComparison)
45
+ ( " testLexicographicalComparison " , testLexicographicalComparison) ,
46
+ ( " testIsAlmostEqual " , testIsAlmostEqual) ,
34
47
]
35
48
}
You can’t perform that action at this time.
0 commit comments