Skip to content

Commit 6b1e328

Browse files
authored
[ExecuTorch] Support BFloat16 in CPUBlas gemm
Differential Revision: D62151658 Pull Request resolved: #5122
1 parent b69ae0c commit 6b1e328

File tree

3 files changed

+36
-1
lines changed

3 files changed

+36
-1
lines changed

kernels/optimized/blas/CPUBlas.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,5 +173,28 @@ void gemm(
173173
}
174174
// clang-format on
175175

176+
// clang-format off
177+
void gemm(
178+
TransposeType transa, TransposeType transb,
179+
int64_t m, int64_t n, int64_t k,
180+
const BFloat16 alpha,
181+
const BFloat16 *a, int64_t lda,
182+
const BFloat16 *b, int64_t ldb,
183+
const BFloat16 beta,
184+
BFloat16 *c, int64_t ldc) {
185+
normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
186+
187+
using acc_type = utils::compute_dtype<BFloat16>;
188+
gemm_impl(
189+
transa, transb,
190+
m, n, k,
191+
static_cast<const acc_type>(alpha),
192+
a, lda,
193+
b, ldb,
194+
static_cast<const acc_type>(beta),
195+
c, ldc);
196+
}
197+
// clang-format on
198+
176199
} // namespace cpublas
177200
} // namespace executorch

kernels/optimized/blas/CPUBlas.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
namespace executorch {
1818
namespace cpublas {
1919

20+
using BFloat16 = torch::executor::BFloat16;
2021
using Half = torch::executor::Half;
2122

2223
enum class TransposeType {
@@ -104,6 +105,15 @@ void gemm(
104105
const Half *b, int64_t ldb,
105106
const Half beta,
106107
Half *c, int64_t ldc);
108+
109+
void gemm(
110+
TransposeType transa, TransposeType transb,
111+
int64_t m, int64_t n, int64_t k,
112+
const BFloat16 alpha,
113+
const BFloat16 *a, int64_t lda,
114+
const BFloat16 *b, int64_t ldb,
115+
const BFloat16 beta,
116+
BFloat16 *c, int64_t ldc);
107117
// clang-format on
108118

109119
// clang-format off

kernels/optimized/test/libblas_test.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <gtest/gtest.h>
1010

1111
#include <executorch/kernels/optimized/blas/CPUBlas.h>
12+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1213

1314
#include <vector>
1415

@@ -17,7 +18,8 @@
1718
_<float, N>(); \
1819
_<int64_t, N>(); \
1920
_<uint8_t, N>(); \
20-
_<int32_t, N>();
21+
_<int32_t, N>(); \
22+
_<exec_aten::BFloat16, N>();
2123

2224
namespace {
2325

0 commit comments

Comments
 (0)