File tree Expand file tree Collapse file tree 3 files changed +22
-12
lines changed Expand file tree Collapse file tree 3 files changed +22
-12
lines changed Original file line number Diff line number Diff line change @@ -357,9 +357,9 @@ def _initialize_model(
357
357
_set_gguf_kwargs (builder_args , is_et = is_pte , context = "generate" )
358
358
359
359
if builder_args .dso_path :
360
- assert (
361
- quantize is None or quantize == "{ }"
362
- ), "quantize not valid for exported DSO model. Specify quantization during export."
360
+ # assert (
361
+ # quantize is None or quantize == "{ }"
362
+ # ), "quantize not valid for exported DSO model. Specify quantization during export."
363
363
364
364
t0 = time .time ()
365
365
model = _load_model (builder_args , only_config = True )
@@ -379,9 +379,9 @@ def _initialize_model(
379
379
except :
380
380
raise RuntimeError (f"Failed to load AOTI compiled { builder_args .dso_path } " )
381
381
elif builder_args .pte_path :
382
- assert (
383
- quantize is None or quantize == "{ }"
384
- ), "quantize not valid for exported PTE model. Specify quantization during export."
382
+ # assert (
383
+ # quantize is None or quantize == "{ }"
384
+ # ), "quantize not valid for exported PTE model. Specify quantization during export."
385
385
386
386
t0 = time .time ()
387
387
model = _load_model (builder_args , only_config = True )
Original file line number Diff line number Diff line change @@ -295,10 +295,16 @@ def _add_arguments_common(parser):
295
295
296
296
297
297
def arg_init (args ):
298
- if hasattr (args , ' quantize' ) and Path (args .quantize ).is_file ():
298
+ if hasattr (args , " quantize" ) and Path (args .quantize ).is_file ():
299
299
with open (args .quantize , "r" ) as f :
300
300
args .quantize = json .loads (f .read ())
301
301
302
- if hasattr (args , 'seed' ) and args .seed :
302
+ if isinstance (args .quantize , str ):
303
+ args .quantize = json .loads (args .quantize )
304
+
305
+ # if we specify dtype in quantization recipe, replicate it as args.dtype
306
+ args .dtype = args .quantize .get ("precision" , {}).get ("dtype" , args .dtype )
307
+
308
+ if hasattr (args , "seed" ) and args .seed :
303
309
torch .manual_seed (args .seed )
304
310
return args
Original file line number Diff line number Diff line change @@ -158,7 +158,7 @@ def _model_call(self, inps):
158
158
x = seq .index_select (0 , input_pos ).view (1 , - 1 )
159
159
start = time .time ()
160
160
logits = model_forward (self ._model , x , input_pos )
161
- self .times .append (time .time ()- start )
161
+ self .times .append (time .time () - start )
162
162
return logits
163
163
164
164
def _model_generate (self , context , max_length , eos_token_id ):
@@ -266,9 +266,13 @@ def main(args) -> None:
266
266
device = builder_args .device ,
267
267
)
268
268
print (f"Time to run eval: { time .time () - t1 :.02f} s." )
269
- times = torch .tensor (result ["times" ])
270
- print (f"Time in model.forward: { times .sum ():.02f} s, over { times .numel ()} model evaluations" )
271
- print (f"forward run time stats - Median: { times .median ():.02f} s Min: { times .min ():.02f} s Max: { times .max ():.02f} s" )
269
+ times = torch .tensor (result ["times" ])
270
+ print (
271
+ f"Time in model.forward: { times .sum ():.02f} s, over { times .numel ()} model evaluations"
272
+ )
273
+ print (
274
+ f"forward run time stats - Median: { times .median ():.02f} s Min: { times .min ():.02f} s Max: { times .max ():.02f} s"
275
+ )
272
276
if builder_args .dso_path :
273
277
print (f"For model { builder_args .dso_path } " )
274
278
elif builder_args .pte_path :
You can’t perform that action at this time.
0 commit comments