File tree Expand file tree Collapse file tree 1 file changed +15
-3
lines changed
runtime/core/exec_aten/testing_util Expand file tree Collapse file tree 1 file changed +15
-3
lines changed Original file line number Diff line number Diff line change @@ -32,9 +32,7 @@ namespace {
32
32
* T must be a floating point type. Non-floating point data should be compared
33
33
* directly.
34
34
*/
35
- template <
36
- typename T,
37
- typename = std::enable_if_t <std::is_floating_point<T>::value>>
35
+ template <typename T>
38
36
bool data_is_close (
39
37
const T* a,
40
38
const T* b,
@@ -110,6 +108,13 @@ bool tensors_are_close(
110
108
a.numel (),
111
109
rtol,
112
110
atol);
111
+ } else if (a.scalar_type () == ScalarType::Half) {
112
+ return data_is_close<Half>(
113
+ a.const_data_ptr <Half>(),
114
+ b.const_data_ptr <Half>(),
115
+ a.numel (),
116
+ rtol,
117
+ atol);
113
118
} else {
114
119
// Non-floating-point types can be compared bitwise.
115
120
return memcmp (a.const_data_ptr (), b.const_data_ptr (), a.nbytes ()) == 0 ;
@@ -150,6 +155,13 @@ bool tensor_data_is_close(
150
155
a.numel (),
151
156
rtol,
152
157
atol);
158
+ } else if (a.scalar_type () == ScalarType::Half) {
159
+ return data_is_close<Half>(
160
+ a.const_data_ptr <Half>(),
161
+ b.const_data_ptr <Half>(),
162
+ a.numel (),
163
+ rtol,
164
+ atol);
153
165
} else {
154
166
// Non-floating-point types can be compared bitwise.
155
167
return memcmp (a.const_data_ptr (), b.const_data_ptr (), a.nbytes ()) == 0 ;
You can’t perform that action at this time.
0 commit comments