Skip to content

Commit 29ddffb

Browse files
cccclaifacebook-github-bot
authored andcommitted
Reland add proper calibration for pt2e flow (#5152)
Summary: Pull Request resolved: #5152 See discussion in #5095 Reland because of internal failure Differential Revision: D62323396
1 parent 13da62b commit 29ddffb

File tree

6 files changed

+212
-39
lines changed

6 files changed

+212
-39
lines changed

examples/models/llama2/eval_llama_lib.py

Lines changed: 67 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,51 @@
2929
)
3030

3131

32+
class GraphModuleEvalWrapper(EagerEvalWrapper):
33+
"""
34+
A wrapper class for ExecuTorch py-binded integration with the
35+
lm-evaluation-harness library.
36+
"""
37+
38+
def __init__(
39+
self,
40+
model: torch.fx.GraphModule,
41+
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
42+
max_seq_length: Optional[int] = None,
43+
use_kv_cache: bool = False,
44+
enable_dynamic_shape: bool = True,
45+
):
46+
super().__init__(
47+
model=model, tokenizer=tokenizer, max_seq_length=max_seq_length
48+
)
49+
self._model = model.to(self.device)
50+
self._use_kv_cache = use_kv_cache
51+
self._enable_dynamic_shape = enable_dynamic_shape
52+
53+
def _model_call(self, inps):
54+
if self._use_kv_cache:
55+
if not self._enable_dynamic_shape:
56+
# graph module exported without dynamic shape won't work with a different shape.
57+
# And we have to do single token prefill here.
58+
result_logits = []
59+
for pos in range(inps.shape[-1]):
60+
pos_tensor = torch.tensor([pos], dtype=torch.int64)
61+
logits = self._model(inps[:, pos : pos + 1], pos_tensor)
62+
result_logits.append(logits)
63+
return torch.cat(result_logits, dim=1)
64+
else:
65+
pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device)
66+
# Batch process the whole sequence.
67+
logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
68+
return logits
69+
70+
else:
71+
return self._model(inps)
72+
73+
def _model_generate(self, context, max_length, eos_token_id):
74+
raise Exception("unimplemented")
75+
76+
3277
class ETPybindEvalWrapper(EagerEvalWrapper):
3378
"""
3479
A wrapper class for ExecuTorch py-binded integration with the
@@ -148,6 +193,13 @@ def gen_eval_wrapper(
148193
if torch.cuda.is_available()
149194
else manager.pre_autograd_graph_module.to(device="cpu")
150195
)
196+
return GraphModuleEvalWrapper(
197+
model=model,
198+
tokenizer=tokenizer,
199+
max_seq_length=args.max_seq_length,
200+
use_kv_cache=args.use_kv_cache,
201+
enable_dynamic_shape=args.enable_dynamic_shape,
202+
)
151203
else:
152204
# TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
153205
# for quantizers. Currently capture_pre_autograd_graph only works with --kv_cache, but
@@ -158,21 +210,21 @@ def gen_eval_wrapper(
158210
else manager.model.eval().to(device="cpu")
159211
)
160212

161-
# Save the checkpoint after the eager model preparation is done.
162-
# The reason for this option is that the checkpoint can be used
163-
# to do evaluations in other evaluation platforms, or with data
164-
# that is not available in this eval_llama. We save the checkpoint
165-
# here for consistency with eval_llama. The accuracy results we
166-
# get from eval_llama can be used as a reference to other evaluations.
167-
if args.output_eager_checkpoint_file is not None:
168-
torch.save(model, args.output_eager_checkpoint_file)
169-
170-
return EagerEvalWrapper(
171-
model=model,
172-
tokenizer=tokenizer,
173-
max_seq_length=args.max_seq_length,
174-
use_kv_cache=args.use_kv_cache,
175-
)
213+
# Save the checkpoint after the eager model preparation is done.
214+
# The reason for this option is that the checkpoint can be used
215+
# to do evaluations in other evaluation platforms, or with data
216+
# that is not available in this eval_llama. We save the checkpoint
217+
# here for consistency with eval_llama. The accuracy results we
218+
# get from eval_llama can be used as a reference to other evaluations.
219+
if args.output_eager_checkpoint_file is not None:
220+
torch.save(model, args.output_eager_checkpoint_file)
221+
222+
return EagerEvalWrapper(
223+
model=model,
224+
tokenizer=tokenizer,
225+
max_seq_length=args.max_seq_length,
226+
use_kv_cache=args.use_kv_cache,
227+
)
176228

177229

178230
def build_args_parser() -> argparse.ArgumentParser:

examples/models/llama2/export_llama_lib.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from enum import Enum
1717
from json import JSONDecodeError
1818
from pathlib import Path
19-
from typing import Optional, Union
19+
from typing import List, Optional, Union
2020

2121
import pkg_resources
2222

@@ -166,19 +166,25 @@ def build_args_parser() -> argparse.ArgumentParser:
166166
nargs="+",
167167
type=str,
168168
default=None,
169-
help="Tasks for GPTQ calibration",
169+
help="Tasks for GPTQ calibration from lm_eval",
170170
)
171171
parser.add_argument(
172172
"--calibration_limit",
173173
type=int,
174174
default=None,
175-
help="number of samples used for calibration",
175+
help="number of samples used for calibration from lm_eval",
176176
)
177177
parser.add_argument(
178178
"--calibration_seq_length",
179179
type=int,
180180
default=None,
181-
help="Sequence length for GPTQ calibration",
181+
help="Sequence length for GPTQ calibration from lm_eval",
182+
)
183+
parser.add_argument(
184+
"--calibration_data",
185+
type=str,
186+
default="Once upon a time",
187+
help="Calibration prompts from users",
182188
)
183189
parser.add_argument(
184190
"-t",
@@ -420,6 +426,11 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
420426
generate_full_logits=args.generate_full_logits,
421427
weight_type=weight_type,
422428
enable_dynamic_shape=args.enable_dynamic_shape,
429+
calibration_tasks=args.calibration_tasks,
430+
calibration_limit=args.calibration_limit,
431+
calibration_seq_length=args.calibration_seq_length,
432+
calibration_data=args.calibration_data,
433+
tokenizer_path=args.tokenizer_path,
423434
verbose=args.verbose,
424435
max_seq_len=args.max_seq_length,
425436
metadata_str=args.metadata,
@@ -630,6 +641,11 @@ def _load_llama_model(
630641
generate_full_logits: bool = False,
631642
weight_type: WeightType = WeightType.LLAMA,
632643
enable_dynamic_shape: bool = False,
644+
calibration_tasks: Optional[List[str]] = None,
645+
calibration_limit: Optional[int] = None,
646+
calibration_seq_length: Optional[int] = None,
647+
calibration_data: Optional[str] = None,
648+
tokenizer_path: Optional[str] = None,
633649
verbose: bool = False,
634650
max_seq_len: int = 128,
635651
metadata_str: Optional[str] = None,
@@ -686,6 +702,11 @@ def _load_llama_model(
686702
use_kv_cache=use_kv_cache,
687703
example_inputs=example_inputs,
688704
enable_dynamic_shape=enable_dynamic_shape,
705+
calibration_tasks=calibration_tasks,
706+
calibration_limit=calibration_limit,
707+
calibration_seq_length=calibration_seq_length,
708+
calibration_data=calibration_data,
709+
tokenizer_path=tokenizer_path,
689710
verbose=verbose,
690711
metadata=_load_llama_model_metadata(
691712
weight_type,

examples/models/llama2/tokenizer/targets.bzl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,19 @@ def define_common_targets():
2121
"@EXECUTORCH_CLIENTS",
2222
],
2323
)
24+
25+
runtime.python_library(
26+
name = "tiktoken_py",
27+
srcs = [
28+
"tiktoken.py",
29+
],
30+
_is_external_target = True,
31+
visibility = [
32+
"//bento/...",
33+
"//bento_kernels/...",
34+
"//executorch/...",
35+
],
36+
deps = [
37+
"fbsource//third-party/pypi/tiktoken:tiktoken",
38+
],
39+
)

extension/llm/export/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,6 @@ runtime.python_library(
3333
"//executorch/exir:lib",
3434
"//executorch/exir/backend:backend_details",
3535
"//executorch/extension/export_util:export_util",
36+
"//executorch/extension/llm/tokenizer:tokenizer_py_lib",
3637
],
3738
)

extension/llm/export/builder.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
2828

2929
from executorch.extension.export_util.utils import export_to_edge, save_pte_program
30+
from executorch.extension.llm.tokenizer.utils import get_tokenizer
3031
from torch._export import capture_pre_autograd_graph
3132
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
3233
from torch.ao.quantization.quantizer import Quantizer
@@ -68,6 +69,11 @@ def __init__(
6869
example_inputs,
6970
args: Optional[Any] = None,
7071
enable_dynamic_shape: bool = False,
72+
calibration_tasks: Optional[List[str]] = None,
73+
calibration_limit: Optional[int] = None,
74+
calibration_seq_length: Optional[int] = None,
75+
calibration_data: Optional[str] = None,
76+
tokenizer_path: Optional[str] = None,
7177
verbose: bool = False,
7278
metadata: Optional[dict] = None,
7379
dynamic_shapes: Optional[Any] = None,
@@ -90,6 +96,11 @@ def __init__(
9096
self.dynamic_shapes = dynamic_shapes
9197
self._saved_pte_filename = None
9298
self.args = args
99+
self.calibration_tasks = calibration_tasks
100+
self.calibration_limit = calibration_limit
101+
self.calibration_seq_length = calibration_seq_length
102+
self.calibration_data = calibration_data
103+
self.tokenizer_path = tokenizer_path
93104

94105
def set_output_dir(self, output_dir: str) -> "LLMEdgeManager":
95106
"""
@@ -181,6 +192,69 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager":
181192

182193
return self
183194

195+
def pt2e_calibrate(
196+
self,
197+
prepared_module,
198+
calibration_tasks,
199+
calibration_limit,
200+
calibration_seq_length,
201+
calibration_data,
202+
tokenizer_path,
203+
):
204+
logging.info("Run calibration...")
205+
try:
206+
from executorch.examples.models.llama2.eval_llama_lib import (
207+
GraphModuleEvalWrapper,
208+
)
209+
from executorch.examples.models.llama2.evaluate import evaluate_model
210+
except ImportError:
211+
raise ImportError(
212+
"Please install the llm eval dependency via examples/models/llama2/install_requirements.sh"
213+
)
214+
215+
tokenizer = get_tokenizer(tokenizer_path)
216+
217+
def calibrate_template(
218+
module: torch.fx.GraphModule, tokenizer, prompts: str, max_len: int
219+
):
220+
# TODO: change criteria & support batch inputs if necessary
221+
pos = torch.tensor(0, dtype=torch.int64)
222+
token_list = tokenizer.encode(prompts, bos=True, eos=False)
223+
224+
with torch.no_grad():
225+
while token_list[-1] != tokenizer.eos_id and pos < max_len:
226+
logits = module(
227+
torch.full((1, 1), token_list[pos]),
228+
torch.tensor((pos,)),
229+
)
230+
pos += 1
231+
if pos >= len(token_list):
232+
token_list.append(torch.argmax(logits[:], dim=-1).item())
233+
234+
calibrate_template(
235+
module=prepared_module,
236+
tokenizer=tokenizer,
237+
prompts=calibration_data,
238+
max_len=calibration_seq_length,
239+
)
240+
241+
eval_wrapper = GraphModuleEvalWrapper(
242+
model=prepared_module,
243+
tokenizer=tokenizer,
244+
max_seq_length=calibration_seq_length,
245+
use_kv_cache=self.use_kv_cache,
246+
enable_dynamic_shape=self.enable_dynamic_shape,
247+
)
248+
eval_results = evaluate_model(
249+
eval_wrapper,
250+
calibration_tasks,
251+
calibration_limit,
252+
)
253+
254+
for task, res in eval_results["results"].items():
255+
print(f"{task}: {res}")
256+
logging.info("Calibration finish...")
257+
184258
def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManager":
185259
"""
186260
Quantize the model via pt2e flow and retrieve LLMEdgeManager including the quantized model.
@@ -203,8 +277,33 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
203277
self.pre_autograd_graph_module is not None
204278
), "Please run capture_pre_autograd_graph first"
205279
m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer)
280+
logging.info(
281+
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
282+
)
206283
# Calibrate
207-
m(*self.example_inputs)
284+
if (
285+
self.calibration_tasks is not None
286+
and self.calibration_limit is not None
287+
and self.calibration_seq_length is not None
288+
and self.calibration_data is not None
289+
and self.tokenizer_path is not None
290+
):
291+
logging.info(
292+
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
293+
)
294+
self.pt2e_calibrate(
295+
prepared_module=m,
296+
calibration_tasks=self.calibration_tasks,
297+
calibration_limit=self.calibration_limit,
298+
calibration_seq_length=self.calibration_seq_length,
299+
calibration_data=self.calibration_data,
300+
tokenizer_path=self.tokenizer_path,
301+
)
302+
else:
303+
logging.info(
304+
"No calibration provided, using dummy input to calibrate..."
305+
)
306+
m(*self.example_inputs)
208307
m = convert_pt2e(m)
209308
DuplicateDynamicQuantChainPass()(m)
210309
self.pre_autograd_graph_module = m

extension/llm/tokenizer/targets.bzl

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,36 +11,20 @@ def define_common_targets():
1111
srcs = [
1212
"__init__.py",
1313
"tokenizer.py",
14+
"utils.py",
1415
],
1516
base_module = "executorch.extension.llm.tokenizer",
1617
visibility = [
1718
"//executorch/examples/...",
1819
"//executorch/extension/llm/tokenizer/...",
20+
"//executorch/extension/llm/export/...",
1921
"//bento/...",
2022
"//bento_kernels/...",
2123
],
2224
_is_external_target = True,
23-
external_deps = [
24-
"sentencepiece-py",
25-
],
26-
)
27-
28-
runtime.python_library(
29-
name = "utils",
30-
srcs = [
31-
"utils.py",
32-
],
33-
base_module = "executorch.extension.llm.utils",
34-
visibility = [
35-
"//executorch/examples/...",
36-
"//executorch/extension/llm/tokenizer/...",
37-
"//bento/...",
38-
"//bento_kernels/...",
39-
],
4025
deps = [
41-
"//executorch/examples/models/llama2/tokenizer:tiktoken",
26+
"//executorch/examples/models/llama2/tokenizer:tiktoken_py",
4227
],
43-
_is_external_target = True,
4428
external_deps = [
4529
"sentencepiece-py",
4630
],

0 commit comments

Comments
 (0)