@@ -806,20 +806,22 @@ def _load_llama_model(
806
806
else :
807
807
raise ValueError (f"{ modelname } is not a valid Llama model." )
808
808
809
- model , example_inputs , example_kwarg_inputs , _ = EagerModelFactory .create_model (
810
- modelname ,
811
- model_class_name ,
812
- checkpoint = checkpoint ,
813
- checkpoint_dir = checkpoint_dir ,
814
- params = params_path ,
815
- use_kv_cache = use_kv_cache ,
816
- use_sdpa_with_kv_cache = use_sdpa_with_kv_cache ,
817
- generate_full_logits = generate_full_logits ,
818
- fairseq2 = weight_type == WeightType .FAIRSEQ2 ,
819
- max_seq_len = max_seq_len ,
820
- enable_dynamic_shape = enable_dynamic_shape ,
821
- output_prune_map_path = output_prune_map_path ,
822
- args = args ,
809
+ model , example_inputs , example_kwarg_inputs , _ = (
810
+ EagerModelFactory .create_model (
811
+ modelname ,
812
+ model_class_name ,
813
+ checkpoint = checkpoint ,
814
+ checkpoint_dir = checkpoint_dir ,
815
+ params = params_path ,
816
+ use_kv_cache = use_kv_cache ,
817
+ use_sdpa_with_kv_cache = use_sdpa_with_kv_cache ,
818
+ generate_full_logits = generate_full_logits ,
819
+ fairseq2 = weight_type == WeightType .FAIRSEQ2 ,
820
+ max_seq_len = max_seq_len ,
821
+ enable_dynamic_shape = enable_dynamic_shape ,
822
+ output_prune_map_path = output_prune_map_path ,
823
+ args = args ,
824
+ )
823
825
)
824
826
if dtype_override :
825
827
assert isinstance (
0 commit comments