Skip to content

Commit 85154aa

Browse files
committed
address comments
1 parent e1cbfe6 commit 85154aa

File tree

4 files changed

+53
-26
lines changed

4 files changed

+53
-26
lines changed

examples/models/llama2/eval_llama_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def gen_eval_wrapper(
162162
tokenizer=tokenizer,
163163
max_seq_length=args.max_seq_length,
164164
use_kv_cache=args.use_kv_cache,
165-
dynamic_shape=(manager.dynamic_shapes != None),
165+
dynamic_shape=(manager.dynamic_shapes is not None),
166166
)
167167

168168

examples/models/llama2/evaluate/eager_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def _model_call(self, inps):
8383
# graph module exported without dynamic shape won't work with a different shape.
8484
# And we have to do single token prefill here.
8585
result_logits = []
86-
for pos in range(self._max_seq_length):
86+
for pos in range(inps.shape[-1]):
8787
pos_tensor = torch.tensor([pos], dtype=torch.int64)
8888
logits = self._model(inps[:, pos : pos + 1], pos_tensor)
8989
result_logits.append(logits)

examples/models/llama2/export_llama_lib.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from enum import Enum
1717
from json import JSONDecodeError
1818
from pathlib import Path
19-
from typing import Optional, Union, List
19+
from typing import List, Optional, Union
2020

2121
import pkg_resources
2222

@@ -166,19 +166,25 @@ def build_args_parser() -> argparse.ArgumentParser:
166166
nargs="+",
167167
type=str,
168168
default=None,
169-
help="Tasks for GPTQ calibration",
169+
help="Tasks for GPTQ calibration from lm_eval",
170170
)
171171
parser.add_argument(
172172
"--calibration_limit",
173173
type=int,
174174
default=None,
175-
help="number of samples used for calibration",
175+
help="number of samples used for calibration from lm_eval",
176176
)
177177
parser.add_argument(
178178
"--calibration_seq_length",
179179
type=int,
180180
default=None,
181-
help="Sequence length for GPTQ calibration",
181+
help="Sequence length for GPTQ calibration from lm_eval",
182+
)
183+
parser.add_argument(
184+
"--calibration_data",
185+
type=str,
186+
default="Once upon a time",
187+
help="Calibration prompts from users",
182188
)
183189
parser.add_argument(
184190
"-t",
@@ -424,6 +430,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
424430
calibration_tasks=args.calibration_tasks,
425431
calibration_limit=args.calibration_limit,
426432
calibration_seq_length=args.calibration_seq_length,
433+
calibration_data=args.calibration_data,
427434
tokenizer_path=args.tokenizer_path,
428435
verbose=args.verbose,
429436
max_seq_len=args.max_seq_length,
@@ -637,6 +644,7 @@ def _load_llama_model(
637644
calibration_tasks: Optional[List[str]] = None,
638645
calibration_limit: Optional[int] = None,
639646
calibration_seq_length: Optional[int] = None,
647+
calibration_data: Optional[str] = None,
640648
tokenizer_path: Optional[str] = None,
641649
verbose: bool = False,
642650
max_seq_len: int = 128,
@@ -696,6 +704,7 @@ def _load_llama_model(
696704
calibration_tasks=calibration_tasks,
697705
calibration_limit=calibration_limit,
698706
calibration_seq_length=calibration_seq_length,
707+
calibration_data=calibration_data,
699708
tokenizer_path=tokenizer_path,
700709
verbose=verbose,
701710
metadata=_load_llama_model_metadata(

extension/llm/export/builder.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import logging
1212
from enum import Enum
1313
from typing import Any, Callable, List, Optional
14-
from executorch.extension.llm.tokenizer.utils import get_tokenizer
1514

1615
import torch
1716
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
@@ -28,6 +27,7 @@
2827
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
2928

3029
from executorch.extension.export_util.utils import export_to_edge, save_pte_program
30+
from executorch.extension.llm.tokenizer.utils import get_tokenizer
3131
from torch._export import capture_pre_autograd_graph
3232
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
3333
from torch.ao.quantization.quantizer import Quantizer
@@ -70,6 +70,7 @@ def __init__(
7070
calibration_tasks: Optional[List[str]] = None,
7171
calibration_limit: Optional[int] = None,
7272
calibration_seq_length: Optional[int] = None,
73+
calibration_data: Optional[str] = None,
7374
tokenizer_path: Optional[str] = None,
7475
verbose: bool = False,
7576
metadata: Optional[dict] = None,
@@ -95,6 +96,7 @@ def __init__(
9596
self.calibration_tasks = calibration_tasks
9697
self.calibration_limit = calibration_limit
9798
self.calibration_seq_length = calibration_seq_length
99+
self.calibration_data = calibration_data
98100
self.tokenizer_path = tokenizer_path
99101

100102
def set_output_dir(self, output_dir: str) -> "LLMEdgeManager":
@@ -176,41 +178,51 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager":
176178
)
177179
return self
178180

179-
180181
def pt2e_calibrate(
181182
self,
182183
prepared_module,
183184
calibration_tasks,
184185
calibration_limit,
185186
calibration_seq_length,
187+
calibration_data,
186188
tokenizer_path,
187189
):
188190
logging.info("Run calibration...")
189191
try:
190-
from executorch.examples.models.llama2.evaluate import EagerEvalWrapper, evaluate_model
192+
from executorch.examples.models.llama2.evaluate import (
193+
EagerEvalWrapper,
194+
evaluate_model,
195+
)
191196
except ImportError:
192197
raise ImportError(
193198
"Please install the llm eval dependency via examples/models/llama2/install_requirements.sh"
194199
)
195200

196201
tokenizer = get_tokenizer(tokenizer_path)
197202

198-
def calibrate_template(module: torch.fx.GraphModule, tokenizer, string: str = "Once upon a time", max_len: int = 128):
199-
# TODO: change criteria & support batch inputs if necessary
200-
pos = torch.tensor(0, dtype=torch.int64)
201-
token_list = [tokenizer.bos_id] + tokenizer.encode(string, bos=True, eos=False)
202-
203-
with torch.no_grad():
204-
while token_list[-1] != tokenizer.eos_id and pos < max_len:
205-
logits = module(
206-
torch.full((1, 1), token_list[pos]),
207-
torch.tensor((pos, )),
208-
)
209-
pos += 1
210-
if pos >= len(token_list):
211-
token_list.append(torch.argmax(logits[:], dim=-1).item())
203+
def calibrate_template(
204+
module: torch.fx.GraphModule, tokenizer, prompts: str, max_len: int
205+
):
206+
# TODO: change criteria & support batch inputs if necessary
207+
pos = torch.tensor(0, dtype=torch.int64)
208+
token_list = tokenizer.encode(prompts, bos=True, eos=False)
209+
210+
with torch.no_grad():
211+
while token_list[-1] != tokenizer.eos_id and pos < max_len:
212+
logits = module(
213+
torch.full((1, 1), token_list[pos]),
214+
torch.tensor((pos,)),
215+
)
216+
pos += 1
217+
if pos >= len(token_list):
218+
token_list.append(torch.argmax(logits[:], dim=-1).item())
212219

213-
calibrate_template(prepared_module, tokenizer, string="Once upon a time", max_len=calibration_seq_length)
220+
calibrate_template(
221+
module=prepared_module,
222+
tokenizer=tokenizer,
223+
prompts=calibration_data,
224+
max_len=calibration_seq_length,
225+
)
214226

215227
eval_wrapper = EagerEvalWrapper(
216228
model=prepared_module.to(device="cuda"),
@@ -251,20 +263,26 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
251263
self.pre_autograd_graph_module is not None
252264
), "Please run capture_pre_autograd_graph first"
253265
m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer)
266+
logging.info(
267+
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
268+
)
254269
# Calibrate
255-
logging.info(f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, seq_length: {self.calibration_seq_length}, tokenizer_path: {self.tokenizer_path}")
256270
if (
257271
self.calibration_tasks is not None
258272
and self.calibration_limit is not None
259273
and self.calibration_seq_length is not None
274+
and self.calibration_data is not None
260275
and self.tokenizer_path is not None
261276
):
262-
logging.info(f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, seq_length: {self.calibration_seq_length}")
277+
logging.info(
278+
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
279+
)
263280
self.pt2e_calibrate(
264281
prepared_module=m,
265282
calibration_tasks=self.calibration_tasks,
266283
calibration_limit=self.calibration_limit,
267284
calibration_seq_length=self.calibration_seq_length,
285+
calibration_data=self.calibration_data,
268286
tokenizer_path=self.tokenizer_path,
269287
)
270288
else:

0 commit comments

Comments
 (0)