Skip to content

Commit c6909a7

Browse files
committed
Address comments
1 parent f5a747c commit c6909a7

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

examples/models/llama3_2_vision/text_decoder/test/test_text_decoder.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@
2121
params = {
2222
"dim": 2048,
2323
"ffn_dim_multiplier": 1.3,
24-
"fusion_interval": 4,
24+
"fusion_interval": 2,
2525
"intermediate_dim": 14336,
2626
"multiple_of": 1024,
2727
"n_heads": 32,
2828
"n_kv_heads": 8,
29-
"n_layers": 4,
29+
"n_layers": 2,
3030
"n_special_tokens": 8,
3131
"norm_eps": 1e-05,
3232
"rope_theta": 500000.0,
@@ -48,11 +48,10 @@ def _set_requires_grad_false(self, model: torch.nn.Module) -> None:
4848
for child in model.children():
4949
self._set_requires_grad_false(child)
5050

51-
def test_llama3_2_text_decoder(self) -> None:
51+
def test_llama3_2_text_decoder_aoti(self) -> None:
5252
with tempfile.NamedTemporaryFile(mode="w") as param_file:
5353
json.dump(params, param_file, indent=2)
5454
param_file.flush()
55-
print(param_file.name)
5655
model = Llama3_2Decoder(
5756
encoder_max_seq_len=6404,
5857
generate_full_logits=True,
@@ -79,7 +78,6 @@ def test_llama3_2_text_decoder(self) -> None:
7978
kwargs=model.get_example_kwarg_inputs(),
8079
package_path=os.path.join(tmpdir, "text_decoder.pt2"),
8180
)
82-
print(path)
8381
encoder_aoti = torch._inductor.aoti_load_package(path)
8482

8583
y = encoder_aoti(

0 commit comments

Comments
 (0)