@@ -73,12 +73,37 @@ TensorImpl impl_like(bundled_program_flatbuffer::Tensor* bundled_tensor) {
73
73
#endif
74
74
75
75
/* *
76
- * Returns true if the two arrays are close according to the description on
76
+ * Returns true if the two elements are close according to the description on
77
77
* `tensors_are_close()`.
78
78
*
79
79
* T must be a floating point type. Non-floating point data should be compared
80
80
* directly.
81
81
*/
82
+ template <
83
+ typename T,
84
+ typename = std::enable_if_t <std::is_floating_point<T>::value>>
85
+ bool elem_is_close (const T ai, const T bi, double rtol, double atol) {
86
+ if (std::isnan (ai) && std::isnan (bi)) {
87
+ // NaN == NaN
88
+ } else if (
89
+ !std::isfinite (ai) && !std::isfinite (bi) && ((ai > 0 ) == (bi > 0 ))) {
90
+ // -Inf == -Inf
91
+ // +Inf == +Inf
92
+ } else if (rtol == 0 && atol == 0 ) {
93
+ // Exact comparison; avoid unnecessary math.
94
+ if (ai != bi) {
95
+ return false ;
96
+ }
97
+ } else {
98
+ auto allowed_error = atol + std::abs (rtol * bi);
99
+ auto actual_error = std::abs (ai - bi);
100
+ if (!std::isfinite (actual_error) || actual_error > allowed_error) {
101
+ return false ;
102
+ }
103
+ }
104
+ return true ;
105
+ }
106
+
82
107
template <
83
108
typename T,
84
109
typename = std::enable_if_t <std::is_floating_point<T>::value>>
@@ -89,26 +114,23 @@ bool data_is_close(
89
114
double rtol,
90
115
double atol) {
91
116
for (size_t i = 0 ; i < numel; i++) {
92
- const auto ai = a[i];
93
- const auto bi = b[i];
94
-
95
- if (std::isnan (ai) && std::isnan (bi)) {
96
- // NaN == NaN
97
- } else if (
98
- !std::isfinite (ai) && !std::isfinite (bi) && ((ai > 0 ) == (bi > 0 ))) {
99
- // -Inf == -Inf
100
- // +Inf == +Inf
101
- } else if (rtol == 0 && atol == 0 ) {
102
- // Exact comparison; avoid unnecessary math.
103
- if (ai != bi) {
104
- return false ;
105
- }
106
- } else {
107
- auto allowed_error = atol + std::abs (rtol * bi);
108
- auto actual_error = std::abs (ai - bi);
109
- if (!std::isfinite (actual_error) || actual_error > allowed_error) {
110
- return false ;
111
- }
117
+ if (!elem_is_close (a[i], b[i], rtol, atol)) {
118
+ return false ;
119
+ }
120
+ }
121
+ return true ;
122
+ }
123
+
124
+ bool data_is_close_half (
125
+ const Half* a,
126
+ const Half* b,
127
+ size_t numel,
128
+ double rtol,
129
+ double atol) {
130
+ for (size_t i = 0 ; i < numel; i++) {
131
+ if (!elem_is_close (
132
+ static_cast <double >(a[i]), static_cast <double >(b[i]), rtol, atol)) {
133
+ return false ;
112
134
}
113
135
}
114
136
return true ;
@@ -177,6 +199,13 @@ bool tensors_are_close(
177
199
bundled_tensor.numel (),
178
200
rtol,
179
201
atol);
202
+ } else if (bundled_tensor.scalar_type () == ScalarType::Half) {
203
+ return data_is_close_half (
204
+ bundled_tensor.const_data_ptr <Half>(),
205
+ method_output_tensor.const_data_ptr <Half>(),
206
+ bundled_tensor.numel (),
207
+ rtol,
208
+ atol);
180
209
} else {
181
210
// Non-floating-point types can be compared bitwise.
182
211
return memcmp (
0 commit comments