Skip to content

Commit aa237f8

Browse files
committed
[ET-SDK] Enable data_is_close for Half tensor
Currently, if we use ET-SDK's `VerifyResultWithBundledExpectedOutput()` on `Half` tensors it will return true only if all elements are exactly equal. This change allows the common behavior to check that all elements are within the specified `rtol`/`atol`. Differential Revision: [D58018861](https://our.internmc.facebook.com/intern/diff/D58018861/) [ghstack-poisoned]
1 parent 68c822f commit aa237f8

File tree

1 file changed

+50
-21
lines changed

1 file changed

+50
-21
lines changed

sdk/bundled_program/bundled_program.cpp

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,37 @@ TensorImpl impl_like(bundled_program_flatbuffer::Tensor* bundled_tensor) {
7373
#endif
7474

7575
/**
76-
* Returns true if the two arrays are close according to the description on
76+
* Returns true if the two elements are close according to the description on
7777
* `tensors_are_close()`.
7878
*
7979
* T must be a floating point type. Non-floating point data should be compared
8080
* directly.
8181
*/
82+
template <
83+
typename T,
84+
typename = std::enable_if_t<std::is_floating_point<T>::value>>
85+
bool elem_is_close(const T ai, const T bi, double rtol, double atol) {
86+
if (std::isnan(ai) && std::isnan(bi)) {
87+
// NaN == NaN
88+
} else if (
89+
!std::isfinite(ai) && !std::isfinite(bi) && ((ai > 0) == (bi > 0))) {
90+
// -Inf == -Inf
91+
// +Inf == +Inf
92+
} else if (rtol == 0 && atol == 0) {
93+
// Exact comparison; avoid unnecessary math.
94+
if (ai != bi) {
95+
return false;
96+
}
97+
} else {
98+
auto allowed_error = atol + std::abs(rtol * bi);
99+
auto actual_error = std::abs(ai - bi);
100+
if (!std::isfinite(actual_error) || actual_error > allowed_error) {
101+
return false;
102+
}
103+
}
104+
return true;
105+
}
106+
82107
template <
83108
typename T,
84109
typename = std::enable_if_t<std::is_floating_point<T>::value>>
@@ -89,26 +114,23 @@ bool data_is_close(
89114
double rtol,
90115
double atol) {
91116
for (size_t i = 0; i < numel; i++) {
92-
const auto ai = a[i];
93-
const auto bi = b[i];
94-
95-
if (std::isnan(ai) && std::isnan(bi)) {
96-
// NaN == NaN
97-
} else if (
98-
!std::isfinite(ai) && !std::isfinite(bi) && ((ai > 0) == (bi > 0))) {
99-
// -Inf == -Inf
100-
// +Inf == +Inf
101-
} else if (rtol == 0 && atol == 0) {
102-
// Exact comparison; avoid unnecessary math.
103-
if (ai != bi) {
104-
return false;
105-
}
106-
} else {
107-
auto allowed_error = atol + std::abs(rtol * bi);
108-
auto actual_error = std::abs(ai - bi);
109-
if (!std::isfinite(actual_error) || actual_error > allowed_error) {
110-
return false;
111-
}
117+
if (!elem_is_close(a[i], b[i], rtol, atol)) {
118+
return false;
119+
}
120+
}
121+
return true;
122+
}
123+
124+
bool data_is_close_half(
125+
const Half* a,
126+
const Half* b,
127+
size_t numel,
128+
double rtol,
129+
double atol) {
130+
for (size_t i = 0; i < numel; i++) {
131+
if (!elem_is_close(
132+
static_cast<double>(a[i]), static_cast<double>(b[i]), rtol, atol)) {
133+
return false;
112134
}
113135
}
114136
return true;
@@ -177,6 +199,13 @@ bool tensors_are_close(
177199
bundled_tensor.numel(),
178200
rtol,
179201
atol);
202+
} else if (bundled_tensor.scalar_type() == ScalarType::Half) {
203+
return data_is_close_half(
204+
bundled_tensor.const_data_ptr<Half>(),
205+
method_output_tensor.const_data_ptr<Half>(),
206+
bundled_tensor.numel(),
207+
rtol,
208+
atol);
180209
} else {
181210
// Non-floating-point types can be compared bitwise.
182211
return memcmp(

0 commit comments

Comments
 (0)