Skip to content

Commit a6cfb03

Browse files
authored
[llama-mm] Onboard torchtune vision encoder to ExecuTorch/AOTI (#6807)
[llama-mm] Onboard torchtune vision encoder to ExecuTorch Summary: As titled. This PR adds `llama3_2_vision_encoder` to `examples/models/llama3_2_vision/vision_encoder` and add CI jobs. Test Plan: Rely on newly added CI jobs Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent d8ec9ae commit a6cfb03

File tree

7 files changed

+152
-1
lines changed

7 files changed

+152
-1
lines changed

.ci/scripts/gather_test_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"ic4": "linux.12xlarge",
2525
"resnet50": "linux.12xlarge",
2626
"llava": "linux.12xlarge",
27+
"llama3_2_vision_encoder": "linux.12xlarge",
2728
# This one causes timeout on smaller runner, the root cause is unclear (T161064121)
2829
"dl3": "linux.12xlarge",
2930
"emformer_join": "linux.12xlarge",

examples/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"emformer_join": ("emformer_rnnt", "EmformerRnntJoinerModel"),
1919
"llama2": ("llama", "Llama2Model"),
2020
"llama": ("llama", "Llama2Model"),
21+
"llama3_2_vision_encoder": ("llama3_2_vision", "FlamingoVisionEncoderModel"),
2122
"lstm": ("lstm", "LSTMModel"),
2223
"mobilebert": ("mobilebert", "MobileBertModelExample"),
2324
"mv2": ("mobilenet_v2", "MV2Model"),
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import FlamingoVisionEncoderModel, VisionEncoderConfig
8+
9+
__all__ = [
10+
"FlamingoVisionEncoderModel",
11+
"VisionEncoderConfig",
12+
]
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass, field
8+
from typing import Optional
9+
10+
import torch
11+
12+
from executorch.examples.models.model_base import EagerModelBase
13+
from executorch.extension.llm.modules._position_embeddings import (
14+
replace_tile_positional_embedding,
15+
replace_tiled_token_positional_embedding,
16+
)
17+
from torchtune.models.llama3_2_vision._component_builders import llama3_2_vision_encoder
18+
19+
max_seq_len = 8192
20+
in_channels = 3
21+
tile_size = 560
22+
max_num_tiles = 4
23+
# how many tokens per image generated by the vision encoder
24+
tokens_per_image = 6404
25+
# how many images to cache in the kv cache in cross attention
26+
kv_cache_image_num = 1
27+
# maximum number of tokens generated by encoder and thus stored in the kv cache in cross attention
28+
encoder_max_seq_len = tokens_per_image * kv_cache_image_num
29+
30+
31+
@dataclass
32+
class VisionEncoderConfig:
33+
patch_size: int = 14
34+
num_heads: int = 16
35+
clip_embed_dim: int = 1280
36+
clip_num_layers: int = 32
37+
clip_hidden_states: list[int] = field(default_factory=lambda: [3, 7, 15, 23, 30])
38+
decoder_embed_dim: int = 4096
39+
num_layers_projection: int = 8
40+
tile_size: int = 560
41+
max_num_tiles: int = 4
42+
in_channels: int = 3
43+
44+
45+
class FlamingoVisionEncoderModel(EagerModelBase):
46+
def __init__(self, config: Optional[VisionEncoderConfig] = None):
47+
super().__init__()
48+
if config is None:
49+
config = VisionEncoderConfig()
50+
self.config = config
51+
self.model = llama3_2_vision_encoder(
52+
patch_size=config.patch_size,
53+
num_heads=config.num_heads,
54+
clip_embed_dim=config.clip_embed_dim,
55+
clip_num_layers=config.clip_num_layers,
56+
clip_hidden_states=config.clip_hidden_states,
57+
decoder_embed_dim=config.decoder_embed_dim,
58+
num_layers_projection=config.num_layers_projection,
59+
tile_size=config.tile_size,
60+
max_num_tiles=config.max_num_tiles,
61+
in_channels=config.in_channels,
62+
)
63+
self.model = replace_tile_positional_embedding(self.model)
64+
self.model = replace_tiled_token_positional_embedding(self.model)
65+
self.image = torch.randn(
66+
1, 1, 4, 3, self.config.tile_size, self.config.tile_size
67+
)
68+
self.aspect_ratio = torch.tensor([[[1, 2]]])
69+
self.sample_inputs = (
70+
self.image,
71+
self.aspect_ratio,
72+
)
73+
74+
def get_eager_model(self, **kwargs):
75+
return self.model
76+
77+
def get_example_inputs(self):
78+
return self.sample_inputs
79+
80+
def get_dynamic_shapes(self):
81+
dim = torch.export.Dim("num_tiles", min=1, max=self.config.max_num_tiles)
82+
image_dynamic_dim = {
83+
0: 1,
84+
1: 1,
85+
2: dim,
86+
3: 3,
87+
4: self.config.tile_size,
88+
5: self.config.tile_size,
89+
}
90+
return (image_dynamic_dim, None)

examples/models/llama3_2_vision/vision_encoder/test/__init__.py

Whitespace-only changes.
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Export and ExecuTorch tests for CLIP vision encoder are covered by test_models.sh.
8+
# Only test AOTI in this file
9+
import os
10+
import tempfile
11+
import unittest
12+
13+
import torch
14+
15+
from executorch.examples.models.llama3_2_vision.vision_encoder import (
16+
FlamingoVisionEncoderModel,
17+
)
18+
from torch.testing import assert_close
19+
20+
21+
class FlamingoVisionEncoderTest(unittest.TestCase):
22+
def setUp(self) -> None:
23+
super().setUp()
24+
25+
def test_flamingo_vision_encoder(self) -> None:
26+
model = FlamingoVisionEncoderModel()
27+
encoder = model.model
28+
eager_res = encoder.forward(*model.get_example_inputs())
29+
30+
# AOTI
31+
ep = torch.export.export(
32+
encoder,
33+
model.get_example_inputs(),
34+
dynamic_shapes=model.get_dynamic_shapes(),
35+
)
36+
with tempfile.TemporaryDirectory() as tmpdir:
37+
path = torch._inductor.aoti_compile_and_package(
38+
ep,
39+
model.get_example_inputs(),
40+
package_path=os.path.join(tmpdir, "vision_encoder.pt2"),
41+
)
42+
print(path)
43+
encoder_aoti = torch._inductor.aoti_load_package(path)
44+
45+
y = encoder_aoti(*model.get_example_inputs())
46+
assert_close(y, eager_res, rtol=1e-4, atol=1e-4)

pytest.ini

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ addopts =
1616
devtools/
1717
# examples
1818
examples/models/llama/tests
19-
examples/models/llama3_2_vision/preprocess
19+
examples/models/llama3_2_vision/preprocess/test
20+
examples/models/llama3_2_vision/vision_encoder/test
2021
# examples/models/llava/test TODO: enable this
2122
# exir
2223
exir/_serialize/test

0 commit comments

Comments
 (0)