Skip to content

Commit e826de3

Browse files
Add Half/BFloat16 tests for op_mul
Differential Revision: D62417216 Pull Request resolved: #5213
1 parent 549f14b commit e826de3

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

kernels/portable/cpu/op_mul.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,11 @@ Tensor& mul_scalar_out(
123123
ET_KERNEL_CHECK(
124124
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
125125

126-
ET_KERNEL_CHECK(ctx, tensor_is_realhb_type(out), InvalidArgument, out);
126+
ET_KERNEL_CHECK(
127+
ctx,
128+
executorch::runtime::tensor_is_realhbbf16_type(out),
129+
InvalidArgument,
130+
out);
127131

128132
ScalarType a_type = a.scalar_type();
129133
ScalarType b_type = utils::get_scalar_dtype(b);

kernels/test/op_mul_test.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,3 +586,29 @@ TEST_F(OpMulScalarOutTest, OptimizedSanityCheck) {
586586
// Check that it matches the expected output.
587587
EXPECT_TENSOR_CLOSE(out, tf.make(sizes, {2.6, 4.2, 9.2, 16.4}));
588588
}
589+
590+
TEST_F(OpMulScalarOutTest, HalfSanityCheck) {
591+
TensorFactory<ScalarType::Half> tf;
592+
593+
const std::vector<int32_t> sizes = {2, 2};
594+
595+
Tensor out = tf.zeros(sizes);
596+
597+
op_mul_scalar_out(tf.make(sizes, {1.3, 2.1, 4.6, 8.2}), 2.0, out);
598+
599+
// Check that it matches the expected output.
600+
EXPECT_TENSOR_CLOSE(out, tf.make(sizes, {2.6, 4.2, 9.2, 16.4}));
601+
}
602+
603+
TEST_F(OpMulScalarOutTest, BFloat16SanityCheck) {
604+
TensorFactory<ScalarType::BFloat16> tf;
605+
606+
const std::vector<int32_t> sizes = {2, 2};
607+
608+
Tensor out = tf.zeros(sizes);
609+
610+
op_mul_scalar_out(tf.make(sizes, {1.3, 2.1, 4.6, 8.2}), 2.0, out);
611+
612+
// Check that it matches the expected output.
613+
EXPECT_TENSOR_CLOSE(out, tf.make(sizes, {2.6, 4.2, 9.2, 16.4}));
614+
}

0 commit comments

Comments
 (0)