Skip to content

Commit c02546c

Browse files
authored
[ExecuTorch] support BF16 in op_mm
Differential Revision: D61981353 Pull Request resolved: #4978
1 parent e33c25c commit c02546c

File tree

2 files changed

+15
-14
lines changed

2 files changed

+15
-14
lines changed

kernels/portable/cpu/op_mm.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,20 @@ mm_out(RuntimeContext& ctx, const Tensor& in, const Tensor& mat2, Tensor& out) {
3434

3535
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
3636

37-
ET_SWITCH_REAL_TYPES_AND(Half, in.scalar_type(), ctx, "mm.out", CTYPE, [&]() {
38-
size_t m = in.size(0);
39-
size_t n = in.size(1);
40-
size_t p = mat2.size(1);
41-
42-
vec_matmul<CTYPE>(
43-
out.mutable_data_ptr<CTYPE>(),
44-
in.const_data_ptr<CTYPE>(),
45-
mat2.const_data_ptr<CTYPE>(),
46-
m,
47-
n,
48-
p);
49-
});
37+
ET_SWITCH_REAL_TYPES_AND2(
38+
Half, BFloat16, in.scalar_type(), ctx, "mm.out", CTYPE, [&]() {
39+
size_t m = in.size(0);
40+
size_t n = in.size(1);
41+
size_t p = mat2.size(1);
42+
43+
vec_matmul<CTYPE>(
44+
out.mutable_data_ptr<CTYPE>(),
45+
in.const_data_ptr<CTYPE>(),
46+
mat2.const_data_ptr<CTYPE>(),
47+
m,
48+
n,
49+
p);
50+
});
5051

5152
return out;
5253
}

kernels/test/op_mm_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ TEST_F(OpMmOutTest, OutputDim) {
8181
/// zeros().
8282
TEST_F(OpMmOutTest, AllDtypesSupported) {
8383
#define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
84-
ET_FORALL_REAL_TYPES_AND(Half, TEST_ENTRY);
84+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
8585
#undef TEST_ENTRY
8686
// TODO: Also add tests for half, complex, quantized, and other types. Easiest
8787
// way to do that would be to make TensorFactory support zeros() and ones()

0 commit comments

Comments
 (0)