Skip to content

Commit 76921f8

Browse files
committed
[Executorch][llama] Set # of threads to use performant cores
Pull Request resolved: #2352 When using all cores, slower ones are dragging the performance down by blocking large cores. Perhaps when we have uarch specific implementation, we may not need this, but this tool is useful in general until we have better API ghstack-source-id: 218361721 @exported-using-ghexport Differential Revision: [D54766071](https://our.internmc.facebook.com/intern/diff/D54766071/)
1 parent 2103c36 commit 76921f8

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

examples/models/llama2/main.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
#include <executorch/examples/models/llama2/runner/runner.h>
1212

13+
#include <executorch/backends/xnnpack/threadpool/cpuinfo_utils.h>
14+
#include <executorch/backends/xnnpack/threadpool/threadpool.h>
15+
1316
DEFINE_string(
1417
model_path,
1518
"llama2.pte",
@@ -45,6 +48,12 @@ int32_t main(int32_t argc, char** argv) {
4548

4649
int32_t seq_len = FLAGS_seq_len;
4750

51+
uint32_t num_performant_cores =
52+
torch::executorch::cpuinfo::get_num_performant_cores();
53+
ET_LOG(
54+
Info, "Resetting threadpool with num threads = %d", num_performant_cores);
55+
torch::executorch::threadpool::get_threadpool()->_unsafe_reset_threadpool(
56+
num_performant_cores);
4857
// create llama runner
4958
::torch::executor::Runner runner(model_path, tokenizer_path, temperature);
5059

examples/models/llama2/targets.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ def define_common_targets():
1616
deps = [
1717
"//executorch/examples/models/llama2/runner:runner" + aten_suffix,
1818
"//executorch/extension/evalue_util:print_evalue",
19+
"//executorch/backends/xnnpack/threadpool:threadpool",
20+
"//executorch/backends/xnnpack/threadpool:cpuinfo_utils",
1921
],
2022
external_deps = [
2123
"gflags",

0 commit comments

Comments
 (0)