Skip to content

Commit 6d1a573

Browse files
swolchokfacebook-github-bot
authored andcommitted
Use parallel_for in bfloat16 gemm_transa_ kernel (#5248)
Summary: Pull Request resolved: #5248 The upstream kernel uses this, I just didn't port it at first. ghstack-source-id: 242411944 exported-using-ghexport Reviewed By: kimishpatel Differential Revision: D62154262 fbshipit-source-id: f15023474d88974c56374dd9be6577fc37217d65
1 parent 1bb5b20 commit 6d1a573

File tree

5 files changed

+48
-22
lines changed

5 files changed

+48
-22
lines changed

build/cmake_deps.toml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ excludes = [
7373
deps = [
7474
"executorch",
7575
"executorch_no_prim_ops",
76+
"extension_threadpool",
7677
"portable_kernels",
7778
]
7879

@@ -197,6 +198,18 @@ deps = [
197198
"executorch",
198199
"executorch_no_prim_ops",
199200
]
201+
202+
[targets.extension_threadpool]
203+
buck_targets = [
204+
"//extension/threadpool:threadpool",
205+
]
206+
filters = [
207+
".cpp$",
208+
]
209+
deps = [
210+
"executorch",
211+
"executorch_no_prim_ops",
212+
]
200213
# ---------------------------------- extension end ----------------------------------
201214
# ---------------------------------- binary start ----------------------------------
202215

@@ -333,6 +346,7 @@ deps = [
333346
"executorch",
334347
"executorch_no_prim_ops",
335348
"optimized_kernels",
349+
"extension_threadpool",
336350
"xnnpack_backend",
337351
]
338352

kernels/optimized/CMakeLists.txt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ endif()
4242
# Build cpublas.
4343
list(TRANSFORM _optimized_cpublas__srcs PREPEND "${EXECUTORCH_ROOT}/")
4444
add_library(cpublas STATIC ${_optimized_cpublas__srcs})
45-
target_link_libraries(cpublas PRIVATE executorch_no_prim_ops eigen_blas)
45+
target_link_libraries(
46+
cpublas PRIVATE executorch_no_prim_ops eigen_blas extension_threadpool
47+
)
4648
target_compile_options(cpublas PUBLIC ${_common_compile_options})
4749

4850
# Generate C++ bindings to register kernels into both PyTorch (for AOT) and
@@ -58,7 +60,9 @@ message("Generated files ${gen_command_sources}")
5860

5961
list(TRANSFORM _optimized_kernels__srcs PREPEND "${EXECUTORCH_ROOT}/")
6062
add_library(optimized_kernels ${_optimized_kernels__srcs})
61-
target_link_libraries(optimized_kernels PRIVATE executorch_no_prim_ops cpublas)
63+
target_link_libraries(
64+
optimized_kernels PRIVATE executorch_no_prim_ops cpublas extension_threadpool
65+
)
6266
target_compile_options(optimized_kernels PUBLIC ${_common_compile_options})
6367
# Build a library for _optimized_kernels_srcs
6468
#

kernels/optimized/blas/BlasKernel.h

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <executorch/kernels/optimized/utils/math_utils.h>
1212
#include <executorch/kernels/optimized/utils/unroll.h>
1313

14+
#include <executorch/extension/parallel/thread_parallel.h>
1415
#include <executorch/runtime/core/portable_type/bfloat16.h>
1516

1617
#include <array>
@@ -177,34 +178,37 @@ inline void gemm_transa_<torch::executor::BFloat16, torch::executor::BFloat16>(
177178
torch::executor::BFloat16 beta,
178179
torch::executor::BFloat16 *c, int64_t ldc) {
179180
// c = alpha * (a.T @ b) + beta * c
180-
// parallel_for(0, m, 1, [&](int64_t begin, int64_t end) {
181181
if (alpha == 1 && beta == 0) {
182-
const auto *a_ = a;
183-
for (int i = 0; i < m; ++i) {
182+
executorch::extension::parallel_for(0, m, 1, [&](int64_t begin, int64_t end) {
183+
const auto *a_ = a + begin * lda;
184+
for (int i = begin; i < end; ++i) {
185+
const auto *b_ = b;
186+
for (int j = 0; j < n; ++j) {
187+
const auto dot = internal::bf16_dot_with_fp32_arith(a_, b_, k);
188+
b_ += ldb;
189+
c[j*ldc+i] = dot;
190+
}
191+
a_ += lda;
192+
}
193+
});
194+
return;
195+
}
196+
executorch::extension::parallel_for(0, m, 1, [&](int64_t begin, int64_t end) {
197+
const auto *a_ = a + begin * lda;
198+
for (int i = begin; i < end; ++i) {
184199
const auto *b_ = b;
185200
for (int j = 0; j < n; ++j) {
186201
const auto dot = internal::bf16_dot_with_fp32_arith(a_, b_, k);
187202
b_ += ldb;
188-
c[j*ldc+i] = dot;
203+
if (beta == 0) {
204+
c[j*ldc+i] = alpha*dot;
205+
} else {
206+
c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot;
207+
}
189208
}
190209
a_ += lda;
191210
}
192-
return;
193-
}
194-
const auto *a_ = a;
195-
for (int i = 0; i < m; ++i) {
196-
const auto *b_ = b;
197-
for (int j = 0; j < n; ++j) {
198-
const auto dot = internal::bf16_dot_with_fp32_arith(a_, b_, k);
199-
b_ += ldb;
200-
if (beta == 0) {
201-
c[j*ldc+i] = alpha*dot;
202-
} else {
203-
c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot;
204-
}
205-
}
206-
a_ += lda;
207-
}
211+
});
208212
}
209213
#endif
210214

kernels/optimized/lib_defs.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def define_libs():
157157
"DEFAULT": [],
158158
}),
159159
exported_deps = [
160+
"//executorch/extension/parallel:thread_parallel",
160161
"//executorch/kernels/optimized:libutils",
161162
"//executorch/runtime/core/exec_aten:lib",
162163
],

kernels/test/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,12 @@ et_cxx_test(
253253
SOURCES
254254
${_optimized_kernels_test_sources}
255255
EXTRA_LIBS
256+
cpuinfo
257+
extension_threadpool
256258
optimized_kernels
257259
optimized_ops_lib
258260
portable_kernels
261+
pthreadpool
259262
eigen_blas
260263
)
261264
add_dependencies(optimized_kernels_test generate_wrapper)

0 commit comments

Comments
 (0)