Skip to content

Commit dbb631c

Browse files
committed
[ExecuTorch] Use parallel_for in bfloat16 gemm_transa_ kernel
Pull Request resolved: #5248 The upstream kernel uses this, I just didn't port it at first. ghstack-source-id: 242278002 @exported-using-ghexport Differential Revision: [D62154262](https://our.internmc.facebook.com/intern/diff/D62154262/)
1 parent 1793c4a commit dbb631c

File tree

3 files changed

+37
-20
lines changed

3 files changed

+37
-20
lines changed

build/cmake_deps.toml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ excludes = [
114114
deps = [
115115
"executorch_no_prim_ops",
116116
"executorch",
117+
"extension_parallel_thread_parallel",
117118
]
118119

119120
[targets.optimized_native_cpu_ops_oss]
@@ -197,6 +198,17 @@ deps = [
197198
"executorch",
198199
"executorch_no_prim_ops",
199200
]
201+
202+
[targets.extension_parallel_thread_parallel]
203+
buck_targets = [
204+
"//extension/parallel:thread_parallel",
205+
]
206+
filters = [
207+
".cpp$",
208+
]
209+
deps = [
210+
"executorch",
211+
]
200212
# ---------------------------------- extension end ----------------------------------
201213
# ---------------------------------- binary start ----------------------------------
202214

kernels/optimized/blas/BlasKernel.h

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <executorch/kernels/optimized/utils/math_utils.h>
1212
#include <executorch/kernels/optimized/utils/unroll.h>
1313

14+
#include <executorch/extension/parallel/thread_parallel.h>
1415
#include <executorch/runtime/core/portable_type/bfloat16.h>
1516

1617
#include <array>
@@ -177,34 +178,37 @@ inline void gemm_transa_<torch::executor::BFloat16, torch::executor::BFloat16>(
177178
torch::executor::BFloat16 beta,
178179
torch::executor::BFloat16 *c, int64_t ldc) {
179180
// c = alpha * (a.T @ b) + beta * c
180-
// parallel_for(0, m, 1, [&](int64_t begin, int64_t end) {
181181
if (alpha == 1 && beta == 0) {
182-
const auto *a_ = a;
183-
for (int i = 0; i < m; ++i) {
182+
executorch::extension::parallel_for(0, m, 1, [&](int64_t begin, int64_t end) {
183+
const auto *a_ = a + begin * lda;
184+
for (int i = begin; i < end; ++i) {
185+
const auto *b_ = b;
186+
for (int j = 0; j < n; ++j) {
187+
const auto dot = internal::bf16_dot_with_fp32_arith(a_, b_, k);
188+
b_ += ldb;
189+
c[j*ldc+i] = dot;
190+
}
191+
a_ += lda;
192+
}
193+
});
194+
return;
195+
}
196+
executorch::extension::parallel_for(0, m, 1, [&](int64_t begin, int64_t end) {
197+
const auto *a_ = a + begin * lda;
198+
for (int i = begin; i < end; ++i) {
184199
const auto *b_ = b;
185200
for (int j = 0; j < n; ++j) {
186201
const auto dot = internal::bf16_dot_with_fp32_arith(a_, b_, k);
187202
b_ += ldb;
188-
c[j*ldc+i] = dot;
203+
if (beta == 0) {
204+
c[j*ldc+i] = alpha*dot;
205+
} else {
206+
c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot;
207+
}
189208
}
190209
a_ += lda;
191210
}
192-
return;
193-
}
194-
const auto *a_ = a;
195-
for (int i = 0; i < m; ++i) {
196-
const auto *b_ = b;
197-
for (int j = 0; j < n; ++j) {
198-
const auto dot = internal::bf16_dot_with_fp32_arith(a_, b_, k);
199-
b_ += ldb;
200-
if (beta == 0) {
201-
c[j*ldc+i] = alpha*dot;
202-
} else {
203-
c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot;
204-
}
205-
}
206-
a_ += lda;
207-
}
211+
});
208212
}
209213
#endif
210214

kernels/optimized/lib_defs.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def define_libs():
157157
"DEFAULT": [],
158158
}),
159159
exported_deps = [
160+
"//executorch/extension/parallel:thread_parallel",
160161
"//executorch/kernels/optimized:libutils",
161162
"//executorch/runtime/core/exec_aten:lib",
162163
],

0 commit comments

Comments
 (0)