Skip to content

Commit c76b22f

Browse files
shewu-quicSheng Feng Wu
andauthored
Qualcomm AI Engine Direct - Fixed the order of the transforms for llama (#5221)
* Qualcomm AI Engine Direct - Fixed the order of the transforms for llama * fixed ci --------- Co-authored-by: Sheng Feng Wu <[email protected]>
1 parent 02304d7 commit c76b22f

File tree

4 files changed

+21
-14
lines changed

4 files changed

+21
-14
lines changed

examples/models/llama2/eval_llama_lib.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,15 @@ def __init__(
4141
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
4242
max_seq_length: Optional[int] = None,
4343
use_kv_cache: bool = False,
44+
generate_full_logits: bool = False,
4445
enable_dynamic_shape: bool = True,
4546
):
4647
super().__init__(
4748
model=model, tokenizer=tokenizer, max_seq_length=max_seq_length
4849
)
4950
self._model = model.to(self.device)
5051
self._use_kv_cache = use_kv_cache
52+
self._generate_full_logits = generate_full_logits
5153
self._enable_dynamic_shape = enable_dynamic_shape
5254

5355
def _model_call(self, inps):
@@ -60,7 +62,10 @@ def _model_call(self, inps):
6062
pos_tensor = torch.tensor([pos], dtype=torch.int64)
6163
logits = self._model(inps[:, pos : pos + 1], pos_tensor)
6264
result_logits.append(logits)
63-
return torch.cat(result_logits, dim=1)
65+
if self._generate_full_logits:
66+
return torch.cat(result_logits, dim=1)
67+
else:
68+
return torch.stack(result_logits, dim=1)
6469
else:
6570
pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device)
6671
# Batch process the whole sequence.

examples/models/llama2/export_llama_lib.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def build_args_parser() -> argparse.ArgumentParser:
233233
"--optimized_rotation_path",
234234
default=None,
235235
required=False,
236-
help="[QNN Backend] Optimized rotation checkpoint path. Just apply R1/R2 here."
236+
help="[QNN backend] Optimized rotation checkpoint path. Just apply R1/R2 here."
237237
"You can download the optimized rotation matrices from https://github.com/facebookresearch/SpinQuant/tree/main",
238238
)
239239
parser.add_argument(
@@ -440,6 +440,9 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
440440
transforms.append(replace_sdpa_with_flex_sdpa)
441441
transforms.append(replace_causal_mask)
442442
transforms.append(replace_rms_norm_with_native_rms_norm)
443+
if args.optimized_rotation_path:
444+
transforms.append(fuse_layer_norms)
445+
transforms.append(get_model_with_r1_r2(args.optimized_rotation_path))
443446
transforms.append(convert_linear_to_conv2d)
444447

445448
elif args.coreml or args.mps:
@@ -448,9 +451,6 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
448451
transforms.append(replace_sdpa_with_simple_sdpa)
449452
transforms.append(replace_causal_mask)
450453

451-
if args.optimized_rotation_path:
452-
transforms.append(fuse_layer_norms)
453-
transforms.append(get_model_with_r1_r2(args.optimized_rotation_path))
454454
return (
455455
_load_llama_model(
456456
modelname=modelname,
@@ -744,6 +744,7 @@ def _load_llama_model(
744744
max_seq_len=model.params.max_seq_len,
745745
dtype=dtype,
746746
use_kv_cache=use_kv_cache,
747+
generate_full_logits=generate_full_logits,
747748
example_inputs=example_inputs,
748749
enable_dynamic_shape=enable_dynamic_shape,
749750
calibration_tasks=calibration_tasks,

extension/llm/export/builder.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(
6969
example_inputs,
7070
args: Optional[Any] = None,
7171
enable_dynamic_shape: bool = False,
72+
generate_full_logits: bool = False,
7273
calibration_tasks: Optional[List[str]] = None,
7374
calibration_limit: Optional[int] = None,
7475
calibration_seq_length: Optional[int] = None,
@@ -86,6 +87,7 @@ def __init__(
8687
self.dtype = dtype
8788
self.example_inputs = example_inputs
8889
self.use_kv_cache = use_kv_cache
90+
self.generate_full_logits = generate_full_logits
8991
self.enable_dynamic_shape = enable_dynamic_shape
9092
self.verbose = verbose
9193
self.metadata = metadata
@@ -229,7 +231,12 @@ def calibrate_template(
229231
)
230232
pos += 1
231233
if pos >= len(token_list):
232-
token_list.append(torch.argmax(logits[:], dim=-1).item())
234+
if self.generate_full_logits:
235+
token_list.append(
236+
torch.argmax(logits[:, -1], dim=-1).item()
237+
)
238+
else:
239+
token_list.append(torch.argmax(logits[:], dim=-1).item())
233240

234241
calibrate_template(
235242
module=prepared_module,
@@ -243,6 +250,7 @@ def calibrate_template(
243250
tokenizer=tokenizer,
244251
max_seq_length=calibration_seq_length,
245252
use_kv_cache=self.use_kv_cache,
253+
generate_full_logits=self.generate_full_logits,
246254
enable_dynamic_shape=self.enable_dynamic_shape,
247255
)
248256
eval_results = evaluate_model(

extension/llm/export/partitioner_lib.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -139,16 +139,9 @@ def get_qnn_partitioner(
139139
if pt2e_quantize is not None:
140140
use_fp16 = False
141141

142-
soc_chip_table = {
143-
"SM8650": QcomChipset.SM8650,
144-
"SM8550": QcomChipset.SM8550,
145-
"SM8475": QcomChipset.SM8475,
146-
"SM8450": QcomChipset.SM8450,
147-
}
148-
149142
return QnnPartitioner( # pyre-fixme[16]
150143
generate_qnn_executorch_compiler_spec( # pyre-fixme[16]
151-
soc_model=soc_chip_table[soc_model], # pyre-fixme[16]
144+
soc_model=getattr(QcomChipset, soc_model), # pyre-fixme[16]
152145
# pyre-fixme[16]
153146
backend_options=generate_htp_compiler_spec(
154147
use_fp16=use_fp16,

0 commit comments

Comments
 (0)