Skip to content

Commit 6c0dbd9

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Add EXPECT_CLOSE macros with tol arguments
Reviewed By: SS-JIA Differential Revision: D48483563 fbshipit-source-id: 67db90b068aeb68f3f3662aa877d6079024a6ad8
1 parent cfd36a6 commit 6c0dbd9

File tree

2 files changed

+143
-0
lines changed

2 files changed

+143
-0
lines changed

runtime/core/exec_aten/testing_util/tensor_util.h

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,18 @@ MATCHER_P(IsCloseTo, other, "") {
117117
return tensors_are_close(arg, other);
118118
}
119119

120+
/**
121+
* Lets gtest users write
122+
* `EXPECT_THAT(tensor1, IsCloseToWithTol(tensor2, rtol, atol))`
123+
* or `EXPECT_THAT(tensor1, Not(IsCloseToWithTol(tensor2, rtol, atol)))`.
124+
*
125+
* See also `EXPECT_TENSOR_CLOSE_WITH_TOL()` and
126+
* `EXPECT_TENSOR_NOT_CLOSE_WITH_TOL()`.
127+
*/
128+
MATCHER_P3(IsCloseToWithTol, other, rtol, atol, "") {
129+
return tensors_are_close(arg, other, rtol, atol);
130+
}
131+
120132
/**
121133
* Lets gtest users write `EXPECT_THAT(tensor1, IsEqualTo(tensor2))` or
122134
* `EXPECT_THAT(tensor1, Not(IsEqualTo(tensor2)))`.
@@ -136,6 +148,19 @@ MATCHER_P(IsEqualTo, other, "") {
136148
MATCHER_P(IsDataCloseTo, other, "") {
137149
return tensor_data_is_close(arg, other);
138150
}
151+
152+
/**
153+
* Lets gtest users write
154+
* `EXPECT_THAT(tensor1, IsDataCloseToWithTol(tensor2, rtol, atol))`
155+
* or `EXPECT_THAT(tensor1, Not(IsDataCloseToWithTol(tensor2, rtol, atol)))`.
156+
*
157+
* See also `EXPECT_TENSOR_CLOSE_WITH_TOL()` and
158+
* `EXPECT_TENSOR_NOT_CLOSE_WITH_TOL()`.
159+
*/
160+
MATCHER_P3(IsDataCloseToWithTol, other, rtol, atol, "") {
161+
return tensor_data_is_close(arg, other, rtol, atol);
162+
}
163+
139164
/**
140165
* Lets gtest users write `EXPECT_THAT(tensor1, IsDataEqualTo(tensor2))` or
141166
* `EXPECT_THAT(tensor1, Not(IsDataEqualTo(tensor2)))`.
@@ -205,6 +230,23 @@ MATCHER_P(IsListEqualTo, other, "") {
205230
#define ASSERT_TENSOR_NOT_CLOSE(t1, t2) \
206231
ASSERT_THAT((t1), ::testing::Not(torch::executor::testing::IsCloseTo(t2)))
207232

233+
#define EXPECT_TENSOR_CLOSE_WITH_TOL(t1, t2, rtol, atol) \
234+
EXPECT_THAT( \
235+
(t1), ::torch::executor::testing::IsCloseToWithTol(t2, rtol, atol))
236+
#define EXPECT_TENSOR_NOT_CLOSE_WITH_TOL(t1, t2, rtol, atol) \
237+
EXPECT_THAT( \
238+
(t1), \
239+
::testing::Not( \
240+
torch::executor::testing::IsCloseToWithTol(t2, rtol, atol)))
241+
#define ASSERT_TENSOR_CLOSE_WITH_TOL(t1, t2, rtol, atol) \
242+
ASSERT_THAT( \
243+
(t1), ::torch::executor::testing::IsCloseToWithTol(t2, rtol, atol))
244+
#define ASSERT_TENSOR_NOT_CLOSE_WITH_TOL(t1, t2, rtol, atol) \
245+
ASSERT_THAT( \
246+
(t1), \
247+
::testing::Not( \
248+
torch::executor::testing::IsCloseToWithTol(t2, rtol, atol)))
249+
208250
#define EXPECT_TENSOR_DATA_EQ(t1, t2) \
209251
EXPECT_THAT((t1), ::torch::executor::testing::IsDataEqualTo(t2))
210252
#define EXPECT_TENSOR_DATA_NE(t1, t2) \
@@ -223,6 +265,23 @@ MATCHER_P(IsListEqualTo, other, "") {
223265
#define ASSERT_TENSOR_DATA_NOT_CLOSE(t1, t2) \
224266
ASSERT_THAT((t1), ::testing::Not(torch::executor::testing::IsDataCloseTo(t2)))
225267

268+
#define EXPECT_TENSOR_DATA_CLOSE_WITH_TOL(t1, t2, rtol, atol) \
269+
EXPECT_THAT( \
270+
(t1), ::torch::executor::testing::IsDataCloseToWithTol(t2, rtol, atol))
271+
#define EXPECT_TENSOR_DATA_NOT_CLOSE_WITH_TOL(t1, t2, rtol, atol) \
272+
EXPECT_THAT( \
273+
(t1), \
274+
::testing::Not( \
275+
torch::executor::testing::IsDataCloseToWithTol(t2, rtol, atol)))
276+
#define ASSERT_TENSOR_DATA_CLOSE_WITH_TOL(t1, t2, rtol, atol) \
277+
ASSERT_THAT( \
278+
(t1), ::torch::executor::testing::IsDataCloseToWithTol(t2, rtol, atol))
279+
#define ASSERT_TENSOR_DATA_NOT_CLOSE_WITH_TOL(t1, t2, rtol, atol) \
280+
ASSERT_THAT( \
281+
(t1), \
282+
::testing::Not( \
283+
torch::executor::testing::IsDataCloseToWithTol(t2, rtol, atol)))
284+
226285
/*
227286
* Helpers for comparing lists of Tensors.
228287
*/

runtime/core/exec_aten/testing_util/test/tensor_util_test.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,48 @@ TEST(TensorUtilTest, DoubleAndInfinitNanElementsAreCloseAndEqual) {
420420
EXPECT_TENSORS_CLOSE_AND_EQUAL(a, b);
421421
}
422422

423+
// Testing closeness with tolerances
424+
425+
TEST(TensorUtilTest, TensorsAreCloseWithTol) {
426+
TensorFactory<ScalarType::Float> tf;
427+
TensorFactory<ScalarType::Double> td;
428+
429+
// Create two tensors with identical shape and dtype, but different data.
430+
Tensor af = tf.make(/*sizes=*/{2, 2}, /*data=*/{1.0, 2.099999, 0.0, -0.05});
431+
Tensor bf = tf.make(/*sizes=*/{2, 2}, /*data=*/{1.099999, 2.0, 0.05, 0.0});
432+
433+
EXPECT_TENSOR_CLOSE_WITH_TOL(af, bf, 0.0, 0.1);
434+
435+
// Create two tensors with identical shape and dtype, but different data.
436+
Tensor ad = td.make(/*sizes=*/{2, 2}, /*data=*/{1.099, 2.199, NAN, -9.0});
437+
Tensor bd = td.make(/*sizes=*/{2, 2}, /*data=*/{1.0, 2.0, NAN, -10.0});
438+
439+
EXPECT_TENSOR_CLOSE_WITH_TOL(ad, bd, 0.1, 0.0);
440+
}
441+
442+
TEST(TensorUtilTest, TensorsAreNotCloseWithTol) {
443+
TensorFactory<ScalarType::Float> tf;
444+
TensorFactory<ScalarType::Double> td;
445+
446+
// Create two tensors with identical shape and dtype, but different data.
447+
Tensor af = tf.make(/*sizes=*/{3}, /*data=*/{1.00, NAN, -10.0});
448+
Tensor bf = tf.make(/*sizes=*/{3}, /*data=*/{1.11, NAN, -10.0});
449+
450+
EXPECT_TENSOR_NOT_CLOSE_WITH_TOL(af, bf, 0.0, 0.1);
451+
452+
// Create two tensors with identical shape and dtype, but different data.
453+
Tensor ad = td.make(/*sizes=*/{3}, /*data=*/{1.0, 0.0, -10.0});
454+
Tensor bd = td.make(/*sizes=*/{3}, /*data=*/{1.0, 0.0, -9.0});
455+
456+
EXPECT_TENSOR_NOT_CLOSE_WITH_TOL(ad, bd, 0.1, 0.0);
457+
458+
// Create two tensors with identical shape and dtype, but different data.
459+
ad = tf.make(/*sizes=*/{3}, /*data=*/{1.0, 2.0, 0.00001});
460+
bd = tf.make(/*sizes=*/{3}, /*data=*/{1.0, 2.0, 0.0});
461+
462+
EXPECT_TENSOR_NOT_CLOSE_WITH_TOL(ad, bd, 0.1, 0.0);
463+
}
464+
423465
//
424466
// Tests for shape-agnostic data equality.
425467
//
@@ -585,6 +627,48 @@ TEST(TensorUtilTest, TensorDataMismatched) {
585627
EXPECT_TENSORS_DATA_NOT_CLOSE_OR_EQUAL(t_zero_dim, t_empty);
586628
}
587629

630+
// Testing data closeness with tolerances
631+
632+
TEST(TensorUtilTest, TensorDataCloseWithTol) {
633+
TensorFactory<ScalarType::Float> tf;
634+
TensorFactory<ScalarType::Double> td;
635+
636+
// Create two tensors with identical shape and dtype, but different data.
637+
Tensor af = tf.make(/*sizes=*/{4, 1}, /*data=*/{1.0, 2.099, 0.0, -0.05});
638+
Tensor bf = tf.make(/*sizes=*/{2, 2}, /*data=*/{1.099, 2.0, 0.05, 0.0});
639+
640+
EXPECT_TENSOR_DATA_CLOSE_WITH_TOL(af, bf, 0.0, 0.1);
641+
642+
// Create two tensors with identical shape and dtype, but different data.
643+
Tensor ad = td.make(/*sizes=*/{2, 2}, /*data=*/{1.099, 2.199, NAN, -9.0});
644+
Tensor bd = td.make(/*sizes=*/{4}, /*data=*/{1.0, 2.0, NAN, -10.0});
645+
646+
EXPECT_TENSOR_DATA_CLOSE_WITH_TOL(ad, bd, 0.1, 0.0);
647+
}
648+
649+
TEST(TensorUtilTest, TensorDataNotCloseWithTol) {
650+
TensorFactory<ScalarType::Float> tf;
651+
TensorFactory<ScalarType::Double> td;
652+
653+
// Create two tensors with identical shape and dtype, but different data.
654+
Tensor af = tf.make(/*sizes=*/{3}, /*data=*/{1.00, 0.0, -10.0});
655+
Tensor bf = tf.make(/*sizes=*/{3, 1}, /*data=*/{1.11, 0.0, -10.0});
656+
657+
EXPECT_TENSOR_DATA_NOT_CLOSE_WITH_TOL(af, bf, 0.0, 0.1);
658+
659+
// Create two tensors with identical shape and dtype, but different data.
660+
Tensor ad = td.make(/*sizes=*/{2, 2}, /*data=*/{1.0, 0.0, -10.0, 0.0});
661+
Tensor bd = td.make(/*sizes=*/{4}, /*data=*/{1.0, 0.0, -9.0, 0.0});
662+
663+
EXPECT_TENSOR_DATA_NOT_CLOSE_WITH_TOL(ad, bd, 0.1, 0.0);
664+
665+
// Create two tensors with identical shape and dtype, but different data.
666+
ad = tf.make(/*sizes=*/{1, 4}, /*data=*/{1.0, 2.0, NAN, 0.00001});
667+
bd = tf.make(/*sizes=*/{2, 2}, /*data=*/{1.0, 2.0, NAN, 0.0});
668+
669+
EXPECT_TENSOR_DATA_NOT_CLOSE_WITH_TOL(ad, bd, 0.1, 0.0);
670+
}
671+
588672
//
589673
// Tests for TensorList helpers.
590674
//

0 commit comments

Comments
 (0)