Skip to content

Commit 261d36e

Browse files
committed
Accept model type parameter in export_llama
1 parent 5134c42 commit 261d36e

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878

7979

8080
EXECUTORCH_DEFINED_MODELS = ["llama2", "llama3", "llama3_1", "llama3_2"]
81-
TORCHTUNE_DEFINED_MODELS = []
81+
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"]
8282

8383

8484
class WeightType(Enum):
@@ -798,11 +798,12 @@ def _load_llama_model(
798798
modelname = "llama2"
799799
model_class_name = "Llama2Model"
800800
elif modelname in TORCHTUNE_DEFINED_MODELS:
801-
raise NotImplementedError("Torchtune Llama models are not yet supported in ExecuTorch export.")
801+
if modelname == "llama3_2_vision":
802+
model_class_name = "Llama3_2Decoder"
802803
else:
803804
raise ValueError(f"{modelname} is not a valid Llama model.")
804805

805-
model, example_inputs, example_kwarg_inputs, _ = (
806+
model, example_inputs, example_kwarg_inputs, dynamic_shapes = (
806807
EagerModelFactory.create_model(
807808
modelname,
808809
model_class_name,

0 commit comments

Comments
 (0)