Skip to content

Add proper pt2e calibration #5095

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 58 additions & 7 deletions examples/models/llama2/eval_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,51 @@
)


class GraphModuleEvalWrapper(EagerEvalWrapper):
"""
A wrapper class for ExecuTorch py-binded integration with the
lm-evaluation-harness library.
"""

def __init__(
self,
model: torch.fx.GraphModule,
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
max_seq_length: Optional[int] = None,
use_kv_cache: bool = False,
enable_dynamic_shape: bool = True,
):
super().__init__(
model=model, tokenizer=tokenizer, max_seq_length=max_seq_length
)
self._model = model.to(self.device)
self._use_kv_cache = use_kv_cache
self._enable_dynamic_shape = enable_dynamic_shape

def _model_call(self, inps):
if self._use_kv_cache:
if not self._enable_dynamic_shape:
# graph module exported without dynamic shape won't work with a different shape.
# And we have to do single token prefill here.
result_logits = []
for pos in range(inps.shape[-1]):
pos_tensor = torch.tensor([pos], dtype=torch.int64)
logits = self._model(inps[:, pos : pos + 1], pos_tensor)
result_logits.append(logits)
return torch.cat(result_logits, dim=1)
else:
pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device)
# Batch process the whole sequence.
logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
return logits

else:
return self._model(inps)

def _model_generate(self, context, max_length, eos_token_id):
raise Exception("unimplemented")


class ETPybindEvalWrapper(EagerEvalWrapper):
"""
A wrapper class for ExecuTorch py-binded integration with the
Expand Down Expand Up @@ -148,6 +193,13 @@ def gen_eval_wrapper(
if torch.cuda.is_available()
else manager.pre_autograd_graph_module.to(device="cpu")
)
return GraphModuleEvalWrapper(
model=model,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
use_kv_cache=args.use_kv_cache,
enable_dynamic_shape=args.enable_dynamic_shape,
)
else:
# TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
# for quantizers. Currently capture_pre_autograd_graph only works with --kv_cache, but
Expand All @@ -157,13 +209,12 @@ def gen_eval_wrapper(
if torch.cuda.is_available()
else manager.model.eval().to(device="cpu")
)

return EagerEvalWrapper(
model=model,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
use_kv_cache=args.use_kv_cache,
)
return EagerEvalWrapper(
model=model,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
use_kv_cache=args.use_kv_cache,
)


def build_args_parser() -> argparse.ArgumentParser:
Expand Down
29 changes: 25 additions & 4 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from enum import Enum
from json import JSONDecodeError
from pathlib import Path
from typing import Optional, Union
from typing import List, Optional, Union

import pkg_resources

Expand Down Expand Up @@ -166,19 +166,25 @@ def build_args_parser() -> argparse.ArgumentParser:
nargs="+",
type=str,
default=None,
help="Tasks for GPTQ calibration",
help="Tasks for GPTQ calibration from lm_eval",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: For future reference, separate out unrelated fixes

)
parser.add_argument(
"--calibration_limit",
type=int,
default=None,
help="number of samples used for calibration",
help="number of samples used for calibration from lm_eval",
)
parser.add_argument(
"--calibration_seq_length",
type=int,
default=None,
help="Sequence length for GPTQ calibration",
help="Sequence length for GPTQ calibration from lm_eval",
)
parser.add_argument(
"--calibration_data",
type=str,
default="Once upon a time",
help="Calibration prompts from users",
)
parser.add_argument(
"-t",
Expand Down Expand Up @@ -421,6 +427,11 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
generate_full_logits=args.generate_full_logits,
weight_type=weight_type,
enable_dynamic_shape=args.enable_dynamic_shape,
calibration_tasks=args.calibration_tasks,
calibration_limit=args.calibration_limit,
calibration_seq_length=args.calibration_seq_length,
calibration_data=args.calibration_data,
tokenizer_path=args.tokenizer_path,
verbose=args.verbose,
max_seq_len=args.max_seq_length,
metadata_str=args.metadata,
Expand Down Expand Up @@ -630,6 +641,11 @@ def _load_llama_model(
generate_full_logits: bool = False,
weight_type: WeightType = WeightType.LLAMA,
enable_dynamic_shape: bool = False,
calibration_tasks: Optional[List[str]] = None,
calibration_limit: Optional[int] = None,
calibration_seq_length: Optional[int] = None,
calibration_data: Optional[str] = None,
tokenizer_path: Optional[str] = None,
verbose: bool = False,
max_seq_len: int = 128,
metadata_str: Optional[str] = None,
Expand Down Expand Up @@ -685,6 +701,11 @@ def _load_llama_model(
use_kv_cache=use_kv_cache,
example_inputs=example_inputs,
enable_dynamic_shape=enable_dynamic_shape,
calibration_tasks=calibration_tasks,
calibration_limit=calibration_limit,
calibration_seq_length=calibration_seq_length,
calibration_data=calibration_data,
tokenizer_path=tokenizer_path,
verbose=verbose,
metadata=_load_llama_model_metadata(
weight_type,
Expand Down
101 changes: 100 additions & 1 deletion extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass

from executorch.extension.export_util.utils import export_to_edge, save_pte_program
from executorch.extension.llm.tokenizer.utils import get_tokenizer
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer import Quantizer
Expand Down Expand Up @@ -66,6 +67,11 @@ def __init__(
use_kv_cache,
example_inputs,
enable_dynamic_shape: bool = False,
calibration_tasks: Optional[List[str]] = None,
calibration_limit: Optional[int] = None,
calibration_seq_length: Optional[int] = None,
calibration_data: Optional[str] = None,
tokenizer_path: Optional[str] = None,
verbose: bool = False,
metadata: Optional[dict] = None,
dynamic_shapes: Optional[Any] = None,
Expand All @@ -87,6 +93,11 @@ def __init__(
self.output_dir = "."
self.dynamic_shapes = dynamic_shapes
self._saved_pte_filename = None
self.calibration_tasks = calibration_tasks
self.calibration_limit = calibration_limit
self.calibration_seq_length = calibration_seq_length
self.calibration_data = calibration_data
self.tokenizer_path = tokenizer_path

def set_output_dir(self, output_dir: str) -> "LLMEdgeManager":
"""
Expand Down Expand Up @@ -167,6 +178,69 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager":
)
return self

def pt2e_calibrate(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method should not be part of builder at all. It is meant to produce a model not calibrate.

Hence my suggestion was to move the functionality of this method either inside GraphModuleEvalWrapper or soemthing else

self,
prepared_module,
calibration_tasks,
calibration_limit,
calibration_seq_length,
calibration_data,
tokenizer_path,
):
logging.info("Run calibration...")
try:
from executorch.examples.models.llama2.eval_llama_lib import (
GraphModuleEvalWrapper,
)
from executorch.examples.models.llama2.evaluate import evaluate_model
except ImportError:
raise ImportError(
"Please install the llm eval dependency via examples/models/llama2/install_requirements.sh"
)

tokenizer = get_tokenizer(tokenizer_path)

def calibrate_template(
module: torch.fx.GraphModule, tokenizer, prompts: str, max_len: int
):
# TODO: change criteria & support batch inputs if necessary
pos = torch.tensor(0, dtype=torch.int64)
token_list = tokenizer.encode(prompts, bos=True, eos=False)

with torch.no_grad():
while token_list[-1] != tokenizer.eos_id and pos < max_len:
logits = module(
torch.full((1, 1), token_list[pos]),
torch.tensor((pos,)),
)
pos += 1
if pos >= len(token_list):
token_list.append(torch.argmax(logits[:], dim=-1).item())

calibrate_template(
module=prepared_module,
tokenizer=tokenizer,
prompts=calibration_data,
max_len=calibration_seq_length,
)

eval_wrapper = GraphModuleEvalWrapper(
model=prepared_module,
tokenizer=tokenizer,
max_seq_length=calibration_seq_length,
use_kv_cache=self.use_kv_cache,
enable_dynamic_shape=self.enable_dynamic_shape,
)
eval_results = evaluate_model(
eval_wrapper,
calibration_tasks,
calibration_limit,
)

for task, res in eval_results["results"].items():
print(f"{task}: {res}")
logging.info("Calibration finish...")

def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManager":
"""
Quantize the model via pt2e flow and retrieve LLMEdgeManager including the quantized model.
Expand All @@ -189,8 +263,33 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
self.pre_autograd_graph_module is not None
), "Please run capture_pre_autograd_graph first"
m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer)
logging.info(
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
)
# Calibrate
m(*self.example_inputs)
if (
self.calibration_tasks is not None
and self.calibration_limit is not None
and self.calibration_seq_length is not None
and self.calibration_data is not None
and self.tokenizer_path is not None
):
logging.info(
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
)
self.pt2e_calibrate(
prepared_module=m,
calibration_tasks=self.calibration_tasks,
calibration_limit=self.calibration_limit,
calibration_seq_length=self.calibration_seq_length,
calibration_data=self.calibration_data,
tokenizer_path=self.tokenizer_path,
)
else:
logging.info(
"No calibration provided, using dummy input to calibrate..."
)
m(*self.example_inputs)
m = convert_pt2e(m)
DuplicateDynamicQuantChainPass()(m)
self.pre_autograd_graph_module = m
Expand Down
Loading