Skip to content

Commit 03f2426

Browse files
Michael Gschwindfacebook-github-bot
authored andcommitted
Enable tensor closeness check for additional types
Summary: Enable tensor closeness check for additional types Differential Revision: D54538932
1 parent 0977924 commit 03f2426

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

runtime/core/exec_aten/testing_util/tensor_util.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,7 @@ namespace {
3232
* T must be a floating point type. Non-floating point data should be compared
3333
* directly.
3434
*/
35-
template <
36-
typename T,
37-
typename = std::enable_if_t<std::is_floating_point<T>::value>>
35+
template <typename T>
3836
bool data_is_close(
3937
const T* a,
4038
const T* b,
@@ -110,6 +108,13 @@ bool tensors_are_close(
110108
a.numel(),
111109
rtol,
112110
atol);
111+
} else if (a.scalar_type() == ScalarType::Half) {
112+
return data_is_close<Half>(
113+
a.const_data_ptr<Half>(),
114+
b.const_data_ptr<Half>(),
115+
a.numel(),
116+
rtol,
117+
atol);
113118
} else {
114119
// Non-floating-point types can be compared bitwise.
115120
return memcmp(a.const_data_ptr(), b.const_data_ptr(), a.nbytes()) == 0;
@@ -150,6 +155,13 @@ bool tensor_data_is_close(
150155
a.numel(),
151156
rtol,
152157
atol);
158+
} else if (a.scalar_type() == ScalarType::Half) {
159+
return data_is_close<Half>(
160+
a.const_data_ptr<Half>(),
161+
b.const_data_ptr<Half>(),
162+
a.numel(),
163+
rtol,
164+
atol);
153165
} else {
154166
// Non-floating-point types can be compared bitwise.
155167
return memcmp(a.const_data_ptr(), b.const_data_ptr(), a.nbytes()) == 0;

0 commit comments

Comments
 (0)