Skip to content

Commit d4d7cfa

Browse files
committed
add graph module eval wrapper
1 parent 21d3974 commit d4d7cfa

File tree

3 files changed

+67
-29
lines changed

3 files changed

+67
-29
lines changed

examples/models/llama2/eval_llama_lib.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,51 @@
2929
)
3030

3131

32+
class GraphModuleEvalWrapper(EagerEvalWrapper):
33+
"""
34+
A wrapper class for ExecuTorch py-binded integration with the
35+
lm-evaluation-harness library.
36+
"""
37+
38+
def __init__(
39+
self,
40+
model: torch.fx.GraphModule,
41+
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
42+
max_seq_length: Optional[int] = None,
43+
use_kv_cache: bool = False,
44+
enable_dynamic_shape: bool = True,
45+
):
46+
super().__init__(
47+
model=model, tokenizer=tokenizer, max_seq_length=max_seq_length
48+
)
49+
self._model = model.to(self.device)
50+
self._use_kv_cache = use_kv_cache
51+
self._enable_dynamic_shape = enable_dynamic_shape
52+
53+
def _model_call(self, inps):
54+
if self._use_kv_cache:
55+
if not self._enable_dynamic_shape:
56+
# graph module exported without dynamic shape won't work with a different shape.
57+
# And we have to do single token prefill here.
58+
result_logits = []
59+
for pos in range(inps.shape[-1]):
60+
pos_tensor = torch.tensor([pos], dtype=torch.int64)
61+
logits = self._model(inps[:, pos : pos + 1], pos_tensor)
62+
result_logits.append(logits)
63+
return torch.cat(result_logits, dim=1)
64+
else:
65+
pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device)
66+
# Batch process the whole sequence.
67+
logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
68+
return logits
69+
70+
else:
71+
return self._model(inps)
72+
73+
def _model_generate(self, context, max_length, eos_token_id):
74+
raise Exception("unimplemented")
75+
76+
3277
class ETPybindEvalWrapper(EagerEvalWrapper):
3378
"""
3479
A wrapper class for ExecuTorch py-binded integration with the
@@ -148,6 +193,13 @@ def gen_eval_wrapper(
148193
if torch.cuda.is_available()
149194
else manager.pre_autograd_graph_module.to(device="cpu")
150195
)
196+
return GraphModuleEvalWrapper(
197+
model=model,
198+
tokenizer=tokenizer,
199+
max_seq_length=args.max_seq_length,
200+
use_kv_cache=args.use_kv_cache,
201+
enable_dynamic_shape=args.enable_dynamic_shape,
202+
)
151203
else:
152204
# TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
153205
# for quantizers. Currently capture_pre_autograd_graph only works with --kv_cache, but
@@ -157,13 +209,12 @@ def gen_eval_wrapper(
157209
if torch.cuda.is_available()
158210
else manager.model.eval().to(device="cpu")
159211
)
160-
return EagerEvalWrapper(
161-
model=model,
162-
tokenizer=tokenizer,
163-
max_seq_length=args.max_seq_length,
164-
use_kv_cache=args.use_kv_cache,
165-
dynamic_shape=(manager.dynamic_shapes is not None),
166-
)
212+
return EagerEvalWrapper(
213+
model=model,
214+
tokenizer=tokenizer,
215+
max_seq_length=args.max_seq_length,
216+
use_kv_cache=args.use_kv_cache,
217+
)
167218

168219

169220
def build_args_parser() -> argparse.ArgumentParser:

examples/models/llama2/evaluate/eager_eval.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def __init__(
3333
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
3434
max_seq_length: Optional[int] = None,
3535
use_kv_cache: bool = False,
36-
dynamic_shape: bool = True,
3736
):
3837
device = "cuda" if torch.cuda.is_available() else "cpu"
3938
super().__init__(device=device)
@@ -42,7 +41,6 @@ def __init__(
4241
self._device = torch.device(device)
4342
self._max_seq_length = 2048 if max_seq_length is None else max_seq_length
4443
self._use_kv_cache = use_kv_cache
45-
self._dynamic_shape = dynamic_shape
4644

4745
@property
4846
def eot_token_id(self):
@@ -79,21 +77,10 @@ def tok_decode(self, tokens):
7977

8078
def _model_call(self, inps):
8179
if self._use_kv_cache:
82-
if not self._dynamic_shape:
83-
# graph module exported without dynamic shape won't work with a different shape.
84-
# And we have to do single token prefill here.
85-
result_logits = []
86-
for pos in range(inps.shape[-1]):
87-
pos_tensor = torch.tensor([pos], dtype=torch.int64)
88-
logits = self._model(inps[:, pos : pos + 1], pos_tensor)
89-
result_logits.append(logits)
90-
return torch.cat(result_logits, dim=1)
91-
else:
92-
pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device)
93-
# Batch process the whole sequence.
94-
logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
95-
return logits
96-
80+
pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device)
81+
# Batch process the whole sequence.
82+
logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
83+
return logits
9784
else:
9885
return self._model(inps)
9986

extension/llm/export/builder.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,10 @@ def pt2e_calibrate(
189189
):
190190
logging.info("Run calibration...")
191191
try:
192-
from executorch.examples.models.llama2.evaluate import (
193-
EagerEvalWrapper,
194-
evaluate_model,
192+
from executorch.examples.models.llama2.eval_llama_lib import (
193+
GraphModuleEvalWrapper,
195194
)
195+
from executorch.examples.models.llama2.evaluate import evaluate_model
196196
except ImportError:
197197
raise ImportError(
198198
"Please install the llm eval dependency via examples/models/llama2/install_requirements.sh"
@@ -224,12 +224,12 @@ def calibrate_template(
224224
max_len=calibration_seq_length,
225225
)
226226

227-
eval_wrapper = EagerEvalWrapper(
227+
eval_wrapper = GraphModuleEvalWrapper(
228228
model=prepared_module,
229229
tokenizer=tokenizer,
230230
max_seq_length=calibration_seq_length,
231231
use_kv_cache=self.use_kv_cache,
232-
dynamic_shape=self.enable_dynamic_shape,
232+
enable_dynamic_shape=self.enable_dynamic_shape,
233233
)
234234
eval_results = evaluate_model(
235235
eval_wrapper,

0 commit comments

Comments
 (0)