Skip to content

Commit ee2180e

Browse files
winskuo-quicZonglin Peng
authored andcommitted
Qualcomm AI Engine Direct - Meta CI for Mobilebert , W2L, and Llama (#8616)
* Qualcomm AI Engine Direct - Meta CI for Mobilebert and W2L * variable update
1 parent a048c2c commit ee2180e

File tree

9 files changed

+66
-25
lines changed

9 files changed

+66
-25
lines changed

.ci/scripts/test_model.sh

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ test_model_with_qnn() {
164164
export LD_LIBRARY_PATH=$QNN_SDK_ROOT/lib/x86_64-linux-clang/
165165
export PYTHONPATH=$EXECUTORCH_ROOT/..
166166

167+
EXTRA_FLAGS=""
167168
if [[ "${MODEL_NAME}" == "dl3" ]]; then
168169
EXPORT_SCRIPT=deeplab_v3
169170
elif [[ "${MODEL_NAME}" == "mv3" ]]; then
@@ -176,6 +177,12 @@ test_model_with_qnn() {
176177
EXPORT_SCRIPT=inception_v3
177178
elif [[ "${MODEL_NAME}" == "vit" ]]; then
178179
EXPORT_SCRIPT=torchvision_vit
180+
elif [[ "${MODEL_NAME}" == "mb" ]]; then
181+
EXPORT_SCRIPT=mobilebert_fine_tune
182+
EXTRA_FLAGS="--num_epochs 1"
183+
pip install scikit-learn
184+
elif [[ "${MODEL_NAME}" == "w2l" ]]; then
185+
EXPORT_SCRIPT=wav2letter
179186
elif [[ "${MODEL_NAME}" == "edsr" ]]; then
180187
EXPORT_SCRIPT=edsr
181188
# Additional deps for edsr
@@ -189,7 +196,7 @@ test_model_with_qnn() {
189196
# TODO(guangyang): Make QNN chipset matches the target device
190197
QNN_CHIPSET=SM8450
191198

192-
"${PYTHON_EXECUTABLE}" -m examples.qualcomm.scripts.${EXPORT_SCRIPT} -b ${CMAKE_OUTPUT_DIR} -m ${QNN_CHIPSET} --compile_only
199+
"${PYTHON_EXECUTABLE}" -m examples.qualcomm.scripts.${EXPORT_SCRIPT} -b ${CMAKE_OUTPUT_DIR} -m ${QNN_CHIPSET} --compile_only $EXTRA_FLAGS
193200
EXPORTED_MODEL=$(find "./${EXPORT_SCRIPT}" -type f -name "${MODEL_NAME}*.pte" -print -quit)
194201
}
195202

.github/workflows/trunk.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ jobs:
311311
strategy:
312312
matrix:
313313
dtype: [fp32]
314-
model: [dl3, mv3, mv2, ic4, ic3, vit]
314+
model: [dl3, mv3, mv2, ic4, ic3, vit, mb, w2l]
315315
fail-fast: false
316316
with:
317317
runner: linux.2xlarge

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
from executorch.examples.models.mobilenet_v3 import MV3Model
7474
from executorch.examples.models.torchvision_vit.model import TorchVisionViTModel
7575

76-
# from executorch.examples.models.wav2letter import Wav2LetterModel
76+
from executorch.examples.models.wav2letter import Wav2LetterModel
7777
from executorch.exir import to_edge
7878
from executorch.exir.backend.backend_api import disable_validation
7979
from executorch.exir.passes import PassManager
@@ -907,8 +907,7 @@ def test_qnn_backend_example_models(self):
907907
# Fail during lowering Reopen once resolved
908908
# MobileBertModelExample(),
909909
# TorchVisionViTModel(),
910-
# Encountered undefined symbol in mainline. Reopen once resolved.
911-
# Wav2LetterModel(),
910+
Wav2LetterModel(),
912911
]
913912
expected_partitions = [
914913
1,
@@ -917,8 +916,8 @@ def test_qnn_backend_example_models(self):
917916
1,
918917
1,
919918
1,
920-
1,
921-
1,
919+
# 1,
920+
# 1,
922921
1,
923922
]
924923
# TODO: Due to trigger maximum recursion depth exceeded, need to check it.
@@ -1962,12 +1961,11 @@ def test_qnn_backend_example_models(self):
19621961
QCOM_ANNOTATION: (),
19631962
QCOM_QUANT_DTYPE: QuantDtype.use_8a8w,
19641963
},
1965-
# Encountered undefined symbol in mainline. Reopen once resolved.
1966-
# {
1967-
# QCOM_MODULE: Wav2LetterModel(),
1968-
# QCOM_ANNOTATION: (),
1969-
# QCOM_QUANT_DTYPE: QuantDtype.use_8a8w,
1970-
# },
1964+
{
1965+
QCOM_MODULE: Wav2LetterModel(),
1966+
QCOM_ANNOTATION: (),
1967+
QCOM_QUANT_DTYPE: QuantDtype.use_8a8w,
1968+
},
19711969
]
19721970
expected_partitions = [
19731971
1,
@@ -1979,7 +1977,7 @@ def test_qnn_backend_example_models(self):
19791977
# For MobileBertModelExample
19801978
# 1,
19811979
1,
1982-
# 1, For Wav2LetterModel
1980+
1,
19831981
]
19841982
# TODO: Due to trigger maximum recursion depth exceeded, need to check it.
19851983
disable_validation()

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -843,6 +843,7 @@ def post_process():
843843
)
844844

845845
runner_cmd = ""
846+
performance_output_path = "outputs/inference_speed.txt"
846847
if args.enable_x86_64:
847848
# x86 emulator is intended for CI and not performance. Check only the first few tokens.
848849
seq_len = min(seq_len, 16)
@@ -862,6 +863,7 @@ def post_process():
862863
f"--model_path {pte_path}",
863864
f"--seq_len {seq_len}",
864865
f"--output_path {args.artifact}/outputs/outputs.txt",
866+
f"--performance_output_path {performance_output_path}",
865867
f"--kv_updater ShiftPointer",
866868
runner_args,
867869
]
@@ -882,6 +884,7 @@ def post_process():
882884
f"--model_path {pte_filename}.pte",
883885
f"--seq_len {seq_len}",
884886
"--output_path outputs/outputs.txt",
887+
f"--performance_output_path {performance_output_path}",
885888
f"--kv_updater {'SmartMask' if args.kv_updater == smart_mask_updater else 'ShiftPointer'}",
886889
runner_args,
887890
]
@@ -905,7 +908,7 @@ def post_process():
905908
adb.pull(output_path=args.artifact, callback=post_process)
906909
if args.ip and args.port != -1:
907910
inference_speed = 0
908-
with open(f"{args.artifact}/outputs/inference_speed.txt", "r") as f:
911+
with open(f"{args.artifact}/{performance_output_path}", "r") as f:
909912
inference_speed = float(f.read())
910913

911914
pte_size = os.path.getsize(pte_path)

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ DEFINE_string(
3030
output_path,
3131
"outputs.txt",
3232
"Executorch inference data output path.");
33+
DEFINE_string(
34+
performance_output_path,
35+
"inference_speed.txt",
36+
"Records inference speed. For CI purpose.");
3337
DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff.");
3438
DEFINE_string(prompt, "The answer to the ultimate question is", "Prompt.");
3539
DEFINE_string(
@@ -63,6 +67,7 @@ int main(int argc, char** argv) {
6367
example::Runner runner(
6468
{FLAGS_model_path},
6569
FLAGS_tokenizer_path.c_str(),
70+
FLAGS_performance_output_path.c_str(),
6671
FLAGS_logits_scale,
6772
FLAGS_logits_offset,
6873
FLAGS_temperature,

examples/qualcomm/oss_scripts/llama/runner/runner.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,16 @@ namespace example {
3434

3535
namespace {
3636
static constexpr auto kTopp = 0.9f;
37-
void printReport(const Runner::Stats& stats);
37+
void printReport(
38+
const Runner::Stats& stats,
39+
const std::string& performance_output_path);
3840
std::string statsToJsonString(const Runner::Stats& stats);
3941
} // namespace
4042

4143
Runner::Runner(
4244
const std::vector<std::string>& models_path,
4345
const std::string& tokenizer_path,
46+
const std::string& performance_output_path,
4447
const float logits_scale,
4548
const int32_t logits_offset,
4649
const float temperature,
@@ -49,6 +52,7 @@ Runner::Runner(
4952
: n_bos_(1),
5053
n_eos_(1),
5154
tokenizer_path_(tokenizer_path),
55+
performance_output_path_(performance_output_path),
5256
logits_scale_(logits_scale),
5357
logits_offset_(logits_offset),
5458
temperature_(temperature),
@@ -437,7 +441,7 @@ Error Runner::generate(
437441

438442
stats_.num_prompt_tokens = num_prompt_tokens;
439443
stats_.num_generated_tokens = pos - num_prompt_tokens;
440-
printReport(stats_);
444+
printReport(stats_, performance_output_path_);
441445
if (stats_callback) {
442446
stats_callback(stats_);
443447
}
@@ -446,7 +450,9 @@ Error Runner::generate(
446450
}
447451

448452
namespace {
449-
void printReport(const Runner::Stats& stats) {
453+
void printReport(
454+
const Runner::Stats& stats,
455+
const std::string& performance_output_path) {
450456
printf("PyTorchObserver %s\n", statsToJsonString(stats).c_str());
451457

452458
ET_LOG(
@@ -507,7 +513,8 @@ void printReport(const Runner::Stats& stats) {
507513

508514
// For now, we just print the total inference time for CI, can save more info
509515
// in future if needed.
510-
std::ofstream outfile("outputs/inference_speed.txt");
516+
517+
std::ofstream outfile(performance_output_path.c_str());
511518
if (outfile.is_open()) {
512519
double num_tok = (stats.num_generated_tokens) /
513520
(double)(stats.inference_end_ms - stats.inference_start_ms) *

examples/qualcomm/oss_scripts/llama/runner/runner.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class Runner {
2929
explicit Runner(
3030
const std::vector<std::string>& models_path,
3131
const std::string& tokenizer_path,
32+
const std::string& performance_output_path_,
3233
const float logits_scale,
3334
const int32_t logits_offset,
3435
const float temperature,
@@ -101,6 +102,7 @@ class Runner {
101102
const int32_t n_eos_;
102103
std::vector<std::shared_ptr<executorch::extension::Module>> modules_;
103104
std::string tokenizer_path_;
105+
std::string performance_output_path_;
104106
float logits_scale_;
105107
int32_t logits_offset_;
106108
float temperature_;

examples/qualcomm/scripts/mobilebert_fine_tune.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def get_fine_tuned_mobilebert(artifacts_dir, pretrained_weight, batch_size):
169169
dataset_train = TensorDataset(input_ids_train, attention_masks_train, labels_train)
170170
dataset_val = TensorDataset(input_ids_val, attention_masks_val, labels_val)
171171

172-
epochs = 5
172+
epochs = args.num_epochs
173173
dataloader_train = DataLoader(
174174
dataset_train,
175175
sampler=RandomSampler(dataset_train),
@@ -366,6 +366,13 @@ def calibrator(gm):
366366
type=str,
367367
)
368368

369+
parser.add_argument(
370+
"--num_epochs",
371+
help="If no pretrained weights are provided, set number of epochs to train the model",
372+
default=5,
373+
type=int,
374+
)
375+
369376
parser.add_argument(
370377
"-F",
371378
"--use_fp16",

examples/qualcomm/scripts/wav2letter.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import json
8+
import logging
89
import os
910
import sys
1011
from multiprocessing.connection import Client
@@ -111,7 +112,12 @@ def main(args):
111112
# target labels " abcdefghijklmnopqrstuvwxyz'*"
112113
instance.vocab_size = 29
113114
model = instance.get_eager_model().eval()
114-
model.load_state_dict(torch.load(args.pretrained_weight, weights_only=True))
115+
if args.pretrained_weight:
116+
model.load_state_dict(torch.load(args.pretrained_weight, weights_only=True))
117+
else:
118+
logging.warning(
119+
"It is strongly recommended to provide pretrained weights, otherwise accuracy will be bad. This option is here mainly for CI purpose to ensure compile is successful."
120+
)
115121

116122
# convert conv1d to conv2d in nn.Module level will only introduce 2 permute
117123
# nodes around input & output, which is more quantization friendly.
@@ -128,9 +134,15 @@ def main(args):
128134

129135
# retrieve dataset, will take some time to download
130136
data_num = 100
131-
inputs, targets, input_list = get_dataset(
132-
data_size=data_num, artifact_dir=args.artifact
133-
)
137+
if args.compile_only:
138+
inputs = [(torch.rand(1, 1, 700, 1),)]
139+
logging.warning(
140+
"With compile_only, accuracy will be bad due to insufficient datasets for quantization."
141+
)
142+
else:
143+
inputs, targets, input_list = get_dataset(
144+
data_size=data_num, artifact_dir=args.artifact
145+
)
134146
pte_filename = "w2l_qnn"
135147
build_executorch_binary(
136148
model,
@@ -212,7 +224,7 @@ def main(args):
212224
),
213225
default=None,
214226
type=str,
215-
required=True,
227+
required=False,
216228
)
217229

218230
args = parser.parse_args()

0 commit comments

Comments
 (0)