Skip to content

Commit 73d6e25

Browse files
committed
[llama-mm] Make text decoder exportable
Summary: Adds source transformation and changes example input to make text decoder exportable. Test Plan: Added new unit test. Reviewers: Subscribers: Tasks: Tags:
1 parent e721945 commit 73d6e25

File tree

4 files changed

+115
-13
lines changed

4 files changed

+115
-13
lines changed

examples/models/llama3_2_vision/text_decoder/model.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818

1919
from executorch.examples.models.model_base import EagerModelBase
20+
from executorch.extension.llm.modules.attention import replace_mha_with_inference_mha
2021
from torchtune.models.llama3_2_vision._component_builders import llama3_2_vision_decoder
2122
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune
2223

@@ -53,7 +54,7 @@ def __init__(self, **kwargs):
5354
self.use_kv_cache = kwargs.get("use_kv_cache", False)
5455
self.verbose = kwargs.get("verbose", False)
5556
self.args = kwargs.get("args", None)
56-
self.dtype = None
57+
self.dtype = kwargs.get("dtype", torch.float16)
5758
self.use_checkpoint = False
5859

5960
ckpt_dir = get_default_model_resource_dir(__file__)
@@ -72,7 +73,7 @@ def __init__(self, **kwargs):
7273
dtype=torch.bool,
7374
)
7475
)
75-
self.input_pos = torch.arange(self.max_seq_len)
76+
self.input_pos = torch.arange(self.max_seq_len, dtype=torch.int64)
7677

7778
# Load checkpoint and params.
7879
device = "cpu"
@@ -107,6 +108,9 @@ def __init__(self, **kwargs):
107108
rope_base=params["rope_theta"],
108109
intermediate_dim=params["intermediate_dim"],
109110
)
111+
112+
# Source transformation for MultiHeadAttention
113+
self.model_ = replace_mha_with_inference_mha(self.model_)
110114
# Save params for future use.
111115
for param_name, param_val in params.items():
112116
setattr(self.model_, param_name, param_val)
@@ -147,39 +151,46 @@ def __init__(self, **kwargs):
147151
self.model_.setup_caches(
148152
batch_size=1,
149153
dtype=self.dtype,
154+
encoder_max_seq_len=self.encoder_max_seq_len,
150155
decoder_max_seq_len=self.max_seq_len,
151156
)
157+
# number of tokens for example input
158+
self.n_tokens = 34
159+
self.model_.to(self.dtype)
152160

153161
def get_eager_model(self) -> torch.nn.Module:
154-
if self.dtype:
155-
return self.model_.to(self.dtype)
156-
else:
157-
return self.model_.to(torch.float16)
162+
return self.model_
158163

159164
def get_example_inputs(self):
160-
return (torch.ones(1, 32, dtype=torch.long),)
165+
return (torch.ones(1, self.n_tokens, dtype=torch.int64),)
161166

162167
def get_example_kwarg_inputs(self):
163168
# For export we must use the prefill versions of the
164169
# causal mask and input_pos.
170+
# Hardcoding # of tiles to be 2. image tokens per tile is 1601.
165171
if self.use_kv_cache:
166172
return {
167-
"input_pos": self.input_pos[None, :32],
168-
"mask": self.causal_mask[None, :32],
169-
# "encoder_input": None,
170-
# "encoder_mask": None,
173+
"input_pos": self.input_pos[None, : self.n_tokens],
174+
"mask": self.causal_mask[None, : self.n_tokens],
175+
"encoder_input": torch.randn(
176+
1, self.encoder_max_seq_len, self.model_.dim, dtype=self.dtype
177+
),
178+
"encoder_mask": torch.ones(
179+
[1, self.n_tokens, self.encoder_max_seq_len], dtype=torch.bool
180+
),
171181
}
172182
else:
173183
return None
174184

175185
def get_dynamic_shapes(self):
176186
batch_size = 1
177187
dim_seq_len = torch.export.Dim("token_dim", min=1, max=self.max_seq_len)
188+
# Hardcoding # of tiles to be 2. image tokens per tile is 1601.
178189
if self.use_kv_cache:
179190
dynamic_shapes = {
180191
"tokens": {0: batch_size, 1: dim_seq_len},
181-
# "encoder_input": {0: 1, 1: dim_enc, 2: 4096},
182-
# "encoder_mask": {0: 1, 1: dim, 2: dim_enc},
192+
"encoder_input": None,
193+
"encoder_mask": {0: 1, 1: dim_seq_len, 2: None},
183194
"mask": {0: batch_size, 1: dim_seq_len, 2: None},
184195
"input_pos": {0: batch_size, 1: dim_seq_len},
185196
}

examples/models/llama3_2_vision/text_decoder/test/__init__.py

Whitespace-only changes.
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Export and ExecuTorch tests for CLIP vision encoder are covered by test_models.sh.
8+
# Only test AOTI in this file
9+
import json
10+
import os
11+
import tempfile
12+
import unittest
13+
14+
import torch
15+
from torch.testing import assert_close
16+
17+
from ..model import Llama3_2Decoder
18+
19+
params = {
20+
"dim": 4096,
21+
"ffn_dim_multiplier": 1.3,
22+
"fusion_interval": 4,
23+
"intermediate_dim": 14336,
24+
"multiple_of": 1024,
25+
"n_heads": 32,
26+
"n_kv_heads": 8,
27+
"n_layers": 4,
28+
"n_special_tokens": 8,
29+
"norm_eps": 1e-05,
30+
"rope_theta": 500000.0,
31+
"use_scaled_rope": True,
32+
"vision_chunk_size": 560,
33+
"vision_max_num_chunks": 4,
34+
"vocab_size": 128256,
35+
"vision_num_cross_attention_layers": 1,
36+
}
37+
38+
39+
class TextDecoderTest(unittest.TestCase):
40+
def setUp(self) -> None:
41+
super().setUp()
42+
43+
def _set_requires_grad_false(self, model: torch.nn.Module) -> None:
44+
for param in model.parameters():
45+
param.requires_grad = False
46+
for child in model.children():
47+
self._set_requires_grad_false(child)
48+
49+
def test_llama3_2_text_decoder(self) -> None:
50+
with tempfile.NamedTemporaryFile(mode="w") as param_file:
51+
json.dump(params, param_file, indent=2)
52+
param_file.flush()
53+
print(param_file.name)
54+
model = Llama3_2Decoder(
55+
encoder_max_seq_len=6404,
56+
generate_full_logits=True,
57+
enable_dynamic_shape=True,
58+
use_kv_cache=True,
59+
params=param_file.name,
60+
dtype=torch.float32,
61+
)
62+
encoder = model.get_eager_model().eval()
63+
self._set_requires_grad_false(encoder)
64+
65+
# AOTI
66+
with torch.no_grad(), torch.inference_mode():
67+
ep = torch.export.export(
68+
encoder,
69+
model.get_example_inputs(),
70+
kwargs=model.get_example_kwarg_inputs(),
71+
dynamic_shapes=model.get_dynamic_shapes(),
72+
)
73+
with tempfile.TemporaryDirectory() as tmpdir:
74+
path = torch._inductor.aoti_compile_and_package(
75+
ep,
76+
model.get_example_inputs(),
77+
kwargs=model.get_example_kwarg_inputs(),
78+
package_path=os.path.join(tmpdir, "text_decoder.pt2"),
79+
)
80+
print(path)
81+
encoder_aoti = torch._inductor.aoti_load_package(path)
82+
83+
y = encoder_aoti(
84+
*model.get_example_inputs(), **model.get_example_kwarg_inputs()
85+
)
86+
87+
eager_res = encoder.forward(
88+
*model.get_example_inputs(), **model.get_example_kwarg_inputs()
89+
)
90+
assert_close(y, eager_res, rtol=1e-4, atol=1e-4)

pytest.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ addopts =
1818
examples/models/llama/tests
1919
examples/models/llama3_2_vision/preprocess
2020
examples/models/llama3_2_vision/vision_encoder/test
21+
examples/models/llama3_2_vision/text_decoder/test
2122
# examples/models/llava/test TODO: enable this
2223
# exir
2324
exir/_serialize/test

0 commit comments

Comments
 (0)