@@ -800,26 +800,26 @@ def _load_llama_model(
800
800
modelname = "llama2"
801
801
model_class_name = "Llama2Model"
802
802
elif modelname in TORCHTUNE_DEFINED_MODELS :
803
- raise NotImplementedError ("Torchtune Llama models are not yet supported in ExecuTorch export." )
803
+ raise NotImplementedError (
804
+ "Torchtune Llama models are not yet supported in ExecuTorch export."
805
+ )
804
806
else :
805
807
raise ValueError (f"{ modelname } is not a valid Llama model." )
806
808
807
- model , example_inputs , example_kwarg_inputs , _ = (
808
- EagerModelFactory .create_model (
809
- modelname ,
810
- model_class_name ,
811
- checkpoint = checkpoint ,
812
- checkpoint_dir = checkpoint_dir ,
813
- params = params_path ,
814
- use_kv_cache = use_kv_cache ,
815
- use_sdpa_with_kv_cache = use_sdpa_with_kv_cache ,
816
- generate_full_logits = generate_full_logits ,
817
- fairseq2 = weight_type == WeightType .FAIRSEQ2 ,
818
- max_seq_len = max_seq_len ,
819
- enable_dynamic_shape = enable_dynamic_shape ,
820
- output_prune_map_path = output_prune_map_path ,
821
- args = args ,
822
- )
809
+ model , example_inputs , example_kwarg_inputs , _ = EagerModelFactory .create_model (
810
+ modelname ,
811
+ model_class_name ,
812
+ checkpoint = checkpoint ,
813
+ checkpoint_dir = checkpoint_dir ,
814
+ params = params_path ,
815
+ use_kv_cache = use_kv_cache ,
816
+ use_sdpa_with_kv_cache = use_sdpa_with_kv_cache ,
817
+ generate_full_logits = generate_full_logits ,
818
+ fairseq2 = weight_type == WeightType .FAIRSEQ2 ,
819
+ max_seq_len = max_seq_len ,
820
+ enable_dynamic_shape = enable_dynamic_shape ,
821
+ output_prune_map_path = output_prune_map_path ,
822
+ args = args ,
823
823
)
824
824
if dtype_override :
825
825
assert isinstance (
0 commit comments