Skip to content

Commit 0333390

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Add export_llava.py (#4295)
Summary: Pull Request resolved: #4295 As titled. Pending CI job Reviewed By: helunwencser Differential Revision: D59901269 fbshipit-source-id: 0f32357830a677736ac3123526653bff70c8c7af
1 parent 1933dae commit 0333390

File tree

2 files changed

+232
-0
lines changed

2 files changed

+232
-0
lines changed

.github/workflows/pull.yml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,37 @@ jobs:
187187
# Test selective build
188188
PYTHON_EXECUTABLE=python bash examples/selective_build/test_selective_build.sh "${BUILD_TOOL}"
189189
190+
test-export-llava-linux:
191+
name: test-export-llava-linux
192+
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
193+
strategy:
194+
fail-fast: false
195+
with:
196+
runner: linux.12xlarge
197+
docker-image: executorch-ubuntu-22.04-clang12
198+
submodules: 'true'
199+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
200+
timeout: 90
201+
script: |
202+
# The generic Linux job chooses to use base env, not the one setup by the image
203+
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
204+
conda activate "${CONDA_ENV}"
205+
206+
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh "cmake"
207+
208+
# install Llava requirements
209+
bash examples/models/llama2/install_requirements.sh
210+
bash examples/models/llava/install_requirements.sh
211+
212+
# run export_llava.sh
213+
python examples/models/llava/export_llava.py
214+
215+
# verify file exists
216+
if [ ! -f "llava_combined_xnnpack.pte" ]; then
217+
echo "llava_combined_xnnpack.pte not found!"
218+
exit 1
219+
fi
220+
190221
test-quantized-aot-lib-linux:
191222
name: test-quantized-aot-lib-linux
192223
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main

examples/models/llava/export_llava.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
9+
XnnpackDynamicallyQuantizedPartitioner,
10+
# XnnpackFloatingPointPartitioner,
11+
)
12+
from executorch.examples.models.llama2.export_llama_lib import (
13+
build_args_parser,
14+
get_quantizer_and_quant_params,
15+
)
16+
from executorch.examples.models.llama2.source_transformation.quantize import (
17+
get_quant_weight_transform,
18+
)
19+
from executorch.examples.models.llama2.source_transformation.sdpa import (
20+
replace_sdpa_with_custom_op,
21+
)
22+
from executorch.exir import EdgeCompileConfig, to_edge
23+
24+
from executorch.extension.llm.export.builder import DType, LLMEdgeManager
25+
from model import LlavaModel
26+
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
27+
get_symmetric_quantization_config,
28+
XNNPACKQuantizer,
29+
)
30+
from torch.export import Dim
31+
from torch.nn.attention import SDPBackend
32+
33+
34+
class LlavaEdgeManager(LLMEdgeManager):
35+
def capture_pre_autograd_graph(self) -> "LlavaEdgeManager":
36+
dynamic_shape = self._get_dynamic_shape()
37+
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
38+
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
39+
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
40+
self.export_program = torch.export.export(
41+
self.model,
42+
self.example_inputs,
43+
dynamic_shapes=dynamic_shape,
44+
strict=False,
45+
)
46+
self.pre_autograd_graph_module = self.export_program.module()
47+
return self
48+
49+
50+
def export_text_model(llava, embeddings, dynamic_shapes):
51+
class LlavaTextModel(torch.nn.Module):
52+
"""Takes images and prompts and encode them into embeddings. Result will be sent to the text model LlavaTextModel."""
53+
54+
def __init__(self, llava):
55+
super().__init__()
56+
self.text_model = llava.text_model
57+
58+
def forward(self, input_pos, embeddings):
59+
return self.text_model(None, input_pos, embeddings)
60+
61+
llava_text_model = LlavaTextModel(llava)
62+
63+
text_model_em = LLMEdgeManager(
64+
model=llava_text_model,
65+
modelname="llava_text_model",
66+
max_seq_len=llava.text_model_args.max_seq_len,
67+
dtype=DType.fp32,
68+
use_kv_cache=True,
69+
example_inputs=(torch.tensor([0], dtype=torch.int64), embeddings),
70+
dynamic_shapes=dynamic_shapes,
71+
)
72+
73+
dtype_override = DType.fp32
74+
parser = build_args_parser()
75+
args = parser.parse_args(
76+
["-X", "-qmode", "8da4w", "--group_size", "128", "--embedding-quantize", "4,32"]
77+
)
78+
quant_transform = get_quant_weight_transform(args, dtype_override, False)
79+
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
80+
81+
manager = (
82+
text_model_em.set_output_dir("./")
83+
.to_dtype(dtype_override)
84+
.source_transform([replace_sdpa_with_custom_op, quant_transform])
85+
.capture_pre_autograd_graph()
86+
.pt2e_quantize(quantizers)
87+
)
88+
89+
with torch.no_grad():
90+
text_model_ep = torch.export.export(
91+
manager.pre_autograd_graph_module,
92+
manager.example_inputs,
93+
dynamic_shapes=manager._get_dynamic_shape(),
94+
)
95+
return text_model_ep
96+
97+
98+
def export_image_encoder(llava, resized, dynamic_shapes):
99+
class LlavaImageEncoder(torch.nn.Module):
100+
"""Takes images and prompts and encode them into embeddings. Result will be sent to the text model LlavaTextModel."""
101+
102+
def __init__(self, llava):
103+
super().__init__()
104+
self.llava = llava
105+
106+
def forward(self, images):
107+
return self.llava.image_embedding(images)
108+
109+
llava_image_encode = LlavaImageEncoder(llava)
110+
111+
# quantizer
112+
linear_quantizer = XNNPACKQuantizer()
113+
operator_config_dynamic = get_symmetric_quantization_config(
114+
is_per_channel=True, is_dynamic=True
115+
)
116+
linear_quantizer.set_global(operator_config_dynamic)
117+
118+
manager = LlavaEdgeManager(
119+
model=llava_image_encode,
120+
modelname="llava_image_encoder",
121+
max_seq_len=llava.text_model_args.max_seq_len, # This may not be right
122+
dtype=DType.fp32,
123+
use_kv_cache=True,
124+
example_inputs=(resized,),
125+
dynamic_shapes=dynamic_shapes,
126+
).capture_pre_autograd_graph()
127+
128+
# lower to executorch
129+
with torch.no_grad():
130+
image_encoder_ep = torch.export.export(
131+
manager.pre_autograd_graph_module,
132+
manager.example_inputs,
133+
dynamic_shapes=manager.dynamic_shapes,
134+
)
135+
return image_encoder_ep
136+
137+
138+
def export_token_embedding(llava, prompt):
139+
embed = torch.nn.Embedding(
140+
llava.model_.config.vocab_size,
141+
llava.model_.config.hidden_size,
142+
llava.model_.config.pad_token_id,
143+
)
144+
embed.load_state_dict(
145+
llava.model_.get_model().embed_tokens.state_dict(), strict=True, assign=True
146+
)
147+
embed = embed.to(torch.float32)
148+
token_dim_1 = Dim("token_dim_1", min=2, max=3518)
149+
dynamic_shapes = [{1: token_dim_1}]
150+
with torch.no_grad():
151+
token_embedding_ep = torch.export.export(
152+
embed, (prompt,), dynamic_shapes=dynamic_shapes
153+
)
154+
return token_embedding_ep
155+
156+
157+
def main():
158+
llava_model = LlavaModel()
159+
llava = llava_model.get_eager_model()
160+
161+
prompt_before_image, resized, prompt_after_image = (
162+
llava_model.get_inputs_for_prefill()
163+
)
164+
165+
image_encoder_ep = export_image_encoder(
166+
llava, resized, llava_model._get_image_dynamic_shapes()
167+
)
168+
169+
embeddings = llava.prefill_embedding(
170+
prompt_before_image, resized, prompt_after_image
171+
)
172+
173+
text_model_ep = export_text_model(
174+
llava, embeddings, llava_model._get_prompt_dynamic_shapes()
175+
)
176+
177+
token_embedding_ep = export_token_embedding(llava, prompt_before_image)
178+
179+
edge_ep = to_edge(
180+
{
181+
"image_encoder": image_encoder_ep,
182+
"token_embedding": token_embedding_ep,
183+
"text_model": text_model_ep,
184+
},
185+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
186+
)
187+
188+
executorch_program = edge_ep.to_backend(
189+
{
190+
# TODO: Fix Xnnpack partitioner issue on image encoder.
191+
# "image_encoder": XnnpackFloatingPointPartitioner(),
192+
"text_model": XnnpackDynamicallyQuantizedPartitioner(),
193+
}
194+
).to_executorch()
195+
196+
with open("llava_combined_xnnpack.pte", "wb") as f:
197+
executorch_program.write_to_file(f)
198+
199+
200+
if __name__ == "__main__":
201+
main()

0 commit comments

Comments
 (0)