Skip to content

Commit 7e374d7

Browse files
Add model execution scripts and runner (#5217)
Summary: Add execution scripts and runner for 8 OSS models Pull Request resolved: #5217 Reviewed By: kirklandsign Differential Revision: D62479707 Pulled By: cccclai fbshipit-source-id: 81310dbb6b785ec59329110ebacb8208102e8597
1 parent d7a7ec6 commit 7e374d7

File tree

16 files changed

+1704
-3
lines changed

16 files changed

+1704
-3
lines changed

backends/mediatek/CMakeLists.txt

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,13 @@ include_directories(BEFORE ${CMAKE_CURRENT_SOURCE_DIR}/runtime/include)
2525

2626
# targets
2727
add_library(neuron_backend SHARED)
28-
target_link_libraries(
29-
neuron_backend PRIVATE executorch_no_prim_ops portable_ops_lib android log
30-
${NEURON_BUFFER_ALLOCATOR_LIB}
28+
target_link_libraries(neuron_backend
29+
PRIVATE
30+
executorch_no_prim_ops
31+
portable_ops_lib
32+
android
33+
log
34+
${NEURON_BUFFER_ALLOCATOR_LIB}
3135
)
3236
target_sources(
3337
neuron_backend

examples/mediatek/CMakeLists.txt

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,44 @@ if(${ANDROID})
7575
)
7676
target_compile_options(mtk_executor_runner PUBLIC ${_common_compile_options})
7777

78+
set(_mtk_oss_executor_runner__srcs ${_executor_runner__srcs})
79+
list(
80+
TRANSFORM
81+
_mtk_oss_executor_runner__srcs
82+
PREPEND
83+
"${EXECUTORCH_SOURCE_DIR}/"
84+
)
85+
list(
86+
FILTER
87+
_mtk_oss_executor_runner__srcs
88+
EXCLUDE REGEX
89+
".*executor_runner.cpp$"
90+
)
91+
list(
92+
PREPEND
93+
_mtk_oss_executor_runner__srcs
94+
${CMAKE_CURRENT_LIST_DIR}/executor_runner/mtk_oss_executor_runner.cpp
95+
)
96+
97+
add_executable(mtk_oss_executor_runner ${_mtk_oss_executor_runner__srcs})
98+
99+
target_include_directories(mtk_oss_executor_runner
100+
PUBLIC
101+
${_common_include_directories}
102+
${EXECUTORCH_ROOT}/cmake-android-out/third-party/gflags/include
103+
)
104+
105+
target_link_libraries(mtk_oss_executor_runner
106+
${_executor_runner_libs}
107+
executorch
108+
neuron_backend
109+
gflags
110+
)
111+
target_compile_options(mtk_oss_executor_runner
112+
PUBLIC
113+
${_common_compile_options}
114+
)
115+
78116
set(_mtk_llama_executor_runner__srcs ${_mtk_executor_runner__srcs})
79117
list(FILTER _mtk_llama_executor_runner__srcs EXCLUDE REGEX
80118
".*executor_runner.cpp$"

examples/mediatek/README.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ examples/mediatek
99
├── preformatter_templates # Model specific prompt preformatter templates
1010
├── prompts # Calibration Prompts
1111
├── tokenizers_ # Model tokenizer scripts
12+
├── oss_utils # Utils for oss models
13+
├── eval_utils # Utils for eval oss models
1214
├── model_export_scripts # Model specifc export scripts
1315
├── models # Model definitions
1416
├── llm_models # LLM model definitions
@@ -44,6 +46,7 @@ pip3 install mtk_converter-8.8.0.dev20240723+public.d1467db9-cp310-cp310-manylin
4446
```
4547

4648
## AoT Flow
49+
### llama
4750
##### Note: Verify that localhost connection is available before running AoT Flow
4851
1. Exporting Models to `.pte`
4952
- In the `examples/mediatek directory`, run:
@@ -72,6 +75,14 @@ source shell_scripts/export_llama.sh <model_name> <num_chunks> <prompt_num_token
7275
- eg. For `llama3-8B-instruct`, embedding bin generated in `examples/mediatek/models/llm_models/weights/llama3-8B-instruct/`
7376
- AoT flow will take roughly 2.5 hours (114GB RAM for `num_chunks=4`) to complete (Results will vary by device/hardware configurations)
7477

78+
### oss
79+
1. Exporting Model to `.pte`
80+
```bash
81+
bash shell_scripts/export_oss.sh <model_name>
82+
```
83+
- Argument Options:
84+
- `model_name`: deeplabv3/edsr/inceptionv3/inceptionv4/mobilenetv2/mobilenetv3/resnet18/resnet50
85+
7586
# Runtime
7687
## Supported Chips
7788

@@ -100,6 +111,13 @@ adb push <MODEL_NAME>.pte <PHONE_PATH, e.g. /data/local/tmp>
100111

101112
Make sure to replace `<MODEL_NAME>` with the actual name of your model file. And, replace the `<PHONE_PATH>` with the desired detination on the device.
102113

114+
##### Note: For oss models, please push additional files to your Android device
115+
```bash
116+
adb push mtk_oss_executor_runner <PHONE_PATH, e.g. /data/local/tmp>
117+
adb push input_list.txt <PHONE_PATH, e.g. /data/local/tmp>
118+
for i in input*bin; do adb push "$i" <PHONE_PATH, e.g. /data/local/tmp>; done;
119+
```
120+
103121
### Executing the Model
104122

105123
Execute the model on your Android device by running:
@@ -111,3 +129,21 @@ adb shell "/data/local/tmp/mtk_executor_runner --model_path /data/local/tmp/<MOD
111129
In the command above, replace `<MODEL_NAME>` with the name of your model file and `<ITER_TIMES>` with the desired number of iterations to run the model.
112130

113131
##### Note: For llama models, please use `mtk_llama_executor_runner`. Refer to `examples/mediatek/executor_runner/run_llama3_sample.sh` for reference.
132+
##### Note: For oss models, please use `mtk_oss_executor_runner`.
133+
```bash
134+
adb shell "/data/local/tmp/mtk_oss_executor_runner --model_path /data/local/tmp/<MODEL_NAME>.pte --input_list /data/local/tmp/input_list.txt --output_folder /data/local/tmp/output_<MODEL_NAME>"
135+
adb pull "/data/local/tmp/output_<MODEL_NAME> ./"
136+
```
137+
138+
### Check oss result on PC
139+
```bash
140+
python3 eval_utils/eval_oss_result.py --eval_type <eval_type> --target_f <golden_folder> --output_f <prediction_folder>
141+
```
142+
For example:
143+
```
144+
python3 eval_utils/eval_oss_result.py --eval_type piq --target_f edsr --output_f output_edsr
145+
```
146+
- Argument Options:
147+
- `eval_type`: topk/piq/segmentation
148+
- `target_f`: folder contain golden data files. file name is `golden_<data_idx>_0.bin`
149+
- `output_f`: folder contain model output data files. file name is `output_<data_idx>_0.bin`
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright (c) MediaTek Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
from typing import Optional
9+
10+
import torch
11+
from executorch import exir
12+
from executorch.backends.mediatek import (
13+
NeuropilotPartitioner,
14+
NeuropilotQuantizer,
15+
Precision,
16+
)
17+
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
18+
19+
20+
def build_executorch_binary(
21+
model,
22+
inputs,
23+
file_name,
24+
dataset,
25+
quant_dtype: Optional[Precision] = None,
26+
):
27+
if quant_dtype is not None:
28+
quantizer = NeuropilotQuantizer()
29+
quantizer.setup_precision(quant_dtype)
30+
if quant_dtype not in Precision:
31+
raise AssertionError(f"No support for Precision {quant_dtype}.")
32+
33+
captured_model = torch._export.capture_pre_autograd_graph(model, inputs)
34+
annotated_model = prepare_pt2e(captured_model, quantizer)
35+
print("Quantizing the model...")
36+
# calibration
37+
for data in dataset:
38+
annotated_model(*data)
39+
quantized_model = convert_pt2e(annotated_model, fold_quantize=False)
40+
aten_dialect = torch.export.export(quantized_model, inputs)
41+
else:
42+
aten_dialect = torch.export.export(model, inputs)
43+
44+
from executorch.exir.program._program import to_edge_transform_and_lower
45+
46+
edge_compile_config = exir.EdgeCompileConfig(_check_ir_validity=False)
47+
# skipped op names are used for deeplabV3 model
48+
neuro_partitioner = NeuropilotPartitioner(
49+
[],
50+
op_names_to_skip={
51+
"aten_convolution_default_106",
52+
"aten_convolution_default_107",
53+
},
54+
)
55+
edge_prog = to_edge_transform_and_lower(
56+
aten_dialect,
57+
compile_config=edge_compile_config,
58+
partitioner=[neuro_partitioner],
59+
)
60+
61+
exec_prog = edge_prog.to_executorch(
62+
config=exir.ExecutorchBackendConfig(extract_constant_segment=False)
63+
)
64+
with open(f"{file_name}.pte", "wb") as file:
65+
file.write(exec_prog.buffer)
66+
67+
68+
def make_output_dir(path: str):
69+
if os.path.exists(path):
70+
for f in os.listdir(path):
71+
os.remove(os.path.join(path, f))
72+
os.removedirs(path)
73+
os.makedirs(path)

0 commit comments

Comments
 (0)