Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Implement isAlmostEqual method for floating-point tensors. #348

Merged
merged 3 commits into from
Jul 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions Sources/TensorFlow/Operators/Comparison.swift
Original file line number Diff line number Diff line change
Expand Up @@ -207,15 +207,15 @@ public extension Tensor where Scalar: Equatable {

// TODO: infix operator ≈: ComparisonPrecedence

public extension Tensor where Scalar: FloatingPoint & Equatable {
public extension Tensor where Scalar: TensorFlowFloatingPoint & Equatable {
/// Returns a tensor of Boolean values indicating whether the elements of `self` are
/// approximately equal to those of `other`.
@inlinable
func elementsApproximatelyEqual(
func elementsAlmostEqual(
_ other: Tensor,
tolerance: Double = 0.00001
tolerance: Scalar = Scalar.ulpOfOne.squareRoot()
) -> Tensor<Bool> {
return Raw.approximateEqual(self, other, tolerance: tolerance)
return Raw.approximateEqual(self, other, tolerance: Double(tolerance))
}
}

Expand All @@ -227,3 +227,16 @@ public extension StringTensor {
return Raw.equal(self, other)
}
}

public extension Tensor where Scalar: TensorFlowFloatingPoint {
/// Returns `true` if tensors are of equal shape and all pairs of scalars are approximately
/// equal.
@inlinable
func isAlmostEqual(
to other: Tensor,
tolerance: Scalar = Scalar.ulpOfOne.squareRoot()
) -> Bool {
return self.shape == other.shape &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use Tensor.elementsApproximatelyEqual(_:tolerance:) directly to make this faster. Also, maybe we should also add the to argument label to all these other functions that check for element-wise equality.

Copy link
Contributor

@rxwei rxwei Jul 10, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, maybe we should also add the to argument label to all these other functions that check for element-wise equality.

This elementsApproximatelyEqual(_:tolerance:) method name is a sentence rather than a verb phrase, and "equal" here is a transitive verb. Same for Sequence.elementsEqual(_:).

We should change "approximately" to "almost" for consistency with isAlmostEqual(to:).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've renamed elementsApproximatelyEqual to elementsAlmostEqual and this is now called from isAlmostEqual. For consistency I also changed the condition on Scalar to now conform to TensorFlowFloatingPoint so that it's convertible to Double. Ptal.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we want to have this check for mismatching shapes here. @rxwei wouldn't a precondition be better? I believe it would be misleading for this to return false when the shapes are mismatching as it should be expected that it operates over tensors with matching shapes.

Another issue is that of broadcasting. If we want to support broadcasting for this function we should remove that check altogether.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

self.elementsAlmostEqual(other, tolerance: tolerance).all()
}
}
15 changes: 14 additions & 1 deletion Tests/TensorFlowTests/OperatorTests/ComparisonTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,21 @@ final class ComparisonOperatorTests: XCTestCase {
XCTAssertTrue(x < y)
}

func testIsAlmostEqual() {
let x = Tensor<Float>([0.1, 0.2, 0.3, 0.4])
let y = Tensor<Float>([0.0999, 0.20001, 0.2998, 0.4])
let z = Tensor<Float>([0.0999, 0.20001, 0.2998, 0.3])

XCTAssertTrue(x.isAlmostEqual(to: y, tolerance: 0.01))
XCTAssertFalse(x.isAlmostEqual(to: z))

let nanInf = Tensor<Float>([.nan, .infinity])
XCTAssertFalse(nanInf.isAlmostEqual(to: nanInf))
}

static var allTests = [
("testElementwiseComparison", testElementwiseComparison),
("testLexicographicalComparison", testLexicographicalComparison)
("testLexicographicalComparison", testLexicographicalComparison),
("testIsAlmostEqual", testIsAlmostEqual),
]
}