Skip to content

Commit 03779eb

Browse files
committed
Fix vision model example input
1 parent f9c001a commit 03779eb

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

examples/models/llama3_2_vision/model.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# pyre-unsafe
88

99
import json
10-
from typing import Any, Dict, Tuple
10+
from typing import Any, Dict
1111

1212
import torch
1313

@@ -122,17 +122,20 @@ def get_eager_model(self) -> torch.nn.Module:
122122
else:
123123
return self.model_.to(torch.float16)
124124

125-
def get_example_inputs(self) -> Tuple[Tuple, Dict]:
125+
def get_example_inputs(self):
126126
return (
127-
(torch.ones(1, 64, dtype=torch.long),), # positional inputs
128-
{
129-
# "mask": None,
130-
# "encoder_input": None,
131-
# "encoder_mask": None,
132-
# "input_pos": torch.ones(64, dtype=torch.long),
133-
} # kwarg inputs
127+
torch.ones(1, 64, dtype=torch.long), # positional inputs
134128
)
135129

130+
def get_example_kwarg_inputs(self):
131+
# TODO: add input_pos and mask when after making cache work.
132+
return {
133+
# "mask": None,
134+
# "encoder_input": None,
135+
# "encoder_mask": None,
136+
# "input_pos": torch.ones(64, dtype=torch.long),
137+
}
138+
136139
def get_dynamic_shapes(self):
137140
dim = torch.export.Dim("token_dim", min=1,max=self.max_seq_len)
138141
dynamic_shapes = {

0 commit comments

Comments
 (0)