Skip to content

Reland add proper calibration for pt2e flow #5152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 67 additions & 15 deletions examples/models/llama2/eval_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,51 @@
)


class GraphModuleEvalWrapper(EagerEvalWrapper):
"""
A wrapper class for ExecuTorch py-binded integration with the
lm-evaluation-harness library.
"""

def __init__(
self,
model: torch.fx.GraphModule,
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
max_seq_length: Optional[int] = None,
use_kv_cache: bool = False,
enable_dynamic_shape: bool = True,
):
super().__init__(
model=model, tokenizer=tokenizer, max_seq_length=max_seq_length
)
self._model = model.to(self.device)
self._use_kv_cache = use_kv_cache
self._enable_dynamic_shape = enable_dynamic_shape

def _model_call(self, inps):
if self._use_kv_cache:
if not self._enable_dynamic_shape:
# graph module exported without dynamic shape won't work with a different shape.
# And we have to do single token prefill here.
result_logits = []
for pos in range(inps.shape[-1]):
pos_tensor = torch.tensor([pos], dtype=torch.int64)
logits = self._model(inps[:, pos : pos + 1], pos_tensor)
result_logits.append(logits)
return torch.cat(result_logits, dim=1)
else:
pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device)
# Batch process the whole sequence.
logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
return logits

else:
return self._model(inps)

def _model_generate(self, context, max_length, eos_token_id):
raise Exception("unimplemented")


class ETPybindEvalWrapper(EagerEvalWrapper):
"""
A wrapper class for ExecuTorch py-binded integration with the
Expand Down Expand Up @@ -148,6 +193,13 @@ def gen_eval_wrapper(
if torch.cuda.is_available()
else manager.pre_autograd_graph_module.to(device="cpu")
)
return GraphModuleEvalWrapper(
model=model,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
use_kv_cache=args.use_kv_cache,
enable_dynamic_shape=args.enable_dynamic_shape,
)
else:
# TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
# for quantizers. Currently capture_pre_autograd_graph only works with --kv_cache, but
Expand All @@ -158,21 +210,21 @@ def gen_eval_wrapper(
else manager.model.eval().to(device="cpu")
)

# Save the checkpoint after the eager model preparation is done.
# The reason for this option is that the checkpoint can be used
# to do evaluations in other evaluation platforms, or with data
# that is not available in this eval_llama. We save the checkpoint
# here for consistency with eval_llama. The accuracy results we
# get from eval_llama can be used as a reference to other evaluations.
if args.output_eager_checkpoint_file is not None:
torch.save(model, args.output_eager_checkpoint_file)

return EagerEvalWrapper(
model=model,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
use_kv_cache=args.use_kv_cache,
)
# Save the checkpoint after the eager model preparation is done.
# The reason for this option is that the checkpoint can be used
# to do evaluations in other evaluation platforms, or with data
# that is not available in this eval_llama. We save the checkpoint
# here for consistency with eval_llama. The accuracy results we
# get from eval_llama can be used as a reference to other evaluations.
if args.output_eager_checkpoint_file is not None:
torch.save(model, args.output_eager_checkpoint_file)

return EagerEvalWrapper(
model=model,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
use_kv_cache=args.use_kv_cache,
)


def build_args_parser() -> argparse.ArgumentParser:
Expand Down
29 changes: 25 additions & 4 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from enum import Enum
from json import JSONDecodeError
from pathlib import Path
from typing import Optional, Union
from typing import List, Optional, Union

import pkg_resources

Expand Down Expand Up @@ -166,19 +166,25 @@ def build_args_parser() -> argparse.ArgumentParser:
nargs="+",
type=str,
default=None,
help="Tasks for GPTQ calibration",
help="Tasks for GPTQ calibration from lm_eval",
)
parser.add_argument(
"--calibration_limit",
type=int,
default=None,
help="number of samples used for calibration",
help="number of samples used for calibration from lm_eval",
)
parser.add_argument(
"--calibration_seq_length",
type=int,
default=None,
help="Sequence length for GPTQ calibration",
help="Sequence length for GPTQ calibration from lm_eval",
)
parser.add_argument(
"--calibration_data",
type=str,
default="Once upon a time",
help="Calibration prompts from users",
)
parser.add_argument(
"-t",
Expand Down Expand Up @@ -420,6 +426,11 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
generate_full_logits=args.generate_full_logits,
weight_type=weight_type,
enable_dynamic_shape=args.enable_dynamic_shape,
calibration_tasks=args.calibration_tasks,
calibration_limit=args.calibration_limit,
calibration_seq_length=args.calibration_seq_length,
calibration_data=args.calibration_data,
tokenizer_path=args.tokenizer_path,
verbose=args.verbose,
max_seq_len=args.max_seq_length,
metadata_str=args.metadata,
Expand Down Expand Up @@ -630,6 +641,11 @@ def _load_llama_model(
generate_full_logits: bool = False,
weight_type: WeightType = WeightType.LLAMA,
enable_dynamic_shape: bool = False,
calibration_tasks: Optional[List[str]] = None,
calibration_limit: Optional[int] = None,
calibration_seq_length: Optional[int] = None,
calibration_data: Optional[str] = None,
tokenizer_path: Optional[str] = None,
verbose: bool = False,
max_seq_len: int = 128,
metadata_str: Optional[str] = None,
Expand Down Expand Up @@ -686,6 +702,11 @@ def _load_llama_model(
use_kv_cache=use_kv_cache,
example_inputs=example_inputs,
enable_dynamic_shape=enable_dynamic_shape,
calibration_tasks=calibration_tasks,
calibration_limit=calibration_limit,
calibration_seq_length=calibration_seq_length,
calibration_data=calibration_data,
tokenizer_path=tokenizer_path,
verbose=verbose,
metadata=_load_llama_model_metadata(
weight_type,
Expand Down
16 changes: 16 additions & 0 deletions examples/models/llama2/tokenizer/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,19 @@ def define_common_targets():
"@EXECUTORCH_CLIENTS",
],
)

runtime.python_library(
name = "tiktoken_py",
srcs = [
"tiktoken.py",
],
_is_external_target = True,
visibility = [
"//bento/...",
"//bento_kernels/...",
"//executorch/...",
],
deps = [
"fbsource//third-party/pypi/tiktoken:tiktoken",
],
)
1 change: 1 addition & 0 deletions extension/llm/export/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,6 @@ runtime.python_library(
"//executorch/exir:lib",
"//executorch/exir/backend:backend_details",
"//executorch/extension/export_util:export_util",
"//executorch/extension/llm/tokenizer:tokenizer_py_lib",
],
)
101 changes: 100 additions & 1 deletion extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass

from executorch.extension.export_util.utils import export_to_edge, save_pte_program
from executorch.extension.llm.tokenizer.utils import get_tokenizer
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer import Quantizer
Expand Down Expand Up @@ -68,6 +69,11 @@ def __init__(
example_inputs,
args: Optional[Any] = None,
enable_dynamic_shape: bool = False,
calibration_tasks: Optional[List[str]] = None,
calibration_limit: Optional[int] = None,
calibration_seq_length: Optional[int] = None,
calibration_data: Optional[str] = None,
tokenizer_path: Optional[str] = None,
verbose: bool = False,
metadata: Optional[dict] = None,
dynamic_shapes: Optional[Any] = None,
Expand All @@ -90,6 +96,11 @@ def __init__(
self.dynamic_shapes = dynamic_shapes
self._saved_pte_filename = None
self.args = args
self.calibration_tasks = calibration_tasks
self.calibration_limit = calibration_limit
self.calibration_seq_length = calibration_seq_length
self.calibration_data = calibration_data
self.tokenizer_path = tokenizer_path

def set_output_dir(self, output_dir: str) -> "LLMEdgeManager":
"""
Expand Down Expand Up @@ -181,6 +192,69 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager":

return self

def pt2e_calibrate(
self,
prepared_module,
calibration_tasks,
calibration_limit,
calibration_seq_length,
calibration_data,
tokenizer_path,
):
logging.info("Run calibration...")
try:
from executorch.examples.models.llama2.eval_llama_lib import (
GraphModuleEvalWrapper,
)
from executorch.examples.models.llama2.evaluate import evaluate_model
except ImportError:
raise ImportError(
"Please install the llm eval dependency via examples/models/llama2/install_requirements.sh"
)

tokenizer = get_tokenizer(tokenizer_path)

def calibrate_template(
module: torch.fx.GraphModule, tokenizer, prompts: str, max_len: int
):
# TODO: change criteria & support batch inputs if necessary
pos = torch.tensor(0, dtype=torch.int64)
token_list = tokenizer.encode(prompts, bos=True, eos=False)

with torch.no_grad():
while token_list[-1] != tokenizer.eos_id and pos < max_len:
logits = module(
torch.full((1, 1), token_list[pos]),
torch.tensor((pos,)),
)
pos += 1
if pos >= len(token_list):
token_list.append(torch.argmax(logits[:], dim=-1).item())

calibrate_template(
module=prepared_module,
tokenizer=tokenizer,
prompts=calibration_data,
max_len=calibration_seq_length,
)

eval_wrapper = GraphModuleEvalWrapper(
model=prepared_module,
tokenizer=tokenizer,
max_seq_length=calibration_seq_length,
use_kv_cache=self.use_kv_cache,
enable_dynamic_shape=self.enable_dynamic_shape,
)
eval_results = evaluate_model(
eval_wrapper,
calibration_tasks,
calibration_limit,
)

for task, res in eval_results["results"].items():
print(f"{task}: {res}")
logging.info("Calibration finish...")

def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManager":
"""
Quantize the model via pt2e flow and retrieve LLMEdgeManager including the quantized model.
Expand All @@ -203,8 +277,33 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
self.pre_autograd_graph_module is not None
), "Please run capture_pre_autograd_graph first"
m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer)
logging.info(
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
)
# Calibrate
m(*self.example_inputs)
if (
self.calibration_tasks is not None
and self.calibration_limit is not None
and self.calibration_seq_length is not None
and self.calibration_data is not None
and self.tokenizer_path is not None
):
logging.info(
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
)
self.pt2e_calibrate(
prepared_module=m,
calibration_tasks=self.calibration_tasks,
calibration_limit=self.calibration_limit,
calibration_seq_length=self.calibration_seq_length,
calibration_data=self.calibration_data,
tokenizer_path=self.tokenizer_path,
)
else:
logging.info(
"No calibration provided, using dummy input to calibrate..."
)
m(*self.example_inputs)
m = convert_pt2e(m)
DuplicateDynamicQuantChainPass()(m)
self.pre_autograd_graph_module = m
Expand Down
22 changes: 3 additions & 19 deletions extension/llm/tokenizer/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -11,36 +11,20 @@ def define_common_targets():
srcs = [
"__init__.py",
"tokenizer.py",
"utils.py",
],
base_module = "executorch.extension.llm.tokenizer",
visibility = [
"//executorch/examples/...",
"//executorch/extension/llm/tokenizer/...",
"//executorch/extension/llm/export/...",
"//bento/...",
"//bento_kernels/...",
],
_is_external_target = True,
external_deps = [
"sentencepiece-py",
],
)

runtime.python_library(
name = "utils",
srcs = [
"utils.py",
],
base_module = "executorch.extension.llm.utils",
visibility = [
"//executorch/examples/...",
"//executorch/extension/llm/tokenizer/...",
"//bento/...",
"//bento_kernels/...",
],
deps = [
"//executorch/examples/models/llama2/tokenizer:tiktoken",
"//executorch/examples/models/llama2/tokenizer:tiktoken_py",
],
_is_external_target = True,
external_deps = [
"sentencepiece-py",
],
Expand Down
Loading