Skip to content

Added quantization for evaluation script #11822

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/qualcomm/oss_scripts/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ python_binary(
name = "eval_llama_qnn",
srcs = ["eval_llama_qnn.py"],
main_function = "executorch.examples.qualcomm.oss_scripts.llama.eval_llama_qnn.main",
preload_deps = [
"//executorch/extension/llm/custom_ops:model_sharding_py",
],
deps = [
":llama_lib",
"//executorch/examples/models/llama:eval_library",
Expand Down
115 changes: 101 additions & 14 deletions examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,50 +8,74 @@
import copy
import json

from typing import List, Optional, Tuple
import logging
import sys

from typing import List, Tuple

import torch
import torch.nn as nn
from executorch.backends.qualcomm.quantizer.custom_annotation import (
annotate_linear_16a8w_in_affine_layer,
annotate_matmul_16a8w,
)

from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d

from executorch.examples.models.llama.eval_llama_lib import (
build_args_parser,
GraphModuleEvalWrapper,
)

from executorch.examples.models.llama.source_transformation.quantize import (
get_quant_embedding_transform,
)

from executorch.examples.qualcomm.oss_scripts.llama.llama import calibrate

from executorch.examples.qualcomm.oss_scripts.llama.model.static_llama import (
LlamaModel,
ModelArgs,
)

from executorch.examples.qualcomm.utils import make_quantizer

from lm_eval.evaluator import simple_evaluate

from pytorch_tokenizers import get_tokenizer

from torchao.quantization.pt2e import MinMaxObserver
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e

sys.setrecursionlimit(4096)
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)
logging.getLogger().setLevel(logging.INFO)


class WrappedLlamaModel(nn.Module):
def __init__(self, model, use_kv_cache=False, max_seq_len=512, device="cuda"):
def __init__(
self, model, atten_mask, use_kv_cache=False, max_seq_len=512, device="cuda"
):
super(WrappedLlamaModel, self).__init__()
self.model = model
self.max_seq_len = max_seq_len
self.use_kv_cache = use_kv_cache
self.device = device
self.atten_mask = atten_mask

def forward(
self,
tokens: torch.Tensor,
input_pos: Optional[torch.Tensor] = None,
*args,
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
# Pad input if necessary, since LlamaModel requires static shape
if tokens.shape[1] != self.max_seq_len:
tokens = torch.nn.functional.pad(
tokens, (self.max_seq_len - tokens.shape[1], 0)
tokens, (0, self.max_seq_len - tokens.shape[1])
)
atten_mask = (
self.model.get_example_inputs(self.use_kv_cache)[1]
.to(device=self.device)
.to(dtype=torch.bfloat16)
)
return self.model.forward(tokens, atten_mask, input_pos, *args)
return self.model.forward(tokens, self.atten_mask)


def gen_eval_wrapper(model_name, args):
Expand Down Expand Up @@ -119,14 +143,69 @@ def permute(w, heads):
layer.feed_forward.prepare_feedfoward_conv()

model.to(dtype=torch.bfloat16)
model.to(args.device)
model.to(device=args.device)

wrapped_model = WrappedLlamaModel(
model, args.use_kv_cache, args.max_seq_length, args.device
tokens, atten_mask = model.get_example_inputs(use_kv_cache=False)
tokens = tokens.to(device=args.device)
atten_mask = atten_mask.to(device=args.device)
atten_mask = atten_mask.to(dtype=torch.bfloat16)
inputs = (tokens, atten_mask)

if args.embedding_quantize:
model = get_quant_embedding_transform(
embedding_quantize=args.embedding_quantize
)(model)

model = convert_linear_to_conv2d(model)

if args.ptq:
quant_dtype = getattr(QuantDtype, f"use_{args.ptq}")

custom_annotations = (annotate_matmul_16a8w,)
if args.llama_model == "stories110m":
custom_annotations = custom_annotations + (
annotate_linear_16a8w_in_affine_layer,
)
quantizer = make_quantizer(
quant_dtype=quant_dtype,
per_channel_conv=True,
per_channel_linear=True,
act_observer=MinMaxObserver,
)
quantizer.add_custom_quant_annotations(custom_annotations)

model.has_quant_io = True

with torch.no_grad():
model = torch.export.export(model, inputs, strict=True).module()
if quant_dtype == QuantDtype.use_16a4w_block:
conv_nodes = [n for n in model.graph.nodes if "conv" in n.name]
block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes}
quantizer.set_block_size_map(block_size_map)

model = prepare_pt2e(model, quantizer)

logging.info("Quantizing the model...")

calibrate(
inputs,
"Once upon a time",
model,
tokenizer=tokenizer,
ar_len=args.prefill_ar_len,
max_seq_len=args.max_seq_len,
kv_updater=None,
use_i64_token=use_i64_token,
)

model = convert_pt2e(model)

model = WrappedLlamaModel(
model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device
)

return GraphModuleEvalWrapper(
model=wrapped_model,
model=model,
tokenizer=tokenizer,
max_seq_length=args.calibration_seq_length,
use_kv_cache=args.use_kv_cache,
Expand Down Expand Up @@ -167,6 +246,7 @@ def main() -> None:
modelname = "llama2"
parser = build_args_parser()
args = parser.parse_args()
args.llama_model = "llama3_2"
# Overrides this arg, because evaluation requires full logits.
args.generate_full_logits = True

Expand All @@ -177,7 +257,14 @@ def main() -> None:
args.use_kv_cache = False
args.prefill_ar_len = args.max_seq_length

# To do fewer samples for faster evaluation
args.limit = 0.1
# args.samples = {'wikitext': list(range(1))}

args.device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_device(args.device)

args.ptq = "8a8w"

eval_llama(modelname, args)

Expand Down
Loading