Skip to content

Commit 8b5d1ef

Browse files
swolchokYIWENX14
authored andcommitted
Support Half/BFloat16 in var (#7913)
Partial fix for #7748.
1 parent f104e0f commit 8b5d1ef

File tree

2 files changed

+49
-26
lines changed

2 files changed

+49
-26
lines changed

kernels/portable/cpu/op_var.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ void compute_variance(
3838
in,
3939
dim_list,
4040
out_ix);
41-
CTYPE_OUT mean = sum / num;
41+
CTYPE_OUT mean = sum / static_cast<CTYPE_OUT>(num);
4242
CTYPE_OUT sum2 = map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
4343
[mean](CTYPE_IN v) {
4444
return (
@@ -90,8 +90,8 @@ Tensor& var_out(
9090

9191
constexpr auto name = "var.out";
9292

93-
ET_SWITCH_FLOAT_TYPES(in.scalar_type(), ctx, name, CTYPE_IN, [&] {
94-
ET_SWITCH_FLOAT_TYPES(out.scalar_type(), ctx, name, CTYPE_OUT, [&] {
93+
ET_SWITCH_FLOATHBF16_TYPES(in.scalar_type(), ctx, name, CTYPE_IN, [&] {
94+
ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, name, CTYPE_OUT, [&] {
9595
compute_variance<CTYPE_IN, CTYPE_OUT>(in, out, dim_list, num, denom);
9696
});
9797
});
@@ -135,8 +135,8 @@ Tensor& var_correction_out(
135135
const size_t num = get_reduced_dim_product(in, dim_list);
136136
const double denom = num - correction_val;
137137

138-
ET_SWITCH_FLOAT_TYPES(in.scalar_type(), ctx, name, CTYPE_IN, [&] {
139-
ET_SWITCH_FLOAT_TYPES(out.scalar_type(), ctx, name, CTYPE_OUT, [&] {
138+
ET_SWITCH_FLOATHBF16_TYPES(in.scalar_type(), ctx, name, CTYPE_IN, [&] {
139+
ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, name, CTYPE_OUT, [&] {
140140
compute_variance<CTYPE_IN, CTYPE_OUT>(in, out, dim_list, num, denom);
141141
});
142142
});

kernels/test/op_var_test.cpp

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,19 @@ using exec_aten::ScalarType;
2525
using exec_aten::Tensor;
2626
using torch::executor::testing::TensorFactory;
2727

28+
namespace {
29+
void expect_tensor_close_with_increased_tol(
30+
const Tensor& actual,
31+
const Tensor& expected) {
32+
if (actual.scalar_type() == ScalarType::BFloat16 ||
33+
actual.scalar_type() == ScalarType::Half) {
34+
EXPECT_TENSOR_CLOSE_WITH_TOL(expected, actual, 1e-2, 1e-2);
35+
} else {
36+
EXPECT_TENSOR_CLOSE(expected, actual);
37+
}
38+
}
39+
} // namespace
40+
2841
class OpVarOutTest : public OperatorTest {
2942
protected:
3043
Tensor& op_var_out(
@@ -142,7 +155,7 @@ class OpVarOutTest : public OperatorTest {
142155
op_var_out(
143156
self, optional_dim_list, /*unbiased=*/true, /*keepdim=*/true, out);
144157
// clang-format off
145-
EXPECT_TENSOR_CLOSE(out, tf_out.make(
158+
expect_tensor_close_with_increased_tol(out, tf_out.make(
146159
{2, 3, 1},
147160
{
148161
1.666667,
@@ -160,7 +173,7 @@ class OpVarOutTest : public OperatorTest {
160173
op_var_out(
161174
self, optional_dim_list, /*unbiased=*/true, /*keepdim=*/false, out);
162175
// clang-format off
163-
EXPECT_TENSOR_CLOSE(out, tf_out.make(
176+
expect_tensor_close_with_increased_tol(out, tf_out.make(
164177
{2, 3},
165178
{
166179
1.666667, 1.666667, 1.666667,
@@ -174,12 +187,14 @@ class OpVarOutTest : public OperatorTest {
174187
optional_dim_list = ArrayRef<int64_t>{dims_2, 2};
175188
op_var_out(
176189
self, optional_dim_list, /*unbiased=*/true, /*keepdim=*/true, out);
177-
EXPECT_TENSOR_CLOSE(out, tf_out.make({1, 1, 4}, {56.0, 56.0, 56.0, 56.0}));
190+
expect_tensor_close_with_increased_tol(
191+
out, tf_out.make({1, 1, 4}, {56.0, 56.0, 56.0, 56.0}));
178192

179193
out = tf_out.zeros({4});
180194
op_var_out(
181195
self, optional_dim_list, /*unbiased=*/true, /*keepdim=*/false, out);
182-
EXPECT_TENSOR_CLOSE(out, tf_out.make({4}, {56.0, 56.0, 56.0, 56.0}));
196+
expect_tensor_close_with_increased_tol(
197+
out, tf_out.make({4}, {56.0, 56.0, 56.0, 56.0}));
183198

184199
// dim list with negative dimensions should work
185200
out = tf_out.zeros({2, 1, 4});
@@ -188,7 +203,7 @@ class OpVarOutTest : public OperatorTest {
188203
op_var_out(
189204
self, optional_dim_list, /*unbiased=*/false, /*keepdim=*/true, out);
190205
// clang-format off
191-
EXPECT_TENSOR_CLOSE(out, tf_out.make(
206+
expect_tensor_close_with_increased_tol(out, tf_out.make(
192207
{2, 1, 4},
193208
{
194209
10.666667, 10.666667, 10.666667, 10.666667,
@@ -201,18 +216,19 @@ class OpVarOutTest : public OperatorTest {
201216
out = tf_out.zeros({1, 1, 1});
202217
optional<ArrayRef<int64_t>> null_dim_list;
203218
op_var_out(self, null_dim_list, /*unbiased=*/true, /*keepdim=*/true, out);
204-
EXPECT_TENSOR_CLOSE(out, tf_out.make({1, 1, 1}, {50.0}));
219+
expect_tensor_close_with_increased_tol(out, tf_out.make({1, 1, 1}, {50.0}));
205220

206221
optional<ArrayRef<int64_t>> empty_dim_list{ArrayRef<int64_t>{}};
207222
op_var_out(self, empty_dim_list, /*unbiased=*/false, /*keepdim=*/true, out);
208-
EXPECT_TENSOR_CLOSE(out, tf_out.make({1, 1, 1}, {47.916668}));
223+
expect_tensor_close_with_increased_tol(
224+
out, tf_out.make({1, 1, 1}, {47.916668}));
209225

210226
out = tf_out.zeros({});
211227
op_var_out(self, null_dim_list, /*unbiased=*/false, /*keepdim=*/false, out);
212-
EXPECT_TENSOR_CLOSE(out, tf_out.make({}, {47.916668}));
228+
expect_tensor_close_with_increased_tol(out, tf_out.make({}, {47.916668}));
213229

214230
op_var_out(self, empty_dim_list, /*unbiased=*/true, /*keepdim=*/false, out);
215-
EXPECT_TENSOR_CLOSE(out, tf_out.make({}, {50.0}));
231+
expect_tensor_close_with_increased_tol(out, tf_out.make({}, {50.0}));
216232
}
217233
};
218234

@@ -227,6 +243,20 @@ class OpVarCorrectionOutTest : public OperatorTest {
227243
return torch::executor::aten::var_outf(
228244
context_, self, dim, correction, keepdim, out);
229245
}
246+
247+
template <ScalarType DTYPE>
248+
void test_dtype() {
249+
TensorFactory<DTYPE> tf;
250+
251+
Tensor x = tf.make({2, 3}, {4.9, 4.0, 5.6, 3.8, 4.9, 5.6});
252+
Tensor expected = tf.make({2}, {0.72693, 0.93032});
253+
optional<Scalar> correction(1.23);
254+
Tensor out = tf.zeros({2});
255+
256+
op_var_correction_out(
257+
x, ArrayRef<int64_t>{1}, correction, /*keepdim=*/false, out);
258+
expect_tensor_close_with_increased_tol(out, expected);
259+
}
230260
};
231261

232262
TEST_F(OpVarOutTest, InvalidDimensionListDies) {
@@ -303,9 +333,9 @@ TEST_F(OpVarOutTest, AllFloatInputFloatOutputPasses) {
303333
test_var_out_dtype<ScalarType::INPUT_DTYPE, ScalarType::OUTPUT_DTYPE>();
304334

305335
#define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \
306-
ET_FORALL_FLOAT_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
336+
ET_FORALL_FLOATHBF16_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
307337

308-
ET_FORALL_FLOAT_TYPES(TEST_ENTRY);
338+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
309339
#undef TEST_ENTRY
310340
#undef TEST_KERNEL
311341
}
@@ -387,14 +417,7 @@ TEST_F(OpVarOutTest, DynamicShapeUnbound) {
387417
}
388418

389419
TEST_F(OpVarCorrectionOutTest, SmokeTest) {
390-
TensorFactory<ScalarType::Float> tf;
391-
392-
Tensor x = tf.make({2, 3}, {4.9, 4.0, 5.6, 3.8, 4.9, 5.6});
393-
Tensor expected = tf.make({2}, {0.72693, 0.93032});
394-
optional<Scalar> correction(1.23);
395-
Tensor out = tf.zeros({2});
396-
397-
op_var_correction_out(
398-
x, ArrayRef<int64_t>{1}, correction, /*keepdim=*/false, out);
399-
EXPECT_TENSOR_CLOSE(out, expected);
420+
#define TEST_ENTRY(ctype, dtype) test_dtype<ScalarType::dtype>();
421+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
422+
#undef TEST_ENTRY
400423
}

0 commit comments

Comments
 (0)