29
29
)
30
30
31
31
32
+ class GraphModuleEvalWrapper (EagerEvalWrapper ):
33
+ """
34
+ A wrapper class for ExecuTorch py-binded integration with the
35
+ lm-evaluation-harness library.
36
+ """
37
+
38
+ def __init__ (
39
+ self ,
40
+ model : torch .fx .GraphModule ,
41
+ tokenizer : Union [SentencePieceTokenizer , Tiktoken ],
42
+ max_seq_length : Optional [int ] = None ,
43
+ use_kv_cache : bool = False ,
44
+ enable_dynamic_shape : bool = True ,
45
+ ):
46
+ super ().__init__ (
47
+ model = model , tokenizer = tokenizer , max_seq_length = max_seq_length
48
+ )
49
+ self ._model = model .to (self .device )
50
+ self ._use_kv_cache = use_kv_cache
51
+ self ._enable_dynamic_shape = enable_dynamic_shape
52
+
53
+ def _model_call (self , inps ):
54
+ if self ._use_kv_cache :
55
+ if not self ._enable_dynamic_shape :
56
+ # graph module exported without dynamic shape won't work with a different shape.
57
+ # And we have to do single token prefill here.
58
+ result_logits = []
59
+ for pos in range (inps .shape [- 1 ]):
60
+ pos_tensor = torch .tensor ([pos ], dtype = torch .int64 )
61
+ logits = self ._model (inps [:, pos : pos + 1 ], pos_tensor )
62
+ result_logits .append (logits )
63
+ return torch .cat (result_logits , dim = 1 )
64
+ else :
65
+ pos_tensor = torch .tensor ([0 ], dtype = torch .int64 , device = self .device )
66
+ # Batch process the whole sequence.
67
+ logits = self ._model (inps [:, : self ._max_seq_length ], pos_tensor )
68
+ return logits
69
+
70
+ else :
71
+ return self ._model (inps )
72
+
73
+ def _model_generate (self , context , max_length , eos_token_id ):
74
+ raise Exception ("unimplemented" )
75
+
76
+
32
77
class ETPybindEvalWrapper (EagerEvalWrapper ):
33
78
"""
34
79
A wrapper class for ExecuTorch py-binded integration with the
@@ -148,6 +193,13 @@ def gen_eval_wrapper(
148
193
if torch .cuda .is_available ()
149
194
else manager .pre_autograd_graph_module .to (device = "cpu" )
150
195
)
196
+ return GraphModuleEvalWrapper (
197
+ model = model ,
198
+ tokenizer = tokenizer ,
199
+ max_seq_length = args .max_seq_length ,
200
+ use_kv_cache = args .use_kv_cache ,
201
+ enable_dynamic_shape = args .enable_dynamic_shape ,
202
+ )
151
203
else :
152
204
# TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
153
205
# for quantizers. Currently capture_pre_autograd_graph only works with --kv_cache, but
@@ -157,13 +209,12 @@ def gen_eval_wrapper(
157
209
if torch .cuda .is_available ()
158
210
else manager .model .eval ().to (device = "cpu" )
159
211
)
160
- return EagerEvalWrapper (
161
- model = model ,
162
- tokenizer = tokenizer ,
163
- max_seq_length = args .max_seq_length ,
164
- use_kv_cache = args .use_kv_cache ,
165
- dynamic_shape = (manager .dynamic_shapes is not None ),
166
- )
212
+ return EagerEvalWrapper (
213
+ model = model ,
214
+ tokenizer = tokenizer ,
215
+ max_seq_length = args .max_seq_length ,
216
+ use_kv_cache = args .use_kv_cache ,
217
+ )
167
218
168
219
169
220
def build_args_parser () -> argparse .ArgumentParser :
0 commit comments