Skip to content

Commit 08c4742

Browse files
author
Joey Tsai
committed
Fix lint
- Fix transformers version - Refine pass quantization tagging function - Rebase
1 parent 7e6b3cc commit 08c4742

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
from torch.fx import Node
2323

2424

25-
def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901
25+
def annotate_matmul_16a8w(
26+
gm: torch.fx.GraphModule, traverse_input1=True
27+
) -> None: # noqa: C901
2628
"""
2729
This function is specific for matmul op 16a8w.
2830
"""
@@ -99,7 +101,8 @@ def annotate_matmul_input1(node: Node):
99101
for node in gm.graph.nodes:
100102
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
101103
annotate_matmul(node, quantization_config_16a8w)
102-
annotate_matmul_input1(node.args[1])
104+
if traverse_input1:
105+
annotate_matmul_input1(node.args[1])
103106

104107

105108
def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901

examples/qualcomm/oss_scripts/llama3_2/llama.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
import json
99
import logging
1010
import os
11-
1211
import sys
1312
import time
13+
from functools import partial
1414
from multiprocessing.connection import Client
1515

1616
import torch
@@ -319,8 +319,10 @@ def compile(args):
319319

320320
if args.model_mode == "kv":
321321
use_kv_cache = output_new_cache_only = True
322+
matmul_annotate_func = partial(annotate_matmul_16a8w, traverse_input1=True)
322323
elif args.model_mode == "batch_prefill":
323324
use_kv_cache = output_new_cache_only = False
325+
matmul_annotate_func = partial(annotate_matmul_16a8w, traverse_input1=False)
324326
elif args.model_mode == "hybrid":
325327
raise NotImplementedError(
326328
f"model_mode {args.model_mode} is not implemented yet."
@@ -385,7 +387,10 @@ def compile(args):
385387
start_quantize_ts = time.time()
386388
single_llama.quantize(
387389
quant_dtype,
388-
custom_annotations=(annotate_matmul_16a8w,),
390+
custom_annotations=(
391+
custom_annotate_llama_last_conv_16a8w,
392+
matmul_annotate_func,
393+
),
389394
)
390395
end_quantize_ts = time.time()
391396
logging.info(f"Time for quantizing: {end_quantize_ts - start_quantize_ts}")

install_requirements.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def python_is_compatible():
137137
"timm==1.0.7",
138138
f"torchaudio==2.5.0.{NIGHTLY_VERSION}" if USE_PYTORCH_NIGHTLY else "torchaudio",
139139
"torchsr==1.0.4",
140-
"transformers==4.42.4", # TODO update back to 4.46.1 once the error is fixed
140+
"transformers==4.46.1",
141141
]
142142

143143
# pip packages needed for development.

0 commit comments

Comments
 (0)