Skip to content

Commit b60fa71

Browse files
Lunwen Hefacebook-github-bot
authored andcommitted
buckify eval_llama (#5437)
Summary: Pull Request resolved: #5437 This PR buckify `eval_llama`. This is useful when we need to run eval using buck. Reviewed By: mergennachin Differential Revision: D62897016 fbshipit-source-id: 59cc64eaa3b29f707b9aa7d3ac2568a16c2743c9
1 parent f0662bb commit b60fa71

File tree

5 files changed

+64
-19
lines changed

5 files changed

+64
-19
lines changed

examples/models/llama2/TARGETS

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,5 +104,45 @@ runtime.python_library(
104104
"//executorch/util:python_profiler",
105105
"fbsource//third-party/pypi/coremltools:coremltools",
106106
"fbsource//third-party/pypi/sentencepiece:sentencepiece",
107+
"//pytorch/ao:torchao",
108+
],
109+
)
110+
111+
runtime.python_binary(
112+
name = "eval_llama",
113+
main_function = "executorch.examples.models.llama2.eval_llama.main",
114+
preload_deps = [
115+
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
116+
"//executorch/kernels/quantized:aot_lib",
117+
],
118+
deps = [
119+
":eval_library",
120+
"//caffe2:torch",
121+
],
122+
)
123+
124+
runtime.python_library(
125+
name = "eval_library",
126+
srcs = [
127+
"eval_llama.py",
128+
"eval_llama_lib.py",
129+
"evaluate/eager_eval.py",
130+
],
131+
_is_external_target = True,
132+
base_module = "executorch.examples.models.llama2",
133+
visibility = [
134+
"//bento/...",
135+
"//bento_kernels/...",
136+
"//executorch/examples/...",
137+
"@EXECUTORCH_CLIENTS",
138+
],
139+
deps = [
140+
"fbsource//third-party/pypi/lm-eval:lm-eval",
141+
"fbsource//third-party/pypi/tiktoken:tiktoken",
142+
":export_library",
143+
"//executorch/examples/models/llama2/tokenizer:tiktoken_py",
144+
"//executorch/extension/llm/export:export_lib",
145+
"//executorch/extension/llm/tokenizer:tokenizer_py_lib",
146+
"//executorch/extension/pybindings:portable_lib",
107147
],
108148
)

examples/models/llama2/eval_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def main() -> None:
2424
args = parser.parse_args()
2525
# Overrides this arg, because evaluation requires full logits.
2626
args.generate_full_logits = True
27-
eval_llama(modelname, args)
27+
eval_llama(modelname, args) # pyre-ignore
2828

2929

3030
if __name__ == "__main__":

examples/models/llama2/eval_llama_lib.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,20 @@
1010
from typing import Optional, Union
1111

1212
import torch
13-
from executorch.examples.models.llama2.evaluate import EagerEvalWrapper, evaluate_model
1413
from executorch.examples.models.llama2.export_llama_lib import (
1514
get_quantizer_and_quant_params,
1615
)
1716
from executorch.examples.models.llama2.tokenizer.tiktoken import Tokenizer as Tiktoken
1817

19-
from executorch.extension.llm.export import LLMEdgeManager
18+
from executorch.extension.llm.export.builder import LLMEdgeManager
2019
from executorch.extension.llm.tokenizer.tokenizer import (
2120
Tokenizer as SentencePieceTokenizer,
2221
)
2322
from executorch.extension.llm.tokenizer.utils import get_tokenizer
2423
from lm_eval.api.model import LM
2524

25+
from .evaluate.eager_eval import EagerEvalWrapper, evaluate_model
26+
2627
from .export_llama_lib import (
2728
_prepare_for_llama_export,
2829
build_args_parser as _build_args_parser,
@@ -91,7 +92,7 @@ def __init__(
9192
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
9293
max_seq_length: Optional[int] = None,
9394
):
94-
super().__init__(None, tokenizer, max_seq_length)
95+
super().__init__(None, tokenizer, max_seq_length) # pyre-ignore
9596
self._model = model # Expects model to be path to a .pte file
9697

9798
from executorch.extension.pybindings.portable_lib import _load_for_executorch
@@ -106,7 +107,7 @@ def __init__(
106107
from executorch.kernels import quantized # noqa
107108

108109
self._et_model = _load_for_executorch(self._model)
109-
self._use_kv_cache = self._et_model.run_method("use_kv_cache")[0]
110+
self._use_kv_cache = self._et_model.run_method("use_kv_cache")[0] # pyre-ignore
110111

111112
def _model_call(self, inps):
112113
# Given inps (tokens), return the logits from a single forward call
@@ -140,7 +141,7 @@ def __init__(
140141
tokenizer_bin: str,
141142
max_seq_length: Optional[int] = None,
142143
):
143-
super().__init__(None, tokenizer, max_seq_length)
144+
super().__init__(None, tokenizer, max_seq_length) # pyre-ignore
144145
self._model = model
145146
self._tokenizer_bin = tokenizer_bin
146147

@@ -165,17 +166,17 @@ def gen_eval_wrapper(
165166
Returns:
166167
eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library.
167168
"""
168-
tokenizer = get_tokenizer(args.tokenizer_path)
169+
tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore
169170

170171
# ExecuTorch Binary Evaluation
171-
if (model := args.pte) is not None:
172-
if (tokenizer_bin := args.tokenizer_bin) is not None:
172+
if (model := args.pte) is not None: # pyre-ignore
173+
if (tokenizer_bin := args.tokenizer_bin) is not None: # pyre-ignore
173174
# ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime
174175
return ETRunnerEvalWrapper(
175176
model=model,
176177
tokenizer=tokenizer,
177178
tokenizer_bin=tokenizer_bin,
178-
max_seq_length=args.max_seq_length,
179+
max_seq_length=args.max_seq_length, # pyre-ignore
179180
)
180181

181182
# ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings
@@ -194,16 +195,16 @@ def gen_eval_wrapper(
194195
if len(quantizers) != 0:
195196
manager = manager.capture_pre_autograd_graph().pt2e_quantize(quantizers)
196197
model = (
197-
manager.pre_autograd_graph_module.to(device="cuda")
198+
manager.pre_autograd_graph_module.to(device="cuda") # pyre-ignore
198199
if torch.cuda.is_available()
199200
else manager.pre_autograd_graph_module.to(device="cpu")
200201
)
201202
return GraphModuleEvalWrapper(
202203
model=model,
203204
tokenizer=tokenizer,
204205
max_seq_length=args.max_seq_length,
205-
use_kv_cache=args.use_kv_cache,
206-
enable_dynamic_shape=args.enable_dynamic_shape,
206+
use_kv_cache=args.use_kv_cache, # pyre-ignore
207+
enable_dynamic_shape=args.enable_dynamic_shape, # pyre-ignore
207208
)
208209
else:
209210
# TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
@@ -221,7 +222,7 @@ def gen_eval_wrapper(
221222
# that is not available in this eval_llama. We save the checkpoint
222223
# here for consistency with eval_llama. The accuracy results we
223224
# get from eval_llama can be used as a reference to other evaluations.
224-
if args.output_eager_checkpoint_file is not None:
225+
if args.output_eager_checkpoint_file is not None: # pyre-ignore
225226
torch.save(model, args.output_eager_checkpoint_file)
226227

227228
return EagerEvalWrapper(
@@ -282,8 +283,8 @@ def eval_llama(
282283
# Evaluate the model
283284
eval_results = evaluate_model(
284285
eval_wrapper,
285-
args.tasks,
286-
args.limit,
286+
args.tasks, # pyre-ignore
287+
args.limit, # pyre-ignore
287288
)
288289

289290
for task, res in eval_results["results"].items():

examples/models/llama2/evaluate/eager_eval.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def batch_size(self):
6262
def device(self):
6363
return self._device
6464

65-
def tok_encode(self, string: str, **kwargs):
65+
def tok_encode(self, string: str, **kwargs): # pyre-ignore
6666
tokens = self._tokenizer.encode(string, bos=True, eos=False)
6767
encoded = torch.tensor(tokens, dtype=torch.int, device=self.device)
6868
# encoded is a pytorch tensor, but some internal logic in the
@@ -111,7 +111,9 @@ def evaluate_model(
111111

112112
if "hendrycks_test" in tasks:
113113
tasks.remove("hendrycks_test")
114-
tasks += list(lm_eval.tasks.hendrycks_test.create_all_tasks().keys())
114+
tasks += list(
115+
lm_eval.tasks.hendrycks_test.create_all_tasks().keys() # pyre-ignore
116+
)
115117
task_dict = get_task_dict(tasks)
116118

117119
eval_results = evaluate(

extension/llm/export/builder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,9 @@ def pt2e_calibrate(
209209
from executorch.examples.models.llama2.eval_llama_lib import (
210210
GraphModuleEvalWrapper,
211211
)
212-
from executorch.examples.models.llama2.evaluate import evaluate_model
212+
from executorch.examples.models.llama2.evaluate import ( # pyre-ignore[21]
213+
evaluate_model,
214+
)
213215
except ImportError:
214216
raise ImportError(
215217
"Please install the llm eval dependency via examples/models/llama2/install_requirements.sh"

0 commit comments

Comments
 (0)