File tree Expand file tree Collapse file tree 2 files changed +27
-3
lines changed Expand file tree Collapse file tree 2 files changed +27
-3
lines changed Original file line number Diff line number Diff line change @@ -32,9 +32,7 @@ namespace {
32
32
* T must be a floating point type. Non-floating point data should be compared
33
33
* directly.
34
34
*/
35
- template <
36
- typename T,
37
- typename = std::enable_if_t <std::is_floating_point<T>::value>>
35
+ template <typename T>
38
36
bool data_is_close (
39
37
const T* a,
40
38
const T* b,
@@ -110,6 +108,13 @@ bool tensors_are_close(
110
108
a.numel (),
111
109
rtol,
112
110
atol);
111
+ } else if (a.scalar_type () == ScalarType::Half) {
112
+ return data_is_close<Half>(
113
+ a.const_data_ptr <Half>(),
114
+ b.const_data_ptr <Half>(),
115
+ a.numel (),
116
+ rtol,
117
+ atol);
113
118
} else {
114
119
// Non-floating-point types can be compared bitwise.
115
120
return memcmp (a.const_data_ptr (), b.const_data_ptr (), a.nbytes ()) == 0 ;
@@ -150,6 +155,13 @@ bool tensor_data_is_close(
150
155
a.numel (),
151
156
rtol,
152
157
atol);
158
+ } else if (a.scalar_type () == ScalarType::Half) {
159
+ return data_is_close<Half>(
160
+ a.const_data_ptr <Half>(),
161
+ b.const_data_ptr <Half>(),
162
+ a.numel (),
163
+ rtol,
164
+ atol);
153
165
} else {
154
166
// Non-floating-point types can be compared bitwise.
155
167
return memcmp (a.const_data_ptr (), b.const_data_ptr (), a.nbytes ()) == 0 ;
Original file line number Diff line number Diff line change @@ -681,6 +681,18 @@ std::ostream& operator<<(std::ostream& out, const Half& value);
681
681
682
682
namespace std {
683
683
684
+ static inline int isinf (torch::executor::Half value) {
685
+ return (value.x & 0x7FFF ) == 0x7C00 ;
686
+ }
687
+
688
+ static inline int isnan (torch::executor::Half value) {
689
+ return ((value.x & 0x7C00 ) == 0x7C00 ) && ((value.x & 0x03ff ) != 0 );
690
+ }
691
+
692
+ static inline int isfinite (torch::executor::Half value) {
693
+ return !(isinf (value) || isnan (value));
694
+ }
695
+
684
696
template <>
685
697
class numeric_limits <torch::executor::Half> {
686
698
public:
You can’t perform that action at this time.
0 commit comments