Skip to content

Commit 99d5bfb

Browse files
committed
Accept model type parameter in export_llama
1 parent c9bbe12 commit 99d5bfb

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979

8080

8181
EXECUTORCH_DEFINED_MODELS = ["llama2", "llama3", "llama3_1", "llama3_2"]
82-
TORCHTUNE_DEFINED_MODELS = []
82+
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"]
8383

8484

8585
class WeightType(Enum):
@@ -800,13 +800,12 @@ def _load_llama_model(
800800
modelname = "llama2"
801801
model_class_name = "Llama2Model"
802802
elif modelname in TORCHTUNE_DEFINED_MODELS:
803-
raise NotImplementedError(
804-
"Torchtune Llama models are not yet supported in ExecuTorch export."
805-
)
803+
if modelname == "llama3_2_vision":
804+
model_class_name = "Llama3_2Decoder"
806805
else:
807806
raise ValueError(f"{modelname} is not a valid Llama model.")
808807

809-
model, example_inputs, example_kwarg_inputs, _ = (
808+
model, example_inputs, example_kwarg_inputs, dynamic_shapes = (
810809
EagerModelFactory.create_model(
811810
modelname,
812811
model_class_name,

0 commit comments

Comments
 (0)