Skip to content

Qualcomm AI Engine Direct - Optimization in static llama #6849

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions backends/qualcomm/quantizer/custom_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,94 @@
QUANT_ANNOTATION_KEY,
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch.ao.quantization.observer import MinMaxObserver
from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
SharedQuantizationSpec,
)
from torch.fx import Node


def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None:
"""
This function is specific for matmul op 16a8w.
"""

def annotate_matmul(node: Node, quantization_config: QuantizationConfig):
input_qspec_map = {}
input_act = node.args[0]
input_spec = quantization_config.input_activation
input_qspec_map[input_act] = input_spec

input_act1 = node.args[1]
input_spec1 = quantization_config.weight
input_qspec_map[input_act1] = input_spec1

node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=quantization_config.output_activation,
_annotated=True,
)

def annotate_cat(node: Node, quantization_config: QuantizationConfig):
input_nodes = node.args[0]

first_input_node = input_nodes[0]
input_qspec_map = {}
input_qspec_map[first_input_node] = quantization_config.input_activation
share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
(first_input_node, node)
)

for input_node in input_nodes[1:]:
if input_node not in input_qspec_map:
input_qspec_map[input_node] = share_qparams_with_input_act0_qspec

node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=share_qparams_with_input_act0_qspec,
_annotated=True,
)

def annotate_single_in_single_out(
node: Node, quantization_config: QuantizationConfig
) -> None:

input_qspec_map = {}
input_act = node.args[0]
input_qspec_map[input_act] = quantization_config.input_activation

node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=quantization_config.output_activation,
_annotated=True,
)

def annotate_matmul_input1(node: Node):
quantization_config_8a8w = get_default_8bit_qnn_ptq_config(
act_symmetric=True, act_observer=MinMaxObserver
)
while isinstance(node, Node) and node.op == "call_function":
if node.target in [
torch.ops.aten.permute.default,
torch.ops.aten.transpose.int,
]:
annotate_single_in_single_out(node, quantization_config_8a8w)
node = node.args[0]
elif node.target == torch.ops.aten.cat.default:
annotate_cat(node, quantization_config_8a8w)
node = node.args[0][0]
Comment on lines +93 to +95
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What pattern is this trying to capture?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following pattern.

                        q (16 bits) -------\
                                 matmul op (16 bits)
past k / v (8 bits) -------\
                     cat op (8 bits) ----/
new k / v (8 bits)---------/
(transpose after k)

else:
node = node.args[0]

quantization_config_16a8w = get_16a8w_qnn_ptq_config(act_observer=MinMaxObserver)

for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
annotate_matmul(node, quantization_config_16a8w)
annotate_matmul_input1(node.args[1])


def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901
"""
This function is specific for llama matmul op 16a8w.
Expand Down
7 changes: 3 additions & 4 deletions examples/qualcomm/oss_scripts/llama2/model/static_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
FeedForward,
ModelArgs,
precompute_freqs_cis,
RMSNorm,
)


Expand Down Expand Up @@ -191,8 +190,8 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False):
config=config, output_new_cache_only=output_new_cache_only
)
self.feed_forward = FeedForward(config)
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.attention_norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps)

def forward(
self,
Expand Down Expand Up @@ -236,7 +235,7 @@ def __init__(self, config: ModelArgs, output_new_cache_only=True):
for _ in range(config.n_layers)
]
)
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
self.norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
freqs_cos, freqs_sin = precompute_freqs_cis(
Expand Down
6 changes: 3 additions & 3 deletions examples/qualcomm/oss_scripts/llama3_2/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner

from executorch.backends.qualcomm.quantizer.custom_annotation import (
annotate_matmul_16a8w,
custom_annotate_llama_last_conv_16a8w,
custom_annotate_llama_matmul_16a8w,
)

from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
Expand Down Expand Up @@ -78,7 +78,7 @@ def calibrate(
token_list = sp_model.encode(user_prompts, bos=True, eos=False)

with torch.no_grad():
while token_list[-1] != sp_model.eos_id and pos < 512:
while token_list[-1] != sp_model.eos_id and pos < 511:
logits, new_k_caches, new_v_caches = module(
torch.full((1, 1), token_list[pos], dtype=torch.int32),
torch.full((1, 1), pos),
Expand Down Expand Up @@ -297,7 +297,7 @@ def compile(args):
quant_dtype,
custom_annotations=(
custom_annotate_llama_last_conv_16a8w,
custom_annotate_llama_matmul_16a8w,
annotate_matmul_16a8w,
),
)
end_quantize_ts = time.time()
Expand Down
Loading