Skip to content

Commit 49b23bb

Browse files
committed
add option to run mmlu with 5 shots
[ghstack-poisoned]
1 parent df5b2ab commit 49b23bb

File tree

3 files changed

+20
-48
lines changed

3 files changed

+20
-48
lines changed

examples/models/llama2/eval_llama_lib.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
)
2222
from executorch.extension.llm.tokenizer.utils import get_tokenizer
2323
from lm_eval.api.model import LM
24+
from lm_eval.evaluator import simple_evaluate
2425

25-
from .evaluate.eager_eval import EagerEvalWrapper, evaluate_model
26+
from .evaluate.eager_eval import EagerEvalWrapper
2627

2728
from .export_llama_lib import (
2829
_prepare_for_llama_export,
@@ -246,9 +247,16 @@ def build_args_parser() -> argparse.ArgumentParser:
246247
help="list of lm-eluther tasks to evaluate usage: --tasks task1 task2",
247248
)
248249
parser.add_argument(
249-
"--limit", type=int, default=5, help="number of samples to evalulate"
250+
"--limit", type=int, default=None, help="number of samples to evalulate"
251+
)
252+
parser.add_argument(
253+
"-f",
254+
"--num_fewshot",
255+
type=int,
256+
default=None,
257+
metavar="N",
258+
help="Number of examples in few-shot context",
250259
)
251-
252260
# Add additional args specific to eval via an ET Runner
253261
# Note: For initial integration, the tokenizer.model is also required
254262
parser.add_argument(
@@ -281,11 +289,13 @@ def eval_llama(
281289
eval_wrapper = gen_eval_wrapper(model_name, args)
282290

283291
# Evaluate the model
284-
eval_results = evaluate_model(
285-
eval_wrapper,
286-
args.tasks, # pyre-ignore
287-
args.limit, # pyre-ignore
288-
)
292+
with torch.no_grad():
293+
eval_results = simple_evaluate(
294+
model=eval_wrapper,
295+
tasks=args.tasks,
296+
num_fewshot=args.num_fewshot,
297+
limit=args.limit,
298+
)
289299

290300
for task, res in eval_results["results"].items():
291301
print(f"{task}: {res}")

examples/models/llama2/evaluate/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from .eager_eval import EagerEvalWrapper, evaluate_model
7+
from .eager_eval import EagerEvalWrapper
88

99
__all__ = [
10-
"evaluate_model",
1110
"EagerEvalWrapper",
1211
]

examples/models/llama2/evaluate/eager_eval.py

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515
)
1616

1717
from lm_eval.api.model import LM
18-
from lm_eval.evaluator import evaluate
18+
from lm_eval.evaluator import simple_evaluate
1919
from lm_eval.models.huggingface import HFLM as eval_wrapper
20-
from lm_eval.tasks import get_task_dict
2120

2221
from torch import nn
2322

@@ -79,39 +78,3 @@ def _model_call(self, inps):
7978

8079
def _model_generate(self, context, max_length, eos_token_id):
8180
raise Exception("unimplemented")
82-
83-
84-
@torch.no_grad()
85-
def evaluate_model(
86-
eval_wrapper: LM,
87-
tasks: Optional[list] = None,
88-
limit: Optional[int] = None,
89-
) -> dict:
90-
"""
91-
Evaluates a language model on a specified task using the lm-evaluation-harness library.
92-
93-
Args:
94-
eval_wrapper (LM): A LM wrapper class compatible with lm-evaluation-harness evaluation
95-
tasks: Optional[list]: The names of the evaluation tasks to perform.
96-
limit (Optional[int]): The maximum number of samples to evaluate (None for all available).
97-
98-
Returns:
99-
eval_results (dict): A dictionary of evaluation results for the specified task(s).
100-
"""
101-
102-
if tasks is None:
103-
tasks = ["wikitext"]
104-
105-
if "hendrycks_test" in tasks:
106-
tasks.remove("hendrycks_test")
107-
tasks += list(
108-
lm_eval.tasks.hendrycks_test.create_all_tasks().keys() # pyre-ignore
109-
)
110-
task_dict = get_task_dict(tasks)
111-
112-
eval_results = evaluate(
113-
eval_wrapper,
114-
task_dict,
115-
limit=limit,
116-
)
117-
return eval_results

0 commit comments

Comments
 (0)