Skip to content

Commit 7edf990

Browse files
author
Guang Yang
committed
Added soc model param to get_qnn_partitioner
1 parent e4a2322 commit 7edf990

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,11 +510,18 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
510510
modelname = f"coreml_{modelname}"
511511

512512
if args.qnn:
513+
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.serialization.qnn_compile_spec_schema`
514+
from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import (
515+
QcomChipset,
516+
)
513517
from executorch.extension.llm.custom_ops import model_sharding
514518

515519
partitioners.append(
516520
get_qnn_partitioner(
517-
args.use_kv_cache, args.pt2e_quantize, args.num_sharding
521+
QcomChipset.SM8650, # Llama 2 works only on SM8650
522+
args.use_kv_cache,
523+
args.pt2e_quantize,
524+
args.num_sharding,
518525
)
519526
)
520527
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`

extension/llm/export/partitioner_lib.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import logging
78
from typing import Optional
89

10+
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
11+
logging.basicConfig(level=logging.INFO, format=FORMAT)
12+
913

1014
def get_xnnpack_partitioner():
1115
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
@@ -105,6 +109,7 @@ def get_coreml_partitioner(
105109

106110

107111
def get_qnn_partitioner(
112+
soc_model,
108113
use_kv_cache: bool = False,
109114
pt2e_quantize: Optional[str] = None,
110115
num_sharding: int = 0,
@@ -118,11 +123,6 @@ def get_qnn_partitioner(
118123
QnnPartitioner,
119124
)
120125

121-
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.serialization.qnn_compile_spec_schema`
122-
from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import (
123-
QcomChipset,
124-
)
125-
126126
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
127127
from executorch.backends.qualcomm.utils.utils import (
128128
generate_htp_compiler_spec,
@@ -133,14 +133,16 @@ def get_qnn_partitioner(
133133
"Please install the Qualcomm backend follwing https://pytorch.org/executorch/main/build-run-qualcomm-ai-engine-direct-backend.html"
134134
)
135135

136+
logging.info(f"Get QNN partitioner for {soc_model.name}.")
137+
136138
use_fp16 = True
137139
skip_node_op_set = {"llama.fallback.default"}
138140
if pt2e_quantize is not None:
139141
use_fp16 = False
140142

141143
return QnnPartitioner( # pyre-fixme[16]
142144
generate_qnn_executorch_compiler_spec( # pyre-fixme[16]
143-
soc_model=QcomChipset.SM8450, # default to SM8450 # pyre-fixme[16]
145+
soc_model=soc_model, # pyre-fixme[16]
144146
# pyre-fixme[16]
145147
backend_options=generate_htp_compiler_spec(
146148
use_fp16=use_fp16,

0 commit comments

Comments
 (0)