11
11
import logging
12
12
from enum import Enum
13
13
from typing import Any , Callable , List , Optional
14
- from executorch .extension .llm .tokenizer .utils import get_tokenizer
15
14
16
15
import torch
17
16
from executorch .backends .transforms .duplicate_dynamic_quant_chain import (
28
27
from executorch .exir .passes .sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
29
28
30
29
from executorch .extension .export_util .utils import export_to_edge , save_pte_program
30
+ from executorch .extension .llm .tokenizer .utils import get_tokenizer
31
31
from torch ._export import capture_pre_autograd_graph
32
32
from torch .ao .quantization .quantize_pt2e import convert_pt2e , prepare_pt2e
33
33
from torch .ao .quantization .quantizer import Quantizer
@@ -70,6 +70,7 @@ def __init__(
70
70
calibration_tasks : Optional [List [str ]] = None ,
71
71
calibration_limit : Optional [int ] = None ,
72
72
calibration_seq_length : Optional [int ] = None ,
73
+ calibration_data : Optional [str ] = None ,
73
74
tokenizer_path : Optional [str ] = None ,
74
75
verbose : bool = False ,
75
76
metadata : Optional [dict ] = None ,
@@ -95,6 +96,7 @@ def __init__(
95
96
self .calibration_tasks = calibration_tasks
96
97
self .calibration_limit = calibration_limit
97
98
self .calibration_seq_length = calibration_seq_length
99
+ self .calibration_data = calibration_data
98
100
self .tokenizer_path = tokenizer_path
99
101
100
102
def set_output_dir (self , output_dir : str ) -> "LLMEdgeManager" :
@@ -176,41 +178,51 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager":
176
178
)
177
179
return self
178
180
179
-
180
181
def pt2e_calibrate (
181
182
self ,
182
183
prepared_module ,
183
184
calibration_tasks ,
184
185
calibration_limit ,
185
186
calibration_seq_length ,
187
+ calibration_data ,
186
188
tokenizer_path ,
187
189
):
188
190
logging .info ("Run calibration..." )
189
191
try :
190
- from executorch .examples .models .llama2 .evaluate import EagerEvalWrapper , evaluate_model
192
+ from executorch .examples .models .llama2 .evaluate import (
193
+ EagerEvalWrapper ,
194
+ evaluate_model ,
195
+ )
191
196
except ImportError :
192
197
raise ImportError (
193
198
"Please install the llm eval dependency via examples/models/llama2/install_requirements.sh"
194
199
)
195
200
196
201
tokenizer = get_tokenizer (tokenizer_path )
197
202
198
- def calibrate_template (module : torch .fx .GraphModule , tokenizer , string : str = "Once upon a time" , max_len : int = 128 ):
199
- # TODO: change criteria & support batch inputs if necessary
200
- pos = torch .tensor (0 , dtype = torch .int64 )
201
- token_list = [tokenizer .bos_id ] + tokenizer .encode (string , bos = True , eos = False )
202
-
203
- with torch .no_grad ():
204
- while token_list [- 1 ] != tokenizer .eos_id and pos < max_len :
205
- logits = module (
206
- torch .full ((1 , 1 ), token_list [pos ]),
207
- torch .tensor ((pos , )),
208
- )
209
- pos += 1
210
- if pos >= len (token_list ):
211
- token_list .append (torch .argmax (logits [:], dim = - 1 ).item ())
203
+ def calibrate_template (
204
+ module : torch .fx .GraphModule , tokenizer , prompts : str , max_len : int
205
+ ):
206
+ # TODO: change criteria & support batch inputs if necessary
207
+ pos = torch .tensor (0 , dtype = torch .int64 )
208
+ token_list = tokenizer .encode (prompts , bos = True , eos = False )
209
+
210
+ with torch .no_grad ():
211
+ while token_list [- 1 ] != tokenizer .eos_id and pos < max_len :
212
+ logits = module (
213
+ torch .full ((1 , 1 ), token_list [pos ]),
214
+ torch .tensor ((pos ,)),
215
+ )
216
+ pos += 1
217
+ if pos >= len (token_list ):
218
+ token_list .append (torch .argmax (logits [:], dim = - 1 ).item ())
212
219
213
- calibrate_template (prepared_module , tokenizer , string = "Once upon a time" , max_len = calibration_seq_length )
220
+ calibrate_template (
221
+ module = prepared_module ,
222
+ tokenizer = tokenizer ,
223
+ prompts = calibration_data ,
224
+ max_len = calibration_seq_length ,
225
+ )
214
226
215
227
eval_wrapper = EagerEvalWrapper (
216
228
model = prepared_module .to (device = "cuda" ),
@@ -251,20 +263,26 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
251
263
self .pre_autograd_graph_module is not None
252
264
), "Please run capture_pre_autograd_graph first"
253
265
m = prepare_pt2e (self .pre_autograd_graph_module , composed_quantizer )
266
+ logging .info (
267
+ f"Calibrating with tasks: { self .calibration_tasks } , limit: { self .calibration_limit } , calibration_data: { self .calibration_data } , tokenizer_path: { self .tokenizer_path } , seq_length: { self .calibration_seq_length } "
268
+ )
254
269
# Calibrate
255
- logging .info (f"Calibrating with tasks: { self .calibration_tasks } , limit: { self .calibration_limit } , seq_length: { self .calibration_seq_length } , tokenizer_path: { self .tokenizer_path } " )
256
270
if (
257
271
self .calibration_tasks is not None
258
272
and self .calibration_limit is not None
259
273
and self .calibration_seq_length is not None
274
+ and self .calibration_data is not None
260
275
and self .tokenizer_path is not None
261
276
):
262
- logging .info (f"Calibrating with tasks: { self .calibration_tasks } , limit: { self .calibration_limit } , seq_length: { self .calibration_seq_length } " )
277
+ logging .info (
278
+ f"Calibrating with tasks: { self .calibration_tasks } , limit: { self .calibration_limit } , calibration_data: { self .calibration_data } , tokenizer_path: { self .tokenizer_path } , seq_length: { self .calibration_seq_length } "
279
+ )
263
280
self .pt2e_calibrate (
264
281
prepared_module = m ,
265
282
calibration_tasks = self .calibration_tasks ,
266
283
calibration_limit = self .calibration_limit ,
267
284
calibration_seq_length = self .calibration_seq_length ,
285
+ calibration_data = self .calibration_data ,
268
286
tokenizer_path = self .tokenizer_path ,
269
287
)
270
288
else :
0 commit comments