Skip to content

Commit 194d677

Browse files
committed
Factor out eager val from eval_llama_lib
Would like to re-use EagerEvalWrapper and eval function for quantization calibration. Differential Revision: [D57881028](https://our.internmc.facebook.com/intern/diff/D57881028/) ghstack-source-id: 227994056 Pull Request resolved: #3756
1 parent 79e9b79 commit 194d677

File tree

3 files changed

+141
-110
lines changed

3 files changed

+141
-110
lines changed

examples/models/llama2/eval_llama_lib.py

Lines changed: 5 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,20 @@
99

1010
from typing import Optional, Union
1111

12-
import lm_eval
1312
import torch
1413
from executorch.examples.models.llama2.export_llama_lib import (
1514
get_quantizer_and_quant_params,
1615
)
16+
from executorch.examples.models.llama2.source_transformation.quantize import (
17+
EagerEvalWrapper,
18+
evaluate_model,
19+
)
1720
from executorch.examples.models.llama2.tokenizer.tiktoken import Tokenizer as Tiktoken
1821
from executorch.examples.models.llama2.tokenizer.tokenizer import (
1922
Tokenizer as SentencePieceTokenizer,
2023
)
2124

2225
from lm_eval.api.model import LM
23-
from lm_eval.evaluator import evaluate
24-
from lm_eval.models.huggingface import HFLM as eval_wrapper
25-
from lm_eval.tasks import get_task_dict
26-
27-
from torch import nn
2826

2927
from .builder import LlamaEdgeManager
3028
from .export_llama_lib import (
@@ -33,75 +31,6 @@
3331
)
3432

3533

36-
class EagerEvalWrapper(eval_wrapper):
37-
"""
38-
A wrapper class based on GPTFast, providing integration with the lm-evaluation-harness library.
39-
"""
40-
41-
def __init__(
42-
self,
43-
model: nn.Module,
44-
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
45-
max_seq_length: Optional[int] = None,
46-
use_kv_cache: bool = False,
47-
):
48-
device = "cuda" if torch.cuda.is_available() else "cpu"
49-
super().__init__(device=device)
50-
self._model = model
51-
self._tokenizer = tokenizer
52-
self._device = torch.device(device)
53-
self._max_seq_length = 2048 if max_seq_length is None else max_seq_length
54-
self._use_kv_cache = use_kv_cache
55-
56-
@property
57-
def eot_token_id(self):
58-
return self._tokenizer.eos_id
59-
60-
@property
61-
def max_length(self):
62-
return self._max_seq_length
63-
64-
@property
65-
def max_gen_toks(self):
66-
return 50
67-
68-
@property
69-
def batch_size(self):
70-
return 1
71-
72-
@property
73-
def device(self):
74-
return self._device
75-
76-
def tok_encode(self, string: str, **kwargs):
77-
tokens = self._tokenizer.encode(string, bos=True, eos=False)
78-
encoded = torch.tensor(tokens, dtype=torch.int, device=self.device)
79-
# encoded is a pytorch tensor, but some internal logic in the
80-
# eval harness expects it to be a list instead
81-
# TODO: verify this for multi-batch as well
82-
encoded = encoded.tolist()
83-
return encoded
84-
85-
def tok_decode(self, tokens):
86-
decoded = self._tokenizer.decode(tokens)
87-
return decoded
88-
89-
def _model_call(self, inps):
90-
if self._use_kv_cache:
91-
pos_tensor = torch.arange(
92-
self._max_seq_length, dtype=torch.int64, device=self.device
93-
)
94-
95-
# Batch process the whole sequence.
96-
logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
97-
return logits
98-
else:
99-
return self._model(inps)
100-
101-
def _model_generate(self, context, max_length, eos_token_id):
102-
raise Exception("unimplemented")
103-
104-
10534
class ETPybindEvalWrapper(EagerEvalWrapper):
10635
"""
10736
A wrapper class for ExecuTorch py-binded integration with the
@@ -165,40 +94,6 @@ def _model_call(self, inps):
16594
pass
16695

16796

168-
@torch.no_grad()
169-
def eval(
170-
eval_wrapper: LM,
171-
tasks: Optional[list] = None,
172-
limit: Optional[int] = None,
173-
) -> dict:
174-
"""
175-
Evaluates a language model on a specified task using the lm-evaluation-harness library.
176-
177-
Args:
178-
eval_wrapper (LM): A LM wrapper class compatible with lm-evaluation-harness evaluation
179-
task (str): The name of the evaluation task to perform.
180-
limit (Optional[int]): The maximum number of samples to evaluate (None for all available).
181-
182-
Returns:
183-
eval_results (dict): A dictionary of evaluation results for the specified task(s).
184-
"""
185-
186-
if tasks is None:
187-
tasks = ["wikitext"]
188-
189-
if "hendrycks_test" in tasks:
190-
tasks.remove("hendrycks_test")
191-
tasks += list(lm_eval.tasks.hendrycks_test.create_all_tasks().keys())
192-
task_dict = get_task_dict(tasks)
193-
194-
eval_results = evaluate(
195-
eval_wrapper,
196-
task_dict,
197-
limit=limit,
198-
)
199-
return eval_results
200-
201-
20297
def gen_eval_wrapper(
20398
model_name: str,
20499
args: argparse.ArgumentParser,
@@ -307,7 +202,7 @@ def eval_llama(
307202
eval_wrapper = gen_eval_wrapper(model_name, args)
308203

309204
# Evaluate the model
310-
eval_results = eval(
205+
eval_results = evaluate_model(
311206
eval_wrapper,
312207
args.tasks,
313208
args.limit,
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .eval import EagerEvalWrapper, evaluate_model
8+
9+
__all__ = [
10+
"evaluate_model",
11+
"EagerEvalWrapper",
12+
]
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from typing import Optional, Union
9+
10+
import lm_eval
11+
import torch
12+
from executorch.examples.models.llama2.tokenizer.tiktoken import Tokenizer as Tiktoken
13+
from executorch.examples.models.llama2.tokenizer.tokenizer import (
14+
Tokenizer as SentencePieceTokenizer,
15+
)
16+
17+
from lm_eval.api.model import LM
18+
from lm_eval.evaluator import evaluate
19+
from lm_eval.models.huggingface import HFLM as eval_wrapper
20+
from lm_eval.tasks import get_task_dict
21+
22+
from torch import nn
23+
24+
25+
class EagerEvalWrapper(eval_wrapper):
26+
"""
27+
A wrapper class based on GPTFast, providing integration with the lm-evaluation-harness library.
28+
"""
29+
30+
def __init__(
31+
self,
32+
model: nn.Module,
33+
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
34+
max_seq_length: Optional[int] = None,
35+
use_kv_cache: bool = False,
36+
):
37+
device = "cuda" if torch.cuda.is_available() else "cpu"
38+
super().__init__(device=device)
39+
self._model = model
40+
self._tokenizer = tokenizer
41+
self._device = torch.device(device)
42+
self._max_seq_length = 2048 if max_seq_length is None else max_seq_length
43+
self._use_kv_cache = use_kv_cache
44+
45+
@property
46+
def eot_token_id(self):
47+
return self._tokenizer.eos_id
48+
49+
@property
50+
def max_length(self):
51+
return self._max_seq_length
52+
53+
@property
54+
def max_gen_toks(self):
55+
return 50
56+
57+
@property
58+
def batch_size(self):
59+
return 1
60+
61+
@property
62+
def device(self):
63+
return self._device
64+
65+
def tok_encode(self, string: str, **kwargs):
66+
tokens = self._tokenizer.encode(string, bos=True, eos=False)
67+
encoded = torch.tensor(tokens, dtype=torch.int, device=self.device)
68+
# encoded is a pytorch tensor, but some internal logic in the
69+
# eval harness expects it to be a list instead
70+
# TODO: verify this for multi-batch as well
71+
encoded = encoded.tolist()
72+
return encoded
73+
74+
def tok_decode(self, tokens):
75+
decoded = self._tokenizer.decode(tokens)
76+
return decoded
77+
78+
def _model_call(self, inps):
79+
if self._use_kv_cache:
80+
pos_tensor = torch.arange(
81+
self._max_seq_length, dtype=torch.int64, device=self.device
82+
)
83+
84+
# Batch process the whole sequence.
85+
logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
86+
return logits
87+
else:
88+
return self._model(inps)
89+
90+
def _model_generate(self, context, max_length, eos_token_id):
91+
raise Exception("unimplemented")
92+
93+
@torch.no_grad()
94+
def evaluate_model(
95+
eval_wrapper: LM,
96+
tasks: Optional[list] = None,
97+
limit: Optional[int] = None,
98+
) -> dict:
99+
"""
100+
Evaluates a language model on a specified task using the lm-evaluation-harness library.
101+
102+
Args:
103+
eval_wrapper (LM): A LM wrapper class compatible with lm-evaluation-harness evaluation
104+
task (str): The name of the evaluation task to perform.
105+
limit (Optional[int]): The maximum number of samples to evaluate (None for all available).
106+
107+
Returns:
108+
eval_results (dict): A dictionary of evaluation results for the specified task(s).
109+
"""
110+
111+
if tasks is None:
112+
tasks = ["wikitext"]
113+
114+
if "hendrycks_test" in tasks:
115+
tasks.remove("hendrycks_test")
116+
tasks += list(lm_eval.tasks.hendrycks_test.create_all_tasks().keys())
117+
task_dict = get_task_dict(tasks)
118+
119+
eval_results = evaluate(
120+
eval_wrapper,
121+
task_dict,
122+
limit=limit,
123+
)
124+
return eval_results

0 commit comments

Comments
 (0)