27
27
from executorch .exir .passes .sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
28
28
29
29
from executorch .extension .export_util .utils import export_to_edge , save_pte_program
30
- from executorch .extension .llm .tokenizer .utils import get_tokenizer
31
30
from torch ._export import capture_pre_autograd_graph
32
31
from torch .ao .quantization .quantize_pt2e import convert_pt2e , prepare_pt2e
33
32
from torch .ao .quantization .quantizer import Quantizer
@@ -67,11 +66,6 @@ def __init__(
67
66
use_kv_cache ,
68
67
example_inputs ,
69
68
enable_dynamic_shape : bool = False ,
70
- calibration_tasks : Optional [List [str ]] = None ,
71
- calibration_limit : Optional [int ] = None ,
72
- calibration_seq_length : Optional [int ] = None ,
73
- calibration_data : Optional [str ] = None ,
74
- tokenizer_path : Optional [str ] = None ,
75
69
verbose : bool = False ,
76
70
metadata : Optional [dict ] = None ,
77
71
dynamic_shapes : Optional [Any ] = None ,
@@ -93,11 +87,6 @@ def __init__(
93
87
self .output_dir = "."
94
88
self .dynamic_shapes = dynamic_shapes
95
89
self ._saved_pte_filename = None
96
- self .calibration_tasks = calibration_tasks
97
- self .calibration_limit = calibration_limit
98
- self .calibration_seq_length = calibration_seq_length
99
- self .calibration_data = calibration_data
100
- self .tokenizer_path = tokenizer_path
101
90
102
91
def set_output_dir (self , output_dir : str ) -> "LLMEdgeManager" :
103
92
"""
@@ -178,69 +167,6 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager":
178
167
)
179
168
return self
180
169
181
- def pt2e_calibrate (
182
- self ,
183
- prepared_module ,
184
- calibration_tasks ,
185
- calibration_limit ,
186
- calibration_seq_length ,
187
- calibration_data ,
188
- tokenizer_path ,
189
- ):
190
- logging .info ("Run calibration..." )
191
- try :
192
- from executorch .examples .models .llama2 .eval_llama_lib import (
193
- GraphModuleEvalWrapper ,
194
- )
195
- from executorch .examples .models .llama2 .evaluate import evaluate_model
196
- except ImportError :
197
- raise ImportError (
198
- "Please install the llm eval dependency via examples/models/llama2/install_requirements.sh"
199
- )
200
-
201
- tokenizer = get_tokenizer (tokenizer_path )
202
-
203
- def calibrate_template (
204
- module : torch .fx .GraphModule , tokenizer , prompts : str , max_len : int
205
- ):
206
- # TODO: change criteria & support batch inputs if necessary
207
- pos = torch .tensor (0 , dtype = torch .int64 )
208
- token_list = tokenizer .encode (prompts , bos = True , eos = False )
209
-
210
- with torch .no_grad ():
211
- while token_list [- 1 ] != tokenizer .eos_id and pos < max_len :
212
- logits = module (
213
- torch .full ((1 , 1 ), token_list [pos ]),
214
- torch .tensor ((pos ,)),
215
- )
216
- pos += 1
217
- if pos >= len (token_list ):
218
- token_list .append (torch .argmax (logits [:], dim = - 1 ).item ())
219
-
220
- calibrate_template (
221
- module = prepared_module ,
222
- tokenizer = tokenizer ,
223
- prompts = calibration_data ,
224
- max_len = calibration_seq_length ,
225
- )
226
-
227
- eval_wrapper = GraphModuleEvalWrapper (
228
- model = prepared_module ,
229
- tokenizer = tokenizer ,
230
- max_seq_length = calibration_seq_length ,
231
- use_kv_cache = self .use_kv_cache ,
232
- enable_dynamic_shape = self .enable_dynamic_shape ,
233
- )
234
- eval_results = evaluate_model (
235
- eval_wrapper ,
236
- calibration_tasks ,
237
- calibration_limit ,
238
- )
239
-
240
- for task , res in eval_results ["results" ].items ():
241
- print (f"{ task } : { res } " )
242
- logging .info ("Calibration finish..." )
243
-
244
170
def pt2e_quantize (self , quantizers : Optional [List [Quantizer ]]) -> "LLMEdgeManager" :
245
171
"""
246
172
Quantize the model via pt2e flow and retrieve LLMEdgeManager including the quantized model.
@@ -263,33 +189,8 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
263
189
self .pre_autograd_graph_module is not None
264
190
), "Please run capture_pre_autograd_graph first"
265
191
m = prepare_pt2e (self .pre_autograd_graph_module , composed_quantizer )
266
- logging .info (
267
- f"Calibrating with tasks: { self .calibration_tasks } , limit: { self .calibration_limit } , calibration_data: { self .calibration_data } , tokenizer_path: { self .tokenizer_path } , seq_length: { self .calibration_seq_length } "
268
- )
269
192
# Calibrate
270
- if (
271
- self .calibration_tasks is not None
272
- and self .calibration_limit is not None
273
- and self .calibration_seq_length is not None
274
- and self .calibration_data is not None
275
- and self .tokenizer_path is not None
276
- ):
277
- logging .info (
278
- f"Calibrating with tasks: { self .calibration_tasks } , limit: { self .calibration_limit } , calibration_data: { self .calibration_data } , tokenizer_path: { self .tokenizer_path } , seq_length: { self .calibration_seq_length } "
279
- )
280
- self .pt2e_calibrate (
281
- prepared_module = m ,
282
- calibration_tasks = self .calibration_tasks ,
283
- calibration_limit = self .calibration_limit ,
284
- calibration_seq_length = self .calibration_seq_length ,
285
- calibration_data = self .calibration_data ,
286
- tokenizer_path = self .tokenizer_path ,
287
- )
288
- else :
289
- logging .info (
290
- "No calibration provided, using dummy input to calibrate..."
291
- )
292
- m (* self .example_inputs )
193
+ m (* self .example_inputs )
293
194
m = convert_pt2e (m )
294
195
DuplicateDynamicQuantChainPass ()(m )
295
196
self .pre_autograd_graph_module = m
0 commit comments