Skip to content

Commit ba901f1

Browse files
author
Guang Yang
committed
Added soc model param to get_qnn_partitioner
1 parent 9062a09 commit ba901f1

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,11 +510,18 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
510510
modelname = f"coreml_{modelname}"
511511

512512
if args.qnn:
513+
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.serialization.qnn_compile_spec_schema`
514+
from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import (
515+
QcomChipset,
516+
)
513517
from executorch.extension.llm.custom_ops import model_sharding
514518

515519
partitioners.append(
516520
get_qnn_partitioner(
517-
args.use_kv_cache, args.pt2e_quantize, args.num_sharding
521+
QcomChipset.SM8650, # Llama 2 works only on SM8650
522+
args.use_kv_cache,
523+
args.pt2e_quantize,
524+
args.num_sharding,
518525
)
519526
)
520527
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`

extension/llm/export/partitioner_lib.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def get_coreml_partitioner(
105105

106106

107107
def get_qnn_partitioner(
108+
soc_model,
108109
use_kv_cache: bool = False,
109110
pt2e_quantize: Optional[str] = None,
110111
num_sharding: int = 0,
@@ -118,11 +119,6 @@ def get_qnn_partitioner(
118119
QnnPartitioner,
119120
)
120121

121-
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.serialization.qnn_compile_spec_schema`
122-
from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import (
123-
QcomChipset,
124-
)
125-
126122
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
127123
from executorch.backends.qualcomm.utils.utils import (
128124
generate_htp_compiler_spec,
@@ -140,7 +136,7 @@ def get_qnn_partitioner(
140136

141137
return QnnPartitioner( # pyre-fixme[16]
142138
generate_qnn_executorch_compiler_spec( # pyre-fixme[16]
143-
soc_model=QcomChipset.SM8450, # default to SM8450 # pyre-fixme[16]
139+
soc_model=soc_model, # pyre-fixme[16]
144140
# pyre-fixme[16]
145141
backend_options=generate_htp_compiler_spec(
146142
use_fp16=use_fp16,

0 commit comments

Comments
 (0)