Skip to content

Commit 1732d06

Browse files
committed
Qualcomm AI Engine Direct - Optimization in static llama
summary: - Fuse rms norm - Improve performance of div op - Fixed 16a8w annotation for matmul op
1 parent 21eecff commit 1732d06

File tree

3 files changed

+87
-7
lines changed

3 files changed

+87
-7
lines changed

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,94 @@
1616
QUANT_ANNOTATION_KEY,
1717
)
1818
from executorch.exir.dialects._ops import ops as exir_ops
19+
from torch.ao.quantization.observer import MinMaxObserver
1920
from torch.ao.quantization.quantizer import (
2021
QuantizationAnnotation,
2122
SharedQuantizationSpec,
2223
)
2324
from torch.fx import Node
2425

2526

27+
def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None:
28+
"""
29+
This function is specific for matmul op 16a8w.
30+
"""
31+
32+
def annotate_matmul(node: Node, quantization_config: QuantizationConfig):
33+
input_qspec_map = {}
34+
input_act = node.args[0]
35+
input_spec = quantization_config.input_activation
36+
input_qspec_map[input_act] = input_spec
37+
38+
input_act1 = node.args[1]
39+
input_spec1 = quantization_config.weight
40+
input_qspec_map[input_act1] = input_spec1
41+
42+
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
43+
input_qspec_map=input_qspec_map,
44+
output_qspec=quantization_config.output_activation,
45+
_annotated=True,
46+
)
47+
48+
def annotate_cat(node: Node, quantization_config: QuantizationConfig):
49+
input_nodes = node.args[0]
50+
51+
first_input_node = input_nodes[0]
52+
input_qspec_map = {}
53+
input_qspec_map[first_input_node] = quantization_config.input_activation
54+
share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
55+
(first_input_node, node)
56+
)
57+
58+
for input_node in input_nodes[1:]:
59+
if input_node not in input_qspec_map:
60+
input_qspec_map[input_node] = share_qparams_with_input_act0_qspec
61+
62+
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
63+
input_qspec_map=input_qspec_map,
64+
output_qspec=share_qparams_with_input_act0_qspec,
65+
_annotated=True,
66+
)
67+
68+
def annotate_single_in_single_out(
69+
node: Node, quantization_config: QuantizationConfig
70+
) -> None:
71+
72+
input_qspec_map = {}
73+
input_act = node.args[0]
74+
input_qspec_map[input_act] = quantization_config.input_activation
75+
76+
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
77+
input_qspec_map=input_qspec_map,
78+
output_qspec=quantization_config.output_activation,
79+
_annotated=True,
80+
)
81+
82+
def annotate_matmul_input1(node: Node):
83+
quantization_config_8a8w = get_default_8bit_qnn_ptq_config(
84+
act_symmetric=True, act_observer=MinMaxObserver
85+
)
86+
while isinstance(node, Node) and node.op == "call_function":
87+
if node.target in [
88+
torch.ops.aten.permute.default,
89+
torch.ops.aten.transpose.int,
90+
]:
91+
annotate_single_in_single_out(node, quantization_config_8a8w)
92+
node = node.args[0]
93+
elif node.target == torch.ops.aten.cat.default:
94+
annotate_cat(node, quantization_config_8a8w)
95+
node = node.args[0][0]
96+
else:
97+
node = node.args[0]
98+
99+
quantization_config_16a8w = get_16a8w_qnn_ptq_config(act_observer=MinMaxObserver)
100+
101+
for node in gm.graph.nodes:
102+
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
103+
annotate_matmul(node, quantization_config_16a8w)
104+
annotate_matmul_input1(node.args[1])
105+
106+
26107
def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901
27108
"""
28109
This function is specific for llama matmul op 16a8w.

examples/qualcomm/oss_scripts/llama2/model/static_llama.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
FeedForward,
1414
ModelArgs,
1515
precompute_freqs_cis,
16-
RMSNorm,
1716
)
1817

1918

@@ -191,8 +190,8 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False):
191190
config=config, output_new_cache_only=output_new_cache_only
192191
)
193192
self.feed_forward = FeedForward(config)
194-
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
195-
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
193+
self.attention_norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps)
194+
self.ffn_norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps)
196195

197196
def forward(
198197
self,
@@ -236,7 +235,7 @@ def __init__(self, config: ModelArgs, output_new_cache_only=True):
236235
for _ in range(config.n_layers)
237236
]
238237
)
239-
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
238+
self.norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps)
240239
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
241240
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
242241
freqs_cos, freqs_sin = precompute_freqs_cis(

examples/qualcomm/oss_scripts/llama3_2/llama.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner
2020

2121
from executorch.backends.qualcomm.quantizer.custom_annotation import (
22+
annotate_matmul_16a8w,
2223
custom_annotate_llama_last_conv_16a8w,
23-
custom_annotate_llama_matmul_16a8w,
2424
)
2525

2626
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
@@ -78,7 +78,7 @@ def calibrate(
7878
token_list = sp_model.encode(user_prompts, bos=True, eos=False)
7979

8080
with torch.no_grad():
81-
while token_list[-1] != sp_model.eos_id and pos < 512:
81+
while token_list[-1] != sp_model.eos_id and pos < 511:
8282
logits, new_k_caches, new_v_caches = module(
8383
torch.full((1, 1), token_list[pos], dtype=torch.int32),
8484
torch.full((1, 1), pos),
@@ -297,7 +297,7 @@ def compile(args):
297297
quant_dtype,
298298
custom_annotations=(
299299
custom_annotate_llama_last_conv_16a8w,
300-
custom_annotate_llama_matmul_16a8w,
300+
annotate_matmul_16a8w,
301301
),
302302
)
303303
end_quantize_ts = time.time()

0 commit comments

Comments
 (0)