Skip to content

Commit 63b0a22

Browse files
Michael Gschwindfacebook-github-bot
authored andcommitted
Enable tensor closeness check for additional types (#2256)
Summary: Pull Request resolved: #2256 Enable tensor closeness check for additional types Reviewed By: manuelcandales Differential Revision: D54538932 fbshipit-source-id: 14d36f2bdcdee833a30995b664b9089e3264b511
1 parent aed32c4 commit 63b0a22

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

runtime/core/exec_aten/testing_util/tensor_util.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,7 @@ namespace {
3232
* T must be a floating point type. Non-floating point data should be compared
3333
* directly.
3434
*/
35-
template <
36-
typename T,
37-
typename = std::enable_if_t<std::is_floating_point<T>::value>>
35+
template <typename T>
3836
bool data_is_close(
3937
const T* a,
4038
const T* b,
@@ -110,6 +108,13 @@ bool tensors_are_close(
110108
a.numel(),
111109
rtol,
112110
atol);
111+
} else if (a.scalar_type() == ScalarType::Half) {
112+
return data_is_close<Half>(
113+
a.const_data_ptr<Half>(),
114+
b.const_data_ptr<Half>(),
115+
a.numel(),
116+
rtol,
117+
atol);
113118
} else {
114119
// Non-floating-point types can be compared bitwise.
115120
return memcmp(a.const_data_ptr(), b.const_data_ptr(), a.nbytes()) == 0;
@@ -150,6 +155,13 @@ bool tensor_data_is_close(
150155
a.numel(),
151156
rtol,
152157
atol);
158+
} else if (a.scalar_type() == ScalarType::Half) {
159+
return data_is_close<Half>(
160+
a.const_data_ptr<Half>(),
161+
b.const_data_ptr<Half>(),
162+
a.numel(),
163+
rtol,
164+
atol);
153165
} else {
154166
// Non-floating-point types can be compared bitwise.
155167
return memcmp(a.const_data_ptr(), b.const_data_ptr(), a.nbytes()) == 0;

runtime/core/portable_type/half.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,18 @@ std::ostream& operator<<(std::ostream& out, const Half& value);
681681

682682
namespace std {
683683

684+
static inline int isinf(torch::executor::Half value) {
685+
return (value.x & 0x7FFF) == 0x7C00;
686+
}
687+
688+
static inline int isnan(torch::executor::Half value) {
689+
return ((value.x & 0x7C00) == 0x7C00) && ((value.x & 0x03ff) != 0);
690+
}
691+
692+
static inline int isfinite(torch::executor::Half value) {
693+
return !(isinf(value) || isnan(value));
694+
}
695+
684696
template <>
685697
class numeric_limits<torch::executor::Half> {
686698
public:

0 commit comments

Comments
 (0)