Skip to content

Commit 8d71cd3

Browse files
authored
[llama-mm] Make text decoder exportable (#6999)
* [llama-mm] Make text decoder exportable Summary: Adds source transformation and changes example input to make text decoder exportable. Test Plan: Added new unit test. Reviewers: Subscribers: Tasks: Tags: * Make the test model smaller Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Ignore e2e test Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Do not run e2e test Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Address comments
1 parent 2adb1bc commit 8d71cd3

File tree

4 files changed

+117
-15
lines changed

4 files changed

+117
-15
lines changed

examples/models/llama3_2_vision/text_decoder/model.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818

1919
from executorch.examples.models.model_base import EagerModelBase
20+
from executorch.extension.llm.modules.attention import replace_mha_with_inference_mha
2021
from torchtune.models.llama3_2_vision._component_builders import llama3_2_vision_decoder
2122
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune
2223

@@ -53,7 +54,7 @@ def __init__(self, **kwargs):
5354
self.use_kv_cache = kwargs.get("use_kv_cache", False)
5455
self.verbose = kwargs.get("verbose", False)
5556
self.args = kwargs.get("args", None)
56-
self.dtype = None
57+
self.dtype = kwargs.get("dtype", torch.float16)
5758
self.use_checkpoint = False
5859

5960
ckpt_dir = get_default_model_resource_dir(__file__)
@@ -72,7 +73,7 @@ def __init__(self, **kwargs):
7273
dtype=torch.bool,
7374
)
7475
)
75-
self.input_pos = torch.arange(self.max_seq_len)
76+
self.input_pos = torch.arange(self.max_seq_len, dtype=torch.int64)
7677

7778
# Load checkpoint and params.
7879
device = "cpu"
@@ -107,6 +108,9 @@ def __init__(self, **kwargs):
107108
rope_base=params["rope_theta"],
108109
intermediate_dim=params["intermediate_dim"],
109110
)
111+
112+
# Source transformation for MultiHeadAttention
113+
self.model_ = replace_mha_with_inference_mha(self.model_)
110114
# Save params for future use.
111115
for param_name, param_val in params.items():
112116
setattr(self.model_, param_name, param_val)
@@ -147,39 +151,46 @@ def __init__(self, **kwargs):
147151
self.model_.setup_caches(
148152
batch_size=1,
149153
dtype=self.dtype,
154+
encoder_max_seq_len=self.encoder_max_seq_len,
150155
decoder_max_seq_len=self.max_seq_len,
151156
)
157+
# number of tokens for example input
158+
self.n_tokens = 34
159+
self.model_.to(self.dtype)
152160

153161
def get_eager_model(self) -> torch.nn.Module:
154-
if self.dtype:
155-
return self.model_.to(self.dtype)
156-
else:
157-
return self.model_.to(torch.float16)
162+
return self.model_
158163

159164
def get_example_inputs(self):
160-
return (torch.ones(1, 32, dtype=torch.long),)
165+
return (torch.ones(1, self.n_tokens, dtype=torch.int64),)
161166

162167
def get_example_kwarg_inputs(self):
163168
# For export we must use the prefill versions of the
164169
# causal mask and input_pos.
170+
# Hardcoding # of tiles to be 2. image tokens per tile is 1601.
165171
if self.use_kv_cache:
166172
return {
167-
"input_pos": self.input_pos[None, :32],
168-
"mask": self.causal_mask[None, :32],
169-
# "encoder_input": None,
170-
# "encoder_mask": None,
173+
"input_pos": self.input_pos[None, : self.n_tokens],
174+
"mask": self.causal_mask[None, : self.n_tokens],
175+
"encoder_input": torch.randn(
176+
1, self.encoder_max_seq_len, self.model_.dim, dtype=self.dtype
177+
),
178+
"encoder_mask": torch.ones(
179+
[1, self.n_tokens, self.encoder_max_seq_len], dtype=torch.bool
180+
),
171181
}
172182
else:
173183
return None
174184

175185
def get_dynamic_shapes(self):
176186
batch_size = 1
177187
dim_seq_len = torch.export.Dim("token_dim", min=1, max=self.max_seq_len)
188+
# Hardcoding # of tiles to be 2. image tokens per tile is 1601.
178189
if self.use_kv_cache:
179190
dynamic_shapes = {
180191
"tokens": {0: batch_size, 1: dim_seq_len},
181-
# "encoder_input": {0: 1, 1: dim_enc, 2: 4096},
182-
# "encoder_mask": {0: 1, 1: dim, 2: dim_enc},
192+
"encoder_input": None,
193+
"encoder_mask": {0: 1, 1: dim_seq_len, 2: None},
183194
"mask": {0: batch_size, 1: dim_seq_len, 2: None},
184195
"input_pos": {0: batch_size, 1: dim_seq_len},
185196
}

examples/models/llama3_2_vision/text_decoder/test/__init__.py

Whitespace-only changes.
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Export and ExecuTorch tests for CLIP vision encoder are covered by test_models.sh.
8+
# Only test AOTI in this file
9+
import json
10+
import os
11+
import tempfile
12+
import unittest
13+
14+
import torch
15+
16+
from executorch.examples.models.llama3_2_vision.text_decoder.model import (
17+
Llama3_2Decoder,
18+
)
19+
from torch.testing import assert_close
20+
21+
params = {
22+
"dim": 2048,
23+
"ffn_dim_multiplier": 1.3,
24+
"fusion_interval": 2,
25+
"intermediate_dim": 14336,
26+
"multiple_of": 1024,
27+
"n_heads": 32,
28+
"n_kv_heads": 8,
29+
"n_layers": 2,
30+
"n_special_tokens": 8,
31+
"norm_eps": 1e-05,
32+
"rope_theta": 500000.0,
33+
"use_scaled_rope": True,
34+
"vision_chunk_size": 560,
35+
"vision_max_num_chunks": 4,
36+
"vocab_size": 21008,
37+
"vision_num_cross_attention_layers": 1,
38+
}
39+
40+
41+
class TextDecoderTest(unittest.TestCase):
42+
def setUp(self) -> None:
43+
super().setUp()
44+
45+
def _set_requires_grad_false(self, model: torch.nn.Module) -> None:
46+
for param in model.parameters():
47+
param.requires_grad = False
48+
for child in model.children():
49+
self._set_requires_grad_false(child)
50+
51+
def test_llama3_2_text_decoder_aoti(self) -> None:
52+
with tempfile.NamedTemporaryFile(mode="w") as param_file:
53+
json.dump(params, param_file, indent=2)
54+
param_file.flush()
55+
model = Llama3_2Decoder(
56+
encoder_max_seq_len=6404,
57+
generate_full_logits=True,
58+
enable_dynamic_shape=True,
59+
use_kv_cache=True,
60+
params=param_file.name,
61+
dtype=torch.float32,
62+
)
63+
encoder = model.get_eager_model().eval()
64+
self._set_requires_grad_false(encoder)
65+
66+
# AOTI
67+
with torch.no_grad(), torch.inference_mode():
68+
ep = torch.export.export(
69+
encoder,
70+
model.get_example_inputs(),
71+
kwargs=model.get_example_kwarg_inputs(),
72+
dynamic_shapes=model.get_dynamic_shapes(),
73+
)
74+
with tempfile.TemporaryDirectory() as tmpdir:
75+
path = torch._inductor.aoti_compile_and_package(
76+
ep,
77+
model.get_example_inputs(),
78+
kwargs=model.get_example_kwarg_inputs(),
79+
package_path=os.path.join(tmpdir, "text_decoder.pt2"),
80+
)
81+
encoder_aoti = torch._inductor.aoti_load_package(path)
82+
83+
y = encoder_aoti(
84+
*model.get_example_inputs(), **model.get_example_kwarg_inputs()
85+
)
86+
87+
eager_res = encoder.forward(
88+
*model.get_example_inputs(), **model.get_example_kwarg_inputs()
89+
)
90+
assert_close(y, eager_res, rtol=1e-4, atol=1e-4)

pytest.ini

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ addopts =
1818
examples/models/llama/tests
1919
examples/models/llama3_2_vision/preprocess
2020
examples/models/llama3_2_vision/vision_encoder/test
21+
examples/models/llama3_2_vision/text_decoder/test
2122
# examples/models/llava/test TODO: enable this
2223
# exir
2324
exir/_serialize/test
@@ -43,8 +44,8 @@ addopts =
4344
extension/pybindings/test
4445
# Runtime
4546
runtime
46-
# test
47-
test/end2end/test_end2end.py
47+
# test TODO: fix these tests
48+
# test/end2end/test_end2end.py
4849
--ignore=backends/xnnpack/test/ops/linear.py
4950
--ignore=backends/xnnpack/test/models/llama2_et_example.py
5051
# T200992559: Add torchao to ET as core dependency

0 commit comments

Comments
 (0)