Skip to content

Commit 67ae762

Browse files
authored
Qualcomm AI Engine Direct - Add the argument to specify soc model (#5211)
* Qualcomm AI Engine Direct - Add the argument to specify soc model * address review
1 parent f471556 commit 67ae762

File tree

3 files changed

+19
-10
lines changed

3 files changed

+19
-10
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,14 @@ def build_args_parser() -> argparse.ArgumentParser:
321321
default=False,
322322
help="Generate logits for all inputs.",
323323
)
324+
325+
parser.add_argument(
326+
"--soc_model",
327+
help="[QNN backend] SoC model of current device. e.g. 'SM8650' for Snapdragon 8 Gen 3.",
328+
type=str,
329+
required=False,
330+
default="SM8650",
331+
)
324332
return parser
325333

326334

@@ -540,7 +548,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
540548

541549
partitioners.append(
542550
get_qnn_partitioner(
543-
args.use_kv_cache, args.pt2e_quantize, args.num_sharding
551+
args.use_kv_cache, args.pt2e_quantize, args.num_sharding, args.soc_model
544552
)
545553
)
546554
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`

examples/qualcomm/utils.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -230,19 +230,12 @@ def build_executorch_binary(
230230
else:
231231
edge_prog = capture_program(model, inputs)
232232

233-
arch_table = {
234-
"SM8650": QcomChipset.SM8650,
235-
"SM8550": QcomChipset.SM8550,
236-
"SM8475": QcomChipset.SM8475,
237-
"SM8450": QcomChipset.SM8450,
238-
}
239-
240233
backend_options = generate_htp_compiler_spec(
241234
use_fp16=False if quant_dtype else True
242235
)
243236
qnn_partitioner = QnnPartitioner(
244237
generate_qnn_executorch_compiler_spec(
245-
soc_model=arch_table[soc_model],
238+
soc_model=getattr(QcomChipset, soc_model),
246239
backend_options=backend_options,
247240
debug=False,
248241
saver=False,

extension/llm/export/partitioner_lib.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def get_qnn_partitioner(
108108
use_kv_cache: bool = False,
109109
pt2e_quantize: Optional[str] = None,
110110
num_sharding: int = 0,
111+
soc_model: str = "SM8650", # default to SM8650
111112
):
112113
assert (
113114
use_kv_cache is True
@@ -138,9 +139,16 @@ def get_qnn_partitioner(
138139
if pt2e_quantize is not None:
139140
use_fp16 = False
140141

142+
soc_chip_table = {
143+
"SM8650": QcomChipset.SM8650,
144+
"SM8550": QcomChipset.SM8550,
145+
"SM8475": QcomChipset.SM8475,
146+
"SM8450": QcomChipset.SM8450,
147+
}
148+
141149
return QnnPartitioner( # pyre-fixme[16]
142150
generate_qnn_executorch_compiler_spec( # pyre-fixme[16]
143-
soc_model=QcomChipset.SM8650, # default to SM8650 # pyre-fixme[16]
151+
soc_model=soc_chip_table[soc_model], # pyre-fixme[16]
144152
# pyre-fixme[16]
145153
backend_options=generate_htp_compiler_spec(
146154
use_fp16=use_fp16,

0 commit comments

Comments
 (0)