Skip to content

Commit 7122d31

Browse files
authored
Add proper pt2e calibration (#5095)
* Add proper pt2e calibration * distinguish dynamic shape * remove unnecessary code * remove unnecessary code * add comments * Address comments and add template calibration * remove logging * address comments * remove cuda * add graph module eval wrapper
1 parent 20d93fb commit 7122d31

File tree

3 files changed

+183
-12
lines changed

3 files changed

+183
-12
lines changed

examples/models/llama2/eval_llama_lib.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,51 @@
2929
)
3030

3131

32+
class GraphModuleEvalWrapper(EagerEvalWrapper):
33+
"""
34+
A wrapper class for ExecuTorch py-binded integration with the
35+
lm-evaluation-harness library.
36+
"""
37+
38+
def __init__(
39+
self,
40+
model: torch.fx.GraphModule,
41+
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
42+
max_seq_length: Optional[int] = None,
43+
use_kv_cache: bool = False,
44+
enable_dynamic_shape: bool = True,
45+
):
46+
super().__init__(
47+
model=model, tokenizer=tokenizer, max_seq_length=max_seq_length
48+
)
49+
self._model = model.to(self.device)
50+
self._use_kv_cache = use_kv_cache
51+
self._enable_dynamic_shape = enable_dynamic_shape
52+
53+
def _model_call(self, inps):
54+
if self._use_kv_cache:
55+
if not self._enable_dynamic_shape:
56+
# graph module exported without dynamic shape won't work with a different shape.
57+
# And we have to do single token prefill here.
58+
result_logits = []
59+
for pos in range(inps.shape[-1]):
60+
pos_tensor = torch.tensor([pos], dtype=torch.int64)
61+
logits = self._model(inps[:, pos : pos + 1], pos_tensor)
62+
result_logits.append(logits)
63+
return torch.cat(result_logits, dim=1)
64+
else:
65+
pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device)
66+
# Batch process the whole sequence.
67+
logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
68+
return logits
69+
70+
else:
71+
return self._model(inps)
72+
73+
def _model_generate(self, context, max_length, eos_token_id):
74+
raise Exception("unimplemented")
75+
76+
3277
class ETPybindEvalWrapper(EagerEvalWrapper):
3378
"""
3479
A wrapper class for ExecuTorch py-binded integration with the
@@ -148,6 +193,13 @@ def gen_eval_wrapper(
148193
if torch.cuda.is_available()
149194
else manager.pre_autograd_graph_module.to(device="cpu")
150195
)
196+
return GraphModuleEvalWrapper(
197+
model=model,
198+
tokenizer=tokenizer,
199+
max_seq_length=args.max_seq_length,
200+
use_kv_cache=args.use_kv_cache,
201+
enable_dynamic_shape=args.enable_dynamic_shape,
202+
)
151203
else:
152204
# TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
153205
# for quantizers. Currently capture_pre_autograd_graph only works with --kv_cache, but
@@ -157,13 +209,12 @@ def gen_eval_wrapper(
157209
if torch.cuda.is_available()
158210
else manager.model.eval().to(device="cpu")
159211
)
160-
161-
return EagerEvalWrapper(
162-
model=model,
163-
tokenizer=tokenizer,
164-
max_seq_length=args.max_seq_length,
165-
use_kv_cache=args.use_kv_cache,
166-
)
212+
return EagerEvalWrapper(
213+
model=model,
214+
tokenizer=tokenizer,
215+
max_seq_length=args.max_seq_length,
216+
use_kv_cache=args.use_kv_cache,
217+
)
167218

168219

169220
def build_args_parser() -> argparse.ArgumentParser:

examples/models/llama2/export_llama_lib.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from enum import Enum
1717
from json import JSONDecodeError
1818
from pathlib import Path
19-
from typing import Optional, Union
19+
from typing import List, Optional, Union
2020

2121
import pkg_resources
2222

@@ -166,19 +166,25 @@ def build_args_parser() -> argparse.ArgumentParser:
166166
nargs="+",
167167
type=str,
168168
default=None,
169-
help="Tasks for GPTQ calibration",
169+
help="Tasks for GPTQ calibration from lm_eval",
170170
)
171171
parser.add_argument(
172172
"--calibration_limit",
173173
type=int,
174174
default=None,
175-
help="number of samples used for calibration",
175+
help="number of samples used for calibration from lm_eval",
176176
)
177177
parser.add_argument(
178178
"--calibration_seq_length",
179179
type=int,
180180
default=None,
181-
help="Sequence length for GPTQ calibration",
181+
help="Sequence length for GPTQ calibration from lm_eval",
182+
)
183+
parser.add_argument(
184+
"--calibration_data",
185+
type=str,
186+
default="Once upon a time",
187+
help="Calibration prompts from users",
182188
)
183189
parser.add_argument(
184190
"-t",
@@ -421,6 +427,11 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
421427
generate_full_logits=args.generate_full_logits,
422428
weight_type=weight_type,
423429
enable_dynamic_shape=args.enable_dynamic_shape,
430+
calibration_tasks=args.calibration_tasks,
431+
calibration_limit=args.calibration_limit,
432+
calibration_seq_length=args.calibration_seq_length,
433+
calibration_data=args.calibration_data,
434+
tokenizer_path=args.tokenizer_path,
424435
verbose=args.verbose,
425436
max_seq_len=args.max_seq_length,
426437
metadata_str=args.metadata,
@@ -630,6 +641,11 @@ def _load_llama_model(
630641
generate_full_logits: bool = False,
631642
weight_type: WeightType = WeightType.LLAMA,
632643
enable_dynamic_shape: bool = False,
644+
calibration_tasks: Optional[List[str]] = None,
645+
calibration_limit: Optional[int] = None,
646+
calibration_seq_length: Optional[int] = None,
647+
calibration_data: Optional[str] = None,
648+
tokenizer_path: Optional[str] = None,
633649
verbose: bool = False,
634650
max_seq_len: int = 128,
635651
metadata_str: Optional[str] = None,
@@ -685,6 +701,11 @@ def _load_llama_model(
685701
use_kv_cache=use_kv_cache,
686702
example_inputs=example_inputs,
687703
enable_dynamic_shape=enable_dynamic_shape,
704+
calibration_tasks=calibration_tasks,
705+
calibration_limit=calibration_limit,
706+
calibration_seq_length=calibration_seq_length,
707+
calibration_data=calibration_data,
708+
tokenizer_path=tokenizer_path,
688709
verbose=verbose,
689710
metadata=_load_llama_model_metadata(
690711
weight_type,

extension/llm/export/builder.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
2828

2929
from executorch.extension.export_util.utils import export_to_edge, save_pte_program
30+
from executorch.extension.llm.tokenizer.utils import get_tokenizer
3031
from torch._export import capture_pre_autograd_graph
3132
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
3233
from torch.ao.quantization.quantizer import Quantizer
@@ -66,6 +67,11 @@ def __init__(
6667
use_kv_cache,
6768
example_inputs,
6869
enable_dynamic_shape: bool = False,
70+
calibration_tasks: Optional[List[str]] = None,
71+
calibration_limit: Optional[int] = None,
72+
calibration_seq_length: Optional[int] = None,
73+
calibration_data: Optional[str] = None,
74+
tokenizer_path: Optional[str] = None,
6975
verbose: bool = False,
7076
metadata: Optional[dict] = None,
7177
dynamic_shapes: Optional[Any] = None,
@@ -87,6 +93,11 @@ def __init__(
8793
self.output_dir = "."
8894
self.dynamic_shapes = dynamic_shapes
8995
self._saved_pte_filename = None
96+
self.calibration_tasks = calibration_tasks
97+
self.calibration_limit = calibration_limit
98+
self.calibration_seq_length = calibration_seq_length
99+
self.calibration_data = calibration_data
100+
self.tokenizer_path = tokenizer_path
90101

91102
def set_output_dir(self, output_dir: str) -> "LLMEdgeManager":
92103
"""
@@ -167,6 +178,69 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager":
167178
)
168179
return self
169180

181+
def pt2e_calibrate(
182+
self,
183+
prepared_module,
184+
calibration_tasks,
185+
calibration_limit,
186+
calibration_seq_length,
187+
calibration_data,
188+
tokenizer_path,
189+
):
190+
logging.info("Run calibration...")
191+
try:
192+
from executorch.examples.models.llama2.eval_llama_lib import (
193+
GraphModuleEvalWrapper,
194+
)
195+
from executorch.examples.models.llama2.evaluate import evaluate_model
196+
except ImportError:
197+
raise ImportError(
198+
"Please install the llm eval dependency via examples/models/llama2/install_requirements.sh"
199+
)
200+
201+
tokenizer = get_tokenizer(tokenizer_path)
202+
203+
def calibrate_template(
204+
module: torch.fx.GraphModule, tokenizer, prompts: str, max_len: int
205+
):
206+
# TODO: change criteria & support batch inputs if necessary
207+
pos = torch.tensor(0, dtype=torch.int64)
208+
token_list = tokenizer.encode(prompts, bos=True, eos=False)
209+
210+
with torch.no_grad():
211+
while token_list[-1] != tokenizer.eos_id and pos < max_len:
212+
logits = module(
213+
torch.full((1, 1), token_list[pos]),
214+
torch.tensor((pos,)),
215+
)
216+
pos += 1
217+
if pos >= len(token_list):
218+
token_list.append(torch.argmax(logits[:], dim=-1).item())
219+
220+
calibrate_template(
221+
module=prepared_module,
222+
tokenizer=tokenizer,
223+
prompts=calibration_data,
224+
max_len=calibration_seq_length,
225+
)
226+
227+
eval_wrapper = GraphModuleEvalWrapper(
228+
model=prepared_module,
229+
tokenizer=tokenizer,
230+
max_seq_length=calibration_seq_length,
231+
use_kv_cache=self.use_kv_cache,
232+
enable_dynamic_shape=self.enable_dynamic_shape,
233+
)
234+
eval_results = evaluate_model(
235+
eval_wrapper,
236+
calibration_tasks,
237+
calibration_limit,
238+
)
239+
240+
for task, res in eval_results["results"].items():
241+
print(f"{task}: {res}")
242+
logging.info("Calibration finish...")
243+
170244
def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManager":
171245
"""
172246
Quantize the model via pt2e flow and retrieve LLMEdgeManager including the quantized model.
@@ -189,8 +263,33 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
189263
self.pre_autograd_graph_module is not None
190264
), "Please run capture_pre_autograd_graph first"
191265
m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer)
266+
logging.info(
267+
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
268+
)
192269
# Calibrate
193-
m(*self.example_inputs)
270+
if (
271+
self.calibration_tasks is not None
272+
and self.calibration_limit is not None
273+
and self.calibration_seq_length is not None
274+
and self.calibration_data is not None
275+
and self.tokenizer_path is not None
276+
):
277+
logging.info(
278+
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
279+
)
280+
self.pt2e_calibrate(
281+
prepared_module=m,
282+
calibration_tasks=self.calibration_tasks,
283+
calibration_limit=self.calibration_limit,
284+
calibration_seq_length=self.calibration_seq_length,
285+
calibration_data=self.calibration_data,
286+
tokenizer_path=self.tokenizer_path,
287+
)
288+
else:
289+
logging.info(
290+
"No calibration provided, using dummy input to calibrate..."
291+
)
292+
m(*self.example_inputs)
194293
m = convert_pt2e(m)
195294
DuplicateDynamicQuantChainPass()(m)
196295
self.pre_autograd_graph_module = m

0 commit comments

Comments
 (0)