Skip to content

Commit 44f9526

Browse files
authored
Revert "Add proper pt2e calibration (#5095)"
This reverts commit 7122d31.
1 parent c83fd2e commit 44f9526

File tree

3 files changed

+12
-183
lines changed

3 files changed

+12
-183
lines changed

examples/models/llama2/eval_llama_lib.py

Lines changed: 7 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -29,51 +29,6 @@
2929
)
3030

3131

32-
class GraphModuleEvalWrapper(EagerEvalWrapper):
33-
"""
34-
A wrapper class for ExecuTorch py-binded integration with the
35-
lm-evaluation-harness library.
36-
"""
37-
38-
def __init__(
39-
self,
40-
model: torch.fx.GraphModule,
41-
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
42-
max_seq_length: Optional[int] = None,
43-
use_kv_cache: bool = False,
44-
enable_dynamic_shape: bool = True,
45-
):
46-
super().__init__(
47-
model=model, tokenizer=tokenizer, max_seq_length=max_seq_length
48-
)
49-
self._model = model.to(self.device)
50-
self._use_kv_cache = use_kv_cache
51-
self._enable_dynamic_shape = enable_dynamic_shape
52-
53-
def _model_call(self, inps):
54-
if self._use_kv_cache:
55-
if not self._enable_dynamic_shape:
56-
# graph module exported without dynamic shape won't work with a different shape.
57-
# And we have to do single token prefill here.
58-
result_logits = []
59-
for pos in range(inps.shape[-1]):
60-
pos_tensor = torch.tensor([pos], dtype=torch.int64)
61-
logits = self._model(inps[:, pos : pos + 1], pos_tensor)
62-
result_logits.append(logits)
63-
return torch.cat(result_logits, dim=1)
64-
else:
65-
pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device)
66-
# Batch process the whole sequence.
67-
logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
68-
return logits
69-
70-
else:
71-
return self._model(inps)
72-
73-
def _model_generate(self, context, max_length, eos_token_id):
74-
raise Exception("unimplemented")
75-
76-
7732
class ETPybindEvalWrapper(EagerEvalWrapper):
7833
"""
7934
A wrapper class for ExecuTorch py-binded integration with the
@@ -193,13 +148,6 @@ def gen_eval_wrapper(
193148
if torch.cuda.is_available()
194149
else manager.pre_autograd_graph_module.to(device="cpu")
195150
)
196-
return GraphModuleEvalWrapper(
197-
model=model,
198-
tokenizer=tokenizer,
199-
max_seq_length=args.max_seq_length,
200-
use_kv_cache=args.use_kv_cache,
201-
enable_dynamic_shape=args.enable_dynamic_shape,
202-
)
203151
else:
204152
# TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
205153
# for quantizers. Currently capture_pre_autograd_graph only works with --kv_cache, but
@@ -209,12 +157,13 @@ def gen_eval_wrapper(
209157
if torch.cuda.is_available()
210158
else manager.model.eval().to(device="cpu")
211159
)
212-
return EagerEvalWrapper(
213-
model=model,
214-
tokenizer=tokenizer,
215-
max_seq_length=args.max_seq_length,
216-
use_kv_cache=args.use_kv_cache,
217-
)
160+
161+
return EagerEvalWrapper(
162+
model=model,
163+
tokenizer=tokenizer,
164+
max_seq_length=args.max_seq_length,
165+
use_kv_cache=args.use_kv_cache,
166+
)
218167

219168

220169
def build_args_parser() -> argparse.ArgumentParser:

examples/models/llama2/export_llama_lib.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from enum import Enum
1717
from json import JSONDecodeError
1818
from pathlib import Path
19-
from typing import List, Optional, Union
19+
from typing import Optional, Union
2020

2121
import pkg_resources
2222

@@ -166,25 +166,19 @@ def build_args_parser() -> argparse.ArgumentParser:
166166
nargs="+",
167167
type=str,
168168
default=None,
169-
help="Tasks for GPTQ calibration from lm_eval",
169+
help="Tasks for GPTQ calibration",
170170
)
171171
parser.add_argument(
172172
"--calibration_limit",
173173
type=int,
174174
default=None,
175-
help="number of samples used for calibration from lm_eval",
175+
help="number of samples used for calibration",
176176
)
177177
parser.add_argument(
178178
"--calibration_seq_length",
179179
type=int,
180180
default=None,
181-
help="Sequence length for GPTQ calibration from lm_eval",
182-
)
183-
parser.add_argument(
184-
"--calibration_data",
185-
type=str,
186-
default="Once upon a time",
187-
help="Calibration prompts from users",
181+
help="Sequence length for GPTQ calibration",
188182
)
189183
parser.add_argument(
190184
"-t",
@@ -427,11 +421,6 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
427421
generate_full_logits=args.generate_full_logits,
428422
weight_type=weight_type,
429423
enable_dynamic_shape=args.enable_dynamic_shape,
430-
calibration_tasks=args.calibration_tasks,
431-
calibration_limit=args.calibration_limit,
432-
calibration_seq_length=args.calibration_seq_length,
433-
calibration_data=args.calibration_data,
434-
tokenizer_path=args.tokenizer_path,
435424
verbose=args.verbose,
436425
max_seq_len=args.max_seq_length,
437426
metadata_str=args.metadata,
@@ -641,11 +630,6 @@ def _load_llama_model(
641630
generate_full_logits: bool = False,
642631
weight_type: WeightType = WeightType.LLAMA,
643632
enable_dynamic_shape: bool = False,
644-
calibration_tasks: Optional[List[str]] = None,
645-
calibration_limit: Optional[int] = None,
646-
calibration_seq_length: Optional[int] = None,
647-
calibration_data: Optional[str] = None,
648-
tokenizer_path: Optional[str] = None,
649633
verbose: bool = False,
650634
max_seq_len: int = 128,
651635
metadata_str: Optional[str] = None,
@@ -701,11 +685,6 @@ def _load_llama_model(
701685
use_kv_cache=use_kv_cache,
702686
example_inputs=example_inputs,
703687
enable_dynamic_shape=enable_dynamic_shape,
704-
calibration_tasks=calibration_tasks,
705-
calibration_limit=calibration_limit,
706-
calibration_seq_length=calibration_seq_length,
707-
calibration_data=calibration_data,
708-
tokenizer_path=tokenizer_path,
709688
verbose=verbose,
710689
metadata=_load_llama_model_metadata(
711690
weight_type,

extension/llm/export/builder.py

Lines changed: 1 addition & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
2828

2929
from executorch.extension.export_util.utils import export_to_edge, save_pte_program
30-
from executorch.extension.llm.tokenizer.utils import get_tokenizer
3130
from torch._export import capture_pre_autograd_graph
3231
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
3332
from torch.ao.quantization.quantizer import Quantizer
@@ -67,11 +66,6 @@ def __init__(
6766
use_kv_cache,
6867
example_inputs,
6968
enable_dynamic_shape: bool = False,
70-
calibration_tasks: Optional[List[str]] = None,
71-
calibration_limit: Optional[int] = None,
72-
calibration_seq_length: Optional[int] = None,
73-
calibration_data: Optional[str] = None,
74-
tokenizer_path: Optional[str] = None,
7569
verbose: bool = False,
7670
metadata: Optional[dict] = None,
7771
dynamic_shapes: Optional[Any] = None,
@@ -93,11 +87,6 @@ def __init__(
9387
self.output_dir = "."
9488
self.dynamic_shapes = dynamic_shapes
9589
self._saved_pte_filename = None
96-
self.calibration_tasks = calibration_tasks
97-
self.calibration_limit = calibration_limit
98-
self.calibration_seq_length = calibration_seq_length
99-
self.calibration_data = calibration_data
100-
self.tokenizer_path = tokenizer_path
10190

10291
def set_output_dir(self, output_dir: str) -> "LLMEdgeManager":
10392
"""
@@ -178,69 +167,6 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager":
178167
)
179168
return self
180169

181-
def pt2e_calibrate(
182-
self,
183-
prepared_module,
184-
calibration_tasks,
185-
calibration_limit,
186-
calibration_seq_length,
187-
calibration_data,
188-
tokenizer_path,
189-
):
190-
logging.info("Run calibration...")
191-
try:
192-
from executorch.examples.models.llama2.eval_llama_lib import (
193-
GraphModuleEvalWrapper,
194-
)
195-
from executorch.examples.models.llama2.evaluate import evaluate_model
196-
except ImportError:
197-
raise ImportError(
198-
"Please install the llm eval dependency via examples/models/llama2/install_requirements.sh"
199-
)
200-
201-
tokenizer = get_tokenizer(tokenizer_path)
202-
203-
def calibrate_template(
204-
module: torch.fx.GraphModule, tokenizer, prompts: str, max_len: int
205-
):
206-
# TODO: change criteria & support batch inputs if necessary
207-
pos = torch.tensor(0, dtype=torch.int64)
208-
token_list = tokenizer.encode(prompts, bos=True, eos=False)
209-
210-
with torch.no_grad():
211-
while token_list[-1] != tokenizer.eos_id and pos < max_len:
212-
logits = module(
213-
torch.full((1, 1), token_list[pos]),
214-
torch.tensor((pos,)),
215-
)
216-
pos += 1
217-
if pos >= len(token_list):
218-
token_list.append(torch.argmax(logits[:], dim=-1).item())
219-
220-
calibrate_template(
221-
module=prepared_module,
222-
tokenizer=tokenizer,
223-
prompts=calibration_data,
224-
max_len=calibration_seq_length,
225-
)
226-
227-
eval_wrapper = GraphModuleEvalWrapper(
228-
model=prepared_module,
229-
tokenizer=tokenizer,
230-
max_seq_length=calibration_seq_length,
231-
use_kv_cache=self.use_kv_cache,
232-
enable_dynamic_shape=self.enable_dynamic_shape,
233-
)
234-
eval_results = evaluate_model(
235-
eval_wrapper,
236-
calibration_tasks,
237-
calibration_limit,
238-
)
239-
240-
for task, res in eval_results["results"].items():
241-
print(f"{task}: {res}")
242-
logging.info("Calibration finish...")
243-
244170
def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManager":
245171
"""
246172
Quantize the model via pt2e flow and retrieve LLMEdgeManager including the quantized model.
@@ -263,33 +189,8 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
263189
self.pre_autograd_graph_module is not None
264190
), "Please run capture_pre_autograd_graph first"
265191
m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer)
266-
logging.info(
267-
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
268-
)
269192
# Calibrate
270-
if (
271-
self.calibration_tasks is not None
272-
and self.calibration_limit is not None
273-
and self.calibration_seq_length is not None
274-
and self.calibration_data is not None
275-
and self.tokenizer_path is not None
276-
):
277-
logging.info(
278-
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
279-
)
280-
self.pt2e_calibrate(
281-
prepared_module=m,
282-
calibration_tasks=self.calibration_tasks,
283-
calibration_limit=self.calibration_limit,
284-
calibration_seq_length=self.calibration_seq_length,
285-
calibration_data=self.calibration_data,
286-
tokenizer_path=self.tokenizer_path,
287-
)
288-
else:
289-
logging.info(
290-
"No calibration provided, using dummy input to calibrate..."
291-
)
292-
m(*self.example_inputs)
193+
m(*self.example_inputs)
293194
m = convert_pt2e(m)
294195
DuplicateDynamicQuantChainPass()(m)
295196
self.pre_autograd_graph_module = m

0 commit comments

Comments
 (0)