Skip to content

Commit ef21787

Browse files
Jack-Khuufacebook-github-bot
authored andcommitted
Move LM wrapper generation into helper (#2609)
Summary: Pull Request resolved: #2609 Part of a refactor of eval_llama_lib to support integration with runtime evaluation and removing model quantization from the eval flow. Specifically this diff just moves the Wrapper generation logic into a helper bypass-github-export-checks bypass-github-pytorch-ci-checks bypass-github-executorch-ci-checks Reviewed By: iseeyuan Differential Revision: D55273853 fbshipit-source-id: 98e1e6108454f39d90aa79cc09024d732254aa64
1 parent 542ef50 commit ef21787

File tree

1 file changed

+35
-25
lines changed

1 file changed

+35
-25
lines changed

examples/models/llama2/eval_llama_lib.py

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import lm_eval
1212

1313
import torch
14+
from lm_eval.api.model import LM
1415
from lm_eval.evaluator import evaluate
1516
from lm_eval.models.huggingface import HFLM as eval_wrapper
1617
from lm_eval.tasks import get_task_dict
@@ -26,7 +27,7 @@
2627

2728
class GPTFastEvalWrapper(eval_wrapper):
2829
"""
29-
A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library.
30+
A wrapper class based on GPTFast, providing integration with the lm-evaluation-harness library.
3031
"""
3132

3233
def __init__(
@@ -85,21 +86,17 @@ def _model_generate(self, context, max_length, eos_token_id):
8586

8687
@torch.no_grad()
8788
def eval(
88-
model: nn.Module,
89-
tokenizer,
89+
eval_wrapper: LM,
9090
tasks: Optional[list] = None,
9191
limit: Optional[int] = None,
92-
max_seq_length: Optional[int] = None,
9392
) -> dict:
9493
"""
9594
Evaluates a language model on a specified task using the lm-evaluation-harness library.
9695
9796
Args:
98-
model (nn.Module): The pre-trained language model to evaluate.
99-
tokenizer: The tokenizer to use for encoding/decoding text.
97+
eval_wrapper (LM): A LM wrapper class compatible with lm-evaluation-harness evaluation
10098
task (str): The name of the evaluation task to perform.
10199
limit (Optional[int]): The maximum number of samples to evaluate (None for all available).
102-
max_seq_length (Optional[int]): The maximum sequence length allowed for input text.
103100
104101
Returns:
105102
eval_results (dict): A dictionary of evaluation results for the specified task(s).
@@ -108,25 +105,46 @@ def eval(
108105
if tasks is None:
109106
tasks = ["wikitext"]
110107

111-
model_eval_wrapper = GPTFastEvalWrapper(
112-
model,
113-
tokenizer,
114-
max_seq_length,
115-
)
116-
117108
if "hendrycks_test" in tasks:
118109
tasks.remove("hendrycks_test")
119110
tasks += list(lm_eval.tasks.hendrycks_test.create_all_tasks().keys())
120111
task_dict = get_task_dict(tasks)
121112

122113
eval_results = evaluate(
123-
model_eval_wrapper,
114+
eval_wrapper,
124115
task_dict,
125116
limit=limit,
126117
)
127118
return eval_results
128119

129120

121+
def gen_eval_wrapper(
122+
model_name: str,
123+
args: argparse.ArgumentParser,
124+
) -> LM:
125+
"""
126+
Generates a wrapper interface around the provided model and tokenizer for
127+
the lm-evaluation-harness library.
128+
129+
Returns:
130+
eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library.
131+
"""
132+
tokenizer = SentencePieceProcessor(model_file=str(args.tokenizer_path))
133+
134+
# GPTFastEvalWrapper: Create a wrapper around a pre-exported model
135+
manager: LlamaEdgeManager = _prepare_for_llama_export(model_name, args)
136+
model = (
137+
manager.model.eval().to(device="cuda")
138+
if torch.cuda.is_available()
139+
else manager.model.to(device="cpu")
140+
)
141+
return GPTFastEvalWrapper(
142+
model=model,
143+
tokenizer=tokenizer,
144+
max_seq_length=args.max_seq_length,
145+
)
146+
147+
130148
def build_args_parser() -> argparse.ArgumentParser:
131149
# Start with arg parser from export_llama_lib
132150
parser = _build_args_parser()
@@ -150,22 +168,14 @@ def eval_llama(
150168
model_name: str,
151169
args: argparse.ArgumentParser,
152170
) -> None:
153-
# Get a pre-lowering/to_edge LlamaEdgeManager instance
154-
manager: LlamaEdgeManager = _prepare_for_llama_export(model_name, args)
155-
tokenizer = SentencePieceProcessor(model_file=str(args.tokenizer_path))
171+
# Generate the eval wrapper
172+
eval_wrapper = gen_eval_wrapper(model_name, args)
156173

157174
# Evaluate the model
158-
model = (
159-
manager.model.eval().to(device="cuda")
160-
if torch.cuda.is_available()
161-
else manager.model.to(device="cpu")
162-
)
163175
eval_results = eval(
164-
model,
165-
tokenizer,
176+
eval_wrapper,
166177
args.tasks,
167178
args.limit,
168-
args.max_seq_length,
169179
)
170180

171181
for task, res in eval_results["results"].items():

0 commit comments

Comments
 (0)