Skip to content

Commit c9bbe12

Browse files
committed
Create create new method for example kwarg inputs instead
1 parent ec80bba commit c9bbe12

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -806,20 +806,22 @@ def _load_llama_model(
806806
else:
807807
raise ValueError(f"{modelname} is not a valid Llama model.")
808808

809-
model, example_inputs, example_kwarg_inputs, _ = EagerModelFactory.create_model(
810-
modelname,
811-
model_class_name,
812-
checkpoint=checkpoint,
813-
checkpoint_dir=checkpoint_dir,
814-
params=params_path,
815-
use_kv_cache=use_kv_cache,
816-
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
817-
generate_full_logits=generate_full_logits,
818-
fairseq2=weight_type == WeightType.FAIRSEQ2,
819-
max_seq_len=max_seq_len,
820-
enable_dynamic_shape=enable_dynamic_shape,
821-
output_prune_map_path=output_prune_map_path,
822-
args=args,
809+
model, example_inputs, example_kwarg_inputs, _ = (
810+
EagerModelFactory.create_model(
811+
modelname,
812+
model_class_name,
813+
checkpoint=checkpoint,
814+
checkpoint_dir=checkpoint_dir,
815+
params=params_path,
816+
use_kv_cache=use_kv_cache,
817+
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
818+
generate_full_logits=generate_full_logits,
819+
fairseq2=weight_type == WeightType.FAIRSEQ2,
820+
max_seq_len=max_seq_len,
821+
enable_dynamic_shape=enable_dynamic_shape ,
822+
output_prune_map_path=output_prune_map_path,
823+
args=args,
824+
)
823825
)
824826
if dtype_override:
825827
assert isinstance(

0 commit comments

Comments
 (0)