27
27
from executorch .exir .passes .sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
28
28
29
29
from executorch .extension .export_util .utils import export_to_edge , save_pte_program
30
+ from executorch .extension .llm .tokenizer .utils import get_tokenizer
30
31
from torch ._export import capture_pre_autograd_graph
31
32
from torch .ao .quantization .quantize_pt2e import convert_pt2e , prepare_pt2e
32
33
from torch .ao .quantization .quantizer import Quantizer
@@ -66,6 +67,11 @@ def __init__(
66
67
use_kv_cache ,
67
68
example_inputs ,
68
69
enable_dynamic_shape : bool = False ,
70
+ calibration_tasks : Optional [List [str ]] = None ,
71
+ calibration_limit : Optional [int ] = None ,
72
+ calibration_seq_length : Optional [int ] = None ,
73
+ calibration_data : Optional [str ] = None ,
74
+ tokenizer_path : Optional [str ] = None ,
69
75
verbose : bool = False ,
70
76
metadata : Optional [dict ] = None ,
71
77
dynamic_shapes : Optional [Any ] = None ,
@@ -87,6 +93,11 @@ def __init__(
87
93
self .output_dir = "."
88
94
self .dynamic_shapes = dynamic_shapes
89
95
self ._saved_pte_filename = None
96
+ self .calibration_tasks = calibration_tasks
97
+ self .calibration_limit = calibration_limit
98
+ self .calibration_seq_length = calibration_seq_length
99
+ self .calibration_data = calibration_data
100
+ self .tokenizer_path = tokenizer_path
90
101
91
102
def set_output_dir (self , output_dir : str ) -> "LLMEdgeManager" :
92
103
"""
@@ -167,6 +178,69 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager":
167
178
)
168
179
return self
169
180
181
+ def pt2e_calibrate (
182
+ self ,
183
+ prepared_module ,
184
+ calibration_tasks ,
185
+ calibration_limit ,
186
+ calibration_seq_length ,
187
+ calibration_data ,
188
+ tokenizer_path ,
189
+ ):
190
+ logging .info ("Run calibration..." )
191
+ try :
192
+ from executorch .examples .models .llama2 .eval_llama_lib import (
193
+ GraphModuleEvalWrapper ,
194
+ )
195
+ from executorch .examples .models .llama2 .evaluate import evaluate_model
196
+ except ImportError :
197
+ raise ImportError (
198
+ "Please install the llm eval dependency via examples/models/llama2/install_requirements.sh"
199
+ )
200
+
201
+ tokenizer = get_tokenizer (tokenizer_path )
202
+
203
+ def calibrate_template (
204
+ module : torch .fx .GraphModule , tokenizer , prompts : str , max_len : int
205
+ ):
206
+ # TODO: change criteria & support batch inputs if necessary
207
+ pos = torch .tensor (0 , dtype = torch .int64 )
208
+ token_list = tokenizer .encode (prompts , bos = True , eos = False )
209
+
210
+ with torch .no_grad ():
211
+ while token_list [- 1 ] != tokenizer .eos_id and pos < max_len :
212
+ logits = module (
213
+ torch .full ((1 , 1 ), token_list [pos ]),
214
+ torch .tensor ((pos ,)),
215
+ )
216
+ pos += 1
217
+ if pos >= len (token_list ):
218
+ token_list .append (torch .argmax (logits [:], dim = - 1 ).item ())
219
+
220
+ calibrate_template (
221
+ module = prepared_module ,
222
+ tokenizer = tokenizer ,
223
+ prompts = calibration_data ,
224
+ max_len = calibration_seq_length ,
225
+ )
226
+
227
+ eval_wrapper = GraphModuleEvalWrapper (
228
+ model = prepared_module ,
229
+ tokenizer = tokenizer ,
230
+ max_seq_length = calibration_seq_length ,
231
+ use_kv_cache = self .use_kv_cache ,
232
+ enable_dynamic_shape = self .enable_dynamic_shape ,
233
+ )
234
+ eval_results = evaluate_model (
235
+ eval_wrapper ,
236
+ calibration_tasks ,
237
+ calibration_limit ,
238
+ )
239
+
240
+ for task , res in eval_results ["results" ].items ():
241
+ print (f"{ task } : { res } " )
242
+ logging .info ("Calibration finish..." )
243
+
170
244
def pt2e_quantize (self , quantizers : Optional [List [Quantizer ]]) -> "LLMEdgeManager" :
171
245
"""
172
246
Quantize the model via pt2e flow and retrieve LLMEdgeManager including the quantized model.
@@ -189,8 +263,33 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
189
263
self .pre_autograd_graph_module is not None
190
264
), "Please run capture_pre_autograd_graph first"
191
265
m = prepare_pt2e (self .pre_autograd_graph_module , composed_quantizer )
266
+ logging .info (
267
+ f"Calibrating with tasks: { self .calibration_tasks } , limit: { self .calibration_limit } , calibration_data: { self .calibration_data } , tokenizer_path: { self .tokenizer_path } , seq_length: { self .calibration_seq_length } "
268
+ )
192
269
# Calibrate
193
- m (* self .example_inputs )
270
+ if (
271
+ self .calibration_tasks is not None
272
+ and self .calibration_limit is not None
273
+ and self .calibration_seq_length is not None
274
+ and self .calibration_data is not None
275
+ and self .tokenizer_path is not None
276
+ ):
277
+ logging .info (
278
+ f"Calibrating with tasks: { self .calibration_tasks } , limit: { self .calibration_limit } , calibration_data: { self .calibration_data } , tokenizer_path: { self .tokenizer_path } , seq_length: { self .calibration_seq_length } "
279
+ )
280
+ self .pt2e_calibrate (
281
+ prepared_module = m ,
282
+ calibration_tasks = self .calibration_tasks ,
283
+ calibration_limit = self .calibration_limit ,
284
+ calibration_seq_length = self .calibration_seq_length ,
285
+ calibration_data = self .calibration_data ,
286
+ tokenizer_path = self .tokenizer_path ,
287
+ )
288
+ else :
289
+ logging .info (
290
+ "No calibration provided, using dummy input to calibrate..."
291
+ )
292
+ m (* self .example_inputs )
194
293
m = convert_pt2e (m )
195
294
DuplicateDynamicQuantChainPass ()(m )
196
295
self .pre_autograd_graph_module = m
0 commit comments