Skip to content

Commit ca9530f

Browse files
Jack-Khuufacebook-github-bot
authored andcommitted
Add placeholder wrappers for ET evaluation: Eager and Runtime (#2821)
Summary: Pull Request resolved: #2821 ET Evaluation can reuse most of the work from evaluating AOT pre-lowered models. This diff introduces 2 skeleton implementations of EvaluationWrappers: * one for eager * one for runtime. Notably, these wrappers are used when a .pte file is provided (with the latter requiring a tokenizer binary) Reviewed By: larryliu0820 Differential Revision: D55674251 fbshipit-source-id: f972a1728b8c3d5cead2fcf2ad3196f6b3254dd6
1 parent 65f3e18 commit ca9530f

File tree

1 file changed

+85
-0
lines changed

1 file changed

+85
-0
lines changed

examples/models/llama2/eval_llama_lib.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,58 @@ def _model_generate(self, context, max_length, eos_token_id):
8484
raise Exception("unimplemented")
8585

8686

87+
class ETEagerEvalWrapper(GPTFastEvalWrapper):
88+
"""
89+
A wrapper class for ExecuTorch Eager integration with the
90+
lm-evaluation-harness library.
91+
"""
92+
93+
def __init__(
94+
self,
95+
model: str,
96+
tokenizer: SentencePieceProcessor,
97+
max_seq_length: Optional[int] = None,
98+
):
99+
super().__init__(None, tokenizer, max_seq_length)
100+
self._model = model
101+
102+
def _model_call(self, inps):
103+
# Given inps (tokens), return the logits from a single
104+
# forward call
105+
106+
# Example:
107+
# inps: Tensor of shape (1, N)
108+
# logits: Tensor of shape (1, N, 32000)
109+
pass
110+
111+
112+
class ETRunnerEvalWrapper(GPTFastEvalWrapper):
113+
"""
114+
A wrapper class for ExecuTorch Runtime integration with the
115+
lm-evaluation-harness library.
116+
"""
117+
118+
def __init__(
119+
self,
120+
model: str,
121+
tokenizer: SentencePieceProcessor,
122+
tokenizer_bin: str,
123+
max_seq_length: Optional[int] = None,
124+
):
125+
super().__init__(None, tokenizer, max_seq_length)
126+
self._model = model
127+
self._tokenizer_bin = tokenizer_bin
128+
129+
def _model_call(self, inps):
130+
# Given inps (tokens), return the logits from a single
131+
# forward call
132+
133+
# Example:
134+
# inps: Tensor of shape (1, N)
135+
# logits: Tensor of shape (1, N, 32000)
136+
pass
137+
138+
87139
@torch.no_grad()
88140
def eval(
89141
eval_wrapper: LM,
@@ -131,6 +183,24 @@ def gen_eval_wrapper(
131183
"""
132184
tokenizer = SentencePieceProcessor(model_file=str(args.tokenizer_path))
133185

186+
# ExecuTorch Binary Evaluation
187+
if (model := args.pte) is not None:
188+
if (tokenizer_bin := args.tokenizer_bin) is not None:
189+
# ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime
190+
return ETRunnerEvalWrapper(
191+
model=model,
192+
tokenizer=tokenizer,
193+
tokenizer_bin=tokenizer_bin,
194+
max_seq_length=args.max_seq_length,
195+
)
196+
197+
# ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated eagerly
198+
return ETEagerEvalWrapper(
199+
model=model,
200+
tokenizer=tokenizer,
201+
max_seq_length=args.max_seq_length,
202+
)
203+
134204
# GPTFastEvalWrapper: Create a wrapper around a pre-exported model
135205
manager: LlamaEdgeManager = _prepare_for_llama_export(model_name, args)
136206
model = (
@@ -161,6 +231,21 @@ def build_args_parser() -> argparse.ArgumentParser:
161231
"--limit", type=int, default=5, help="number of samples to evalulate"
162232
)
163233

234+
# Add additional args specific to eval via an ET Runner
235+
# Note: For initial integration, the tokenizer.model is also required
236+
parser.add_argument(
237+
"--pte",
238+
type=str,
239+
default=None,
240+
help="[For ExecuTorch] Path to the ExecuTorch model being evaluated. If provided, don't go through the export flow",
241+
)
242+
parser.add_argument(
243+
"--tokenizer_bin",
244+
type=str,
245+
default=None,
246+
help="[For ExecuTorch] Path to the Tokenizer binary for evaluating ExecuTorch models via runtime",
247+
)
248+
164249
return parser
165250

166251

0 commit comments

Comments
 (0)