@@ -322,8 +322,8 @@ def build_args_parser() -> argparse.ArgumentParser:
322
322
default = "fp32" ,
323
323
type = str ,
324
324
choices = ["fp32" , "fp16" , "bf16" ],
325
- help = "Override the dtype of the model (default is the checkpoint dtype) ."
326
- "Options: fp32, fp16, bf16. Please be aware that only some backends support fp16 and bf16." ,
325
+ help = "Provide the dtype of the model. This must match up with the supported dtypes of the backends that you are using ."
326
+ "Please be aware that only some backends support fp16 and bf16." ,
327
327
)
328
328
329
329
parser .add_argument (
@@ -565,43 +565,42 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
565
565
output_dir_path = canonical_path (args .output_dir , dir = True )
566
566
weight_type = WeightType .FAIRSEQ2 if args .fairseq2 else WeightType .LLAMA
567
567
568
- # dtype override
569
- if args .dtype_override is not None :
570
- dtype_override = DType [args .dtype_override ]
571
- elif args .quantization_mode in ["8da4w" , "8da4w-gptq" ]:
572
- dtype_override = DType ["fp16" ]
573
- else :
574
- dtype_override = None
568
+ # Convert dtype override string arg to actual type.
569
+ dtype_override = DType [args .dtype_override ]
570
+
571
+ edge_manager = _load_llama_model (
572
+ args .model ,
573
+ checkpoint = checkpoint_path ,
574
+ checkpoint_dir = checkpoint_dir ,
575
+ params_path = params_path ,
576
+ use_kv_cache = args .use_kv_cache ,
577
+ use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache ,
578
+ generate_full_logits = args .generate_full_logits ,
579
+ weight_type = weight_type ,
580
+ enable_dynamic_shape = args .enable_dynamic_shape ,
581
+ calibration_tasks = args .calibration_tasks ,
582
+ calibration_limit = args .calibration_limit ,
583
+ calibration_seq_length = args .calibration_seq_length ,
584
+ calibration_data = args .calibration_data ,
585
+ tokenizer_path = args .tokenizer_path ,
586
+ verbose = args .verbose ,
587
+ max_seq_len = args .max_seq_length ,
588
+ max_context_len = args .max_context_length ,
589
+ input_prune_map_path = args .input_prune_map ,
590
+ output_prune_map_path = args .output_prune_map ,
591
+ metadata_str = args .metadata ,
592
+ dtype_override = dtype_override ,
593
+ args = args ,
594
+ )
575
595
576
- return (
577
- _load_llama_model (
578
- args .model ,
579
- checkpoint = checkpoint_path ,
580
- checkpoint_dir = checkpoint_dir ,
581
- params_path = params_path ,
582
- use_kv_cache = args .use_kv_cache ,
583
- use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache ,
584
- generate_full_logits = args .generate_full_logits ,
585
- weight_type = weight_type ,
586
- enable_dynamic_shape = args .enable_dynamic_shape ,
587
- calibration_tasks = args .calibration_tasks ,
588
- calibration_limit = args .calibration_limit ,
589
- calibration_seq_length = args .calibration_seq_length ,
590
- calibration_data = args .calibration_data ,
591
- tokenizer_path = args .tokenizer_path ,
592
- verbose = args .verbose ,
593
- max_seq_len = args .max_seq_length ,
594
- max_context_len = args .max_context_length ,
595
- input_prune_map_path = args .input_prune_map ,
596
- output_prune_map_path = args .output_prune_map ,
597
- metadata_str = args .metadata ,
598
- dtype_override = dtype_override ,
599
- args = args ,
600
- )
601
- .set_output_dir (output_dir_path )
602
- .source_transform (_get_source_transforms (args .model , dtype_override , args ))
596
+ # At this point, the model is loaded in the default fp32.
597
+ edge_manager .model = edge_manager .model .to (dtype = dtype_override .to_torch_dtype ())
598
+ edge_manager .set_output_dir (output_dir_path ).source_transform (
599
+ _get_source_transforms (args .model , dtype_override , args )
603
600
)
604
601
602
+ return edge_manager
603
+
605
604
606
605
def get_quantizer_and_quant_params (args ):
607
606
pt2e_quant_params = get_pt2e_quantization_params (
@@ -1006,6 +1005,8 @@ def _load_llama_model(
1006
1005
else :
1007
1006
raise ValueError (f"{ modelname } is not a valid Llama model." )
1008
1007
1008
+ torch_dtype = dtype_override .to_torch_dtype () if dtype_override else None
1009
+
1009
1010
model , example_inputs , example_kwarg_inputs , dynamic_shapes = (
1010
1011
EagerModelFactory .create_model (
1011
1012
module_name ,
@@ -1022,41 +1023,16 @@ def _load_llama_model(
1022
1023
enable_dynamic_shape = enable_dynamic_shape ,
1023
1024
input_prune_map_path = input_prune_map_path ,
1024
1025
output_prune_map_path = output_prune_map_path ,
1026
+ dtype = torch_dtype ,
1025
1027
args = args ,
1026
1028
)
1027
1029
)
1028
- if dtype_override :
1029
- assert isinstance (
1030
- dtype_override , DType
1031
- ), "Override dtype needs to be of type <DType>"
1032
- torch_dtype = dtype_override .to_torch_dtype ()
1033
- logging .info (f"model.to { torch_dtype } " )
1034
- model = model .to (dtype = torch_dtype )
1035
- dtype = dtype_override
1036
- else :
1037
- state_dict = model .state_dict ()
1038
- dtype = state_dict [next (iter (state_dict ))].dtype
1039
- assert dtype in [
1040
- torch .bfloat16 ,
1041
- torch .float16 ,
1042
- torch .float32 ,
1043
- ], f"Only support bfloat16, fp16 or fp32 got { dtype } "
1044
- logging .info (f"Loaded model with dtype={ dtype } " )
1045
-
1046
- if dtype == torch .bfloat16 :
1047
- dtype = DType .bf16
1048
- elif dtype == torch .float16 :
1049
- dtype = DType .fp16
1050
- elif dtype == torch .float32 :
1051
- dtype = DType .fp32
1052
- else :
1053
- raise ValueError (f"Unsupported dtype { dtype } " )
1054
1030
1055
1031
return LLMEdgeManager (
1056
1032
model = model ,
1057
1033
modelname = modelname ,
1058
1034
max_seq_len = model .max_seq_len ,
1059
- dtype = dtype ,
1035
+ dtype = dtype_override ,
1060
1036
use_kv_cache = use_kv_cache ,
1061
1037
generate_full_logits = generate_full_logits ,
1062
1038
example_inputs = example_inputs ,
0 commit comments