Skip to content

[llama-mm] Make text decoder exportable #6999

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions examples/models/llama3_2_vision/text_decoder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)

from executorch.examples.models.model_base import EagerModelBase
from executorch.extension.llm.modules.attention import replace_mha_with_inference_mha
from torchtune.models.llama3_2_vision._component_builders import llama3_2_vision_decoder
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune

Expand Down Expand Up @@ -53,7 +54,7 @@ def __init__(self, **kwargs):
self.use_kv_cache = kwargs.get("use_kv_cache", False)
self.verbose = kwargs.get("verbose", False)
self.args = kwargs.get("args", None)
self.dtype = None
self.dtype = kwargs.get("dtype", torch.float16)
self.use_checkpoint = False

ckpt_dir = get_default_model_resource_dir(__file__)
Expand All @@ -72,7 +73,7 @@ def __init__(self, **kwargs):
dtype=torch.bool,
)
)
self.input_pos = torch.arange(self.max_seq_len)
self.input_pos = torch.arange(self.max_seq_len, dtype=torch.int64)

# Load checkpoint and params.
device = "cpu"
Expand Down Expand Up @@ -107,6 +108,9 @@ def __init__(self, **kwargs):
rope_base=params["rope_theta"],
intermediate_dim=params["intermediate_dim"],
)

# Source transformation for MultiHeadAttention
self.model_ = replace_mha_with_inference_mha(self.model_)
# Save params for future use.
for param_name, param_val in params.items():
setattr(self.model_, param_name, param_val)
Expand Down Expand Up @@ -147,39 +151,46 @@ def __init__(self, **kwargs):
self.model_.setup_caches(
batch_size=1,
dtype=self.dtype,
encoder_max_seq_len=self.encoder_max_seq_len,
decoder_max_seq_len=self.max_seq_len,
)
# number of tokens for example input
self.n_tokens = 34
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 34 just curious

Copy link
Contributor Author

@larryliu0820 larryliu0820 Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just the number of tokens used by the example input. Could be something else.

self.model_.to(self.dtype)

def get_eager_model(self) -> torch.nn.Module:
if self.dtype:
return self.model_.to(self.dtype)
else:
return self.model_.to(torch.float16)
return self.model_

def get_example_inputs(self):
return (torch.ones(1, 32, dtype=torch.long),)
return (torch.ones(1, self.n_tokens, dtype=torch.int64),)

def get_example_kwarg_inputs(self):
# For export we must use the prefill versions of the
# causal mask and input_pos.
# Hardcoding # of tiles to be 2. image tokens per tile is 1601.
if self.use_kv_cache:
return {
"input_pos": self.input_pos[None, :32],
"mask": self.causal_mask[None, :32],
# "encoder_input": None,
# "encoder_mask": None,
"input_pos": self.input_pos[None, : self.n_tokens],
"mask": self.causal_mask[None, : self.n_tokens],
"encoder_input": torch.randn(
1, self.encoder_max_seq_len, self.model_.dim, dtype=self.dtype
),
"encoder_mask": torch.ones(
[1, self.n_tokens, self.encoder_max_seq_len], dtype=torch.bool
),
}
else:
return None

def get_dynamic_shapes(self):
batch_size = 1
dim_seq_len = torch.export.Dim("token_dim", min=1, max=self.max_seq_len)
# Hardcoding # of tiles to be 2. image tokens per tile is 1601.
if self.use_kv_cache:
dynamic_shapes = {
"tokens": {0: batch_size, 1: dim_seq_len},
# "encoder_input": {0: 1, 1: dim_enc, 2: 4096},
# "encoder_mask": {0: 1, 1: dim, 2: dim_enc},
"encoder_input": None,
"encoder_mask": {0: 1, 1: dim_seq_len, 2: None},
"mask": {0: batch_size, 1: dim_seq_len, 2: None},
"input_pos": {0: batch_size, 1: dim_seq_len},
}
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can rename to test_aoti.py or something

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was hoping to host all tests related to text decoder

# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Export and ExecuTorch tests for CLIP vision encoder are covered by test_models.sh.
# Only test AOTI in this file
import json
import os
import tempfile
import unittest

import torch

from executorch.examples.models.llama3_2_vision.text_decoder.model import (
Llama3_2Decoder,
)
from torch.testing import assert_close

params = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed, reading json is not very convenient in unit test

"dim": 2048,
"ffn_dim_multiplier": 1.3,
"fusion_interval": 2,
"intermediate_dim": 14336,
"multiple_of": 1024,
"n_heads": 32,
"n_kv_heads": 8,
"n_layers": 2,
"n_special_tokens": 8,
"norm_eps": 1e-05,
"rope_theta": 500000.0,
"use_scaled_rope": True,
"vision_chunk_size": 560,
"vision_max_num_chunks": 4,
"vocab_size": 21008,
"vision_num_cross_attention_layers": 1,
}


class TextDecoderTest(unittest.TestCase):
def setUp(self) -> None:
super().setUp()

def _set_requires_grad_false(self, model: torch.nn.Module) -> None:
for param in model.parameters():
param.requires_grad = False
for child in model.children():
self._set_requires_grad_false(child)

def test_llama3_2_text_decoder_aoti(self) -> None:
with tempfile.NamedTemporaryFile(mode="w") as param_file:
json.dump(params, param_file, indent=2)
param_file.flush()
model = Llama3_2Decoder(
encoder_max_seq_len=6404,
generate_full_logits=True,
enable_dynamic_shape=True,
use_kv_cache=True,
params=param_file.name,
dtype=torch.float32,
)
encoder = model.get_eager_model().eval()
self._set_requires_grad_false(encoder)

# AOTI
with torch.no_grad(), torch.inference_mode():
ep = torch.export.export(
encoder,
model.get_example_inputs(),
kwargs=model.get_example_kwarg_inputs(),
dynamic_shapes=model.get_dynamic_shapes(),
)
with tempfile.TemporaryDirectory() as tmpdir:
path = torch._inductor.aoti_compile_and_package(
ep,
model.get_example_inputs(),
kwargs=model.get_example_kwarg_inputs(),
package_path=os.path.join(tmpdir, "text_decoder.pt2"),
)
encoder_aoti = torch._inductor.aoti_load_package(path)

y = encoder_aoti(
*model.get_example_inputs(), **model.get_example_kwarg_inputs()
)

eager_res = encoder.forward(
*model.get_example_inputs(), **model.get_example_kwarg_inputs()
)
assert_close(y, eager_res, rtol=1e-4, atol=1e-4)
5 changes: 3 additions & 2 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ addopts =
examples/models/llama/tests
examples/models/llama3_2_vision/preprocess
examples/models/llama3_2_vision/vision_encoder/test
examples/models/llama3_2_vision/text_decoder/test
# examples/models/llava/test TODO: enable this
# exir
exir/_serialize/test
Expand All @@ -43,8 +44,8 @@ addopts =
extension/pybindings/test
# Runtime
runtime
# test
test/end2end/test_end2end.py
# test TODO: fix these tests
# test/end2end/test_end2end.py
--ignore=backends/xnnpack/test/ops/linear.py
--ignore=backends/xnnpack/test/models/llama2_et_example.py
# T200992559: Add torchao to ET as core dependency
Expand Down
Loading