File tree Expand file tree Collapse file tree 3 files changed +8
-5
lines changed
backends/qualcomm/quantizer
examples/qualcomm/oss_scripts/llama3_2 Expand file tree Collapse file tree 3 files changed +8
-5
lines changed Original file line number Diff line number Diff line change 22
22
from torch .fx import Node
23
23
24
24
25
- def annotate_matmul_16a8w (gm : torch .fx .GraphModule ) -> None :
25
+ def annotate_matmul_16a8w (gm : torch .fx .GraphModule , traverse_input1 = True ) -> None :
26
26
"""
27
27
This function is specific for matmul op 16a8w.
28
28
"""
@@ -99,7 +99,8 @@ def annotate_matmul_input1(node: Node):
99
99
for node in gm .graph .nodes :
100
100
if node .op == "call_function" and node .target == torch .ops .aten .matmul .default :
101
101
annotate_matmul (node , quantization_config_16a8w )
102
- annotate_matmul_input1 (node .args [1 ])
102
+ if traverse_input1 :
103
+ annotate_matmul_input1 (node .args [1 ])
103
104
104
105
105
106
def custom_annotate_llama_matmul_16a8w (gm : torch .fx .GraphModule ) -> None : # noqa: C901
Original file line number Diff line number Diff line change 8
8
import json
9
9
import logging
10
10
import os
11
-
12
11
import sys
13
12
import time
13
+ from functools import partial
14
14
from multiprocessing .connection import Client
15
15
16
16
import torch
@@ -319,8 +319,10 @@ def compile(args):
319
319
320
320
if args .model_mode == "kv" :
321
321
use_kv_cache = output_new_cache_only = True
322
+ matmul_annotate_func = partial (annotate_matmul_16a8w , traverse_input1 = True )
322
323
elif args .model_mode == "batch_prefill" :
323
324
use_kv_cache = output_new_cache_only = False
325
+ matmul_annotate_func = partial (annotate_matmul_16a8w , traverse_input1 = False )
324
326
elif args .model_mode == "hybrid" :
325
327
raise NotImplementedError (
326
328
f"model_mode { args .model_mode } is not implemented yet."
@@ -387,7 +389,7 @@ def compile(args):
387
389
quant_dtype ,
388
390
custom_annotations = (
389
391
custom_annotate_llama_last_conv_16a8w ,
390
- annotate_matmul_16a8w ,
392
+ matmul_annotate_func ,
391
393
),
392
394
)
393
395
end_quantize_ts = time .time ()
Original file line number Diff line number Diff line change @@ -137,7 +137,7 @@ def python_is_compatible():
137
137
"timm==1.0.7" ,
138
138
f"torchaudio==2.5.0.{ NIGHTLY_VERSION } " if USE_PYTORCH_NIGHTLY else "torchaudio" ,
139
139
"torchsr==1.0.4" ,
140
- "transformers==4.42.4" , # TODO update back to 4. 46.1 once the error is fixed
140
+ "transformers==4.46.1" ,
141
141
]
142
142
143
143
# pip packages needed for development.
You can’t perform that action at this time.
0 commit comments