Skip to content

Commit 85a0556

Browse files
author
Joey Tsai
committed
Fix lint
- Fix transformers version - Refine pass quantization tagging function
1 parent fcc10de commit 85a0556

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from torch.fx import Node
2323

2424

25-
def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None:
25+
def annotate_matmul_16a8w(gm: torch.fx.GraphModule, traverse_input1=True) -> None:
2626
"""
2727
This function is specific for matmul op 16a8w.
2828
"""
@@ -99,7 +99,8 @@ def annotate_matmul_input1(node: Node):
9999
for node in gm.graph.nodes:
100100
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
101101
annotate_matmul(node, quantization_config_16a8w)
102-
annotate_matmul_input1(node.args[1])
102+
if traverse_input1:
103+
annotate_matmul_input1(node.args[1])
103104

104105

105106
def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901

examples/qualcomm/oss_scripts/llama3_2/llama.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
import json
99
import logging
1010
import os
11-
1211
import sys
1312
import time
13+
from functools import partial
1414
from multiprocessing.connection import Client
1515

1616
import torch
@@ -319,8 +319,10 @@ def compile(args):
319319

320320
if args.model_mode == "kv":
321321
use_kv_cache = output_new_cache_only = True
322+
matmul_annotate_func = partial(annotate_matmul_16a8w, traverse_input1=True)
322323
elif args.model_mode == "batch_prefill":
323324
use_kv_cache = output_new_cache_only = False
325+
matmul_annotate_func = partial(annotate_matmul_16a8w, traverse_input1=False)
324326
elif args.model_mode == "hybrid":
325327
raise NotImplementedError(
326328
f"model_mode {args.model_mode} is not implemented yet."
@@ -387,7 +389,7 @@ def compile(args):
387389
quant_dtype,
388390
custom_annotations=(
389391
custom_annotate_llama_last_conv_16a8w,
390-
annotate_matmul_16a8w,
392+
matmul_annotate_func,
391393
),
392394
)
393395
end_quantize_ts = time.time()

install_requirements.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def python_is_compatible():
137137
"timm==1.0.7",
138138
f"torchaudio==2.5.0.{NIGHTLY_VERSION}" if USE_PYTORCH_NIGHTLY else "torchaudio",
139139
"torchsr==1.0.4",
140-
"transformers==4.42.4", # TODO update back to 4.46.1 once the error is fixed
140+
"transformers==4.46.1",
141141
]
142142

143143
# pip packages needed for development.

0 commit comments

Comments
 (0)