File tree Expand file tree Collapse file tree 3 files changed +13
-5
lines changed
backends/qualcomm/quantizer
examples/qualcomm/oss_scripts/llama3_2 Expand file tree Collapse file tree 3 files changed +13
-5
lines changed Original file line number Diff line number Diff line change 22
22
from torch .fx import Node
23
23
24
24
25
- def annotate_matmul_16a8w (gm : torch .fx .GraphModule ) -> None : # noqa: C901
25
+ def annotate_matmul_16a8w (
26
+ gm : torch .fx .GraphModule , traverse_input1 = True
27
+ ) -> None : # noqa: C901
26
28
"""
27
29
This function is specific for matmul op 16a8w.
28
30
"""
@@ -99,7 +101,8 @@ def annotate_matmul_input1(node: Node):
99
101
for node in gm .graph .nodes :
100
102
if node .op == "call_function" and node .target == torch .ops .aten .matmul .default :
101
103
annotate_matmul (node , quantization_config_16a8w )
102
- annotate_matmul_input1 (node .args [1 ])
104
+ if traverse_input1 :
105
+ annotate_matmul_input1 (node .args [1 ])
103
106
104
107
105
108
def custom_annotate_llama_matmul_16a8w (gm : torch .fx .GraphModule ) -> None : # noqa: C901
Original file line number Diff line number Diff line change 8
8
import json
9
9
import logging
10
10
import os
11
-
12
11
import sys
13
12
import time
13
+ from functools import partial
14
14
from multiprocessing .connection import Client
15
15
16
16
import torch
@@ -319,8 +319,10 @@ def compile(args):
319
319
320
320
if args .model_mode == "kv" :
321
321
use_kv_cache = output_new_cache_only = True
322
+ matmul_annotate_func = partial (annotate_matmul_16a8w , traverse_input1 = True )
322
323
elif args .model_mode == "batch_prefill" :
323
324
use_kv_cache = output_new_cache_only = False
325
+ matmul_annotate_func = partial (annotate_matmul_16a8w , traverse_input1 = False )
324
326
elif args .model_mode == "hybrid" :
325
327
raise NotImplementedError (
326
328
f"model_mode { args .model_mode } is not implemented yet."
@@ -385,7 +387,10 @@ def compile(args):
385
387
start_quantize_ts = time .time ()
386
388
single_llama .quantize (
387
389
quant_dtype ,
388
- custom_annotations = (annotate_matmul_16a8w ,),
390
+ custom_annotations = (
391
+ custom_annotate_llama_last_conv_16a8w ,
392
+ matmul_annotate_func ,
393
+ ),
389
394
)
390
395
end_quantize_ts = time .time ()
391
396
logging .info (f"Time for quantizing: { end_quantize_ts - start_quantize_ts } " )
Original file line number Diff line number Diff line change @@ -137,7 +137,7 @@ def python_is_compatible():
137
137
"timm==1.0.7" ,
138
138
f"torchaudio==2.5.0.{ NIGHTLY_VERSION } " if USE_PYTORCH_NIGHTLY else "torchaudio" ,
139
139
"torchsr==1.0.4" ,
140
- "transformers==4.42.4" , # TODO update back to 4. 46.1 once the error is fixed
140
+ "transformers==4.46.1" ,
141
141
]
142
142
143
143
# pip packages needed for development.
You can’t perform that action at this time.
0 commit comments