Skip to content

Commit e283967

Browse files
Martin Yuanfacebook-github-bot
authored andcommitted
Add Llava model to examples (#2576)
Summary: Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #2576 Pull Request resolved: #2576 Reviewed By: cccclai Differential Revision: D55268288 Pulled By: iseeyuan fbshipit-source-id: 4574fdfec46594161d0e7f34d2c9187bb1edb1d5
1 parent 250f681 commit e283967

File tree

9 files changed

+129
-2
lines changed

9 files changed

+129
-2
lines changed

.ci/scripts/gather_test_models.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from examples.models import MODEL_NAME_TO_MODEL
1414
from examples.xnnpack import MODEL_NAME_TO_OPTIONS
1515

16-
1716
DEFAULT_RUNNERS = {
1817
"linux": "linux.2xlarge",
1918
"macos": "macos-m1-stable",
@@ -24,6 +23,7 @@
2423
"w2l": "linux.12xlarge",
2524
"ic4": "linux.12xlarge",
2625
"resnet50": "linux.12xlarge",
26+
"llava_encoder": "linux.4xlarge",
2727
# This one causes timeout on smaller runner, the root cause is unclear (T161064121)
2828
"dl3": "linux.12xlarge",
2929
"emformer_join": "linux.12xlarge",
@@ -83,7 +83,17 @@ def model_should_run_on_event(model: str, event: str) -> bool:
8383
We put higher priority and fast models to pull request and rest to push.
8484
"""
8585
if event == "pull_request":
86-
return model in ["add", "ic3", "mv2", "mv3", "resnet18", "vit"]
86+
return model in ["add", "ic3", "mv2", "mv3", "resnet18", "vit", "llava_encoder"]
87+
return True
88+
89+
90+
def model_should_run_on_target_os(model: str, target_os: str) -> bool:
91+
"""
92+
A helper function to decide whether a model should be tested on a target os (linux/macos).
93+
For example, a big model can be disabled in macos due to the limited macos resources.
94+
"""
95+
if target_os == "macos":
96+
return model not in ["llava_encoder"]
8797
return True
8898

8999

@@ -119,6 +129,9 @@ def export_models_for_ci() -> dict[str, dict]:
119129
if not model_should_run_on_event(name, event):
120130
continue
121131

132+
if not model_should_run_on_target_os(name, target_os):
133+
continue
134+
122135
if backend == "xnnpack":
123136
if name not in MODEL_NAME_TO_OPTIONS:
124137
continue

.ci/scripts/test.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ test_model() {
6767
run_portable_executor_runner
6868
rm "./${MODEL_NAME}.pte"
6969
fi
70+
if [[ "${MODEL_NAME}" == "llava_encoder" ]]; then
71+
# Install requirements for llava
72+
bash examples/models/llava_encoder/install_requirements.sh
73+
fi
7074
# python3 -m examples.portable.scripts.export --model_name="llama2" should works too
7175
"${PYTHON_EXECUTABLE}" -m examples.portable.scripts.export --model_name="${MODEL_NAME}"
7276
run_portable_executor_runner

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,6 @@
6262
[submodule "kernels/optimized/third-party/eigen"]
6363
path = kernels/optimized/third-party/eigen
6464
url = https://gitlab.com/libeigen/eigen.git
65+
[submodule "examples/third-party/LLaVA"]
66+
path = examples/third-party/LLaVA
67+
url = https://github.com/haotian-liu/LLaVA.git

examples/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
"ic4": ("inception_v4", "InceptionV4Model"),
2727
"resnet18": ("resnet", "ResNet18Model"),
2828
"resnet50": ("resnet", "ResNet50Model"),
29+
"llava_encoder": ("llava_encoder", "LlavaModel"),
2930
}
3031

3132
__all__ = [
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
## Summary
2+
In this example, we initiate the process of running multi modality through ExecuTorch.
3+
- Demonstrate how to export the image encoder model in the [LLava](https://github.com/haotian-liu/LLaVA) multimodal model.
4+
- Provide TODO steps on how to use the exported .pte file and the existing [exported Llama2 model](https://github.com/pytorch/executorch/tree/main/examples/models/llama2), to build the multimodal pipeline.
5+
6+
## Instructions
7+
Note that this folder does not host the pretrained LLava model.
8+
- To have Llava available, follow the [Install instructions](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#install) in the LLava github. Follow the licence in the specific repo when using L
9+
- Since the pytorch model version may not be updated, `cd executorch`, run `./install_requirements.sh`.
10+
- If there is numpy compatibility issue, run `pip install bitsandbytes -I`.
11+
- Alternatively, run `examples/models/llava_encoder/install_requirements.sh`, to replace the steps above.
12+
- Run `python3 -m examples.portable.scripts.export --model_name="llava_encoder"`. The llava_encoder.pte file will be generated.
13+
- Run `./cmake-out/executor_runner --model_path ./llava_encoder.pte` to verify the exported model with ExecuTorch runtime with portable kernels. Note that the portable kernels are not performance optimized. Please refer to other examples like those in llama2 folder for optimization.
14+
15+
## TODO
16+
- Write the pipeline in cpp
17+
- Have image and text prompts as inputs.
18+
- Call image processing functions to preprocess the image tensor.
19+
- Load the llava_encoder.pte model, run it using the image tensor.
20+
- The output of the encoder can be combined with the prompt, as inputs to the llama model. Call functions in llama_runner.cpp to run the llama model and get outputs. The ExecuTorch end to end flow for the llama model is located at `examples/models/llama2`.
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import LlavaModel
8+
9+
__all__ = [
10+
LlavaModel,
11+
]
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#!/bin/bash
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# install llava from the submodule
9+
pip install --force-reinstall -e examples/third-party/LLaVA
10+
11+
# not included in the pip install package, but needed in llava
12+
pip install protobuf
13+
14+
# bitsandbytes depends on numpy 1.x, which is not compatible with numpy 2.x.
15+
# Reinstall bitsandbytes to make it compatible.
16+
pip install bitsandbytes -I
17+
18+
# The deps of llava can have different versions than deps of ExecuTorch.
19+
# For example, torch version required from llava is older than ExecuTorch.
20+
# To make both work, recover ExecuTorch's original dependencies by rerunning
21+
# the install_requirements.sh.
22+
./install_requirements.sh
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
9+
from examples.models.model_base import EagerModelBase
10+
from llava.eval.run_llava import load_images, process_images
11+
from llava.mm_utils import get_model_name_from_path
12+
13+
from llava.model.builder import load_pretrained_model
14+
from torch import nn
15+
16+
17+
class EncoderModel(nn.Module):
18+
def __init__(self, llava_model):
19+
super().__init__()
20+
self.model_ = llava_model
21+
22+
def forward(self, images_tensor):
23+
features = self.model_.get_model().get_vision_tower()(images_tensor)
24+
features = self.model_.get_model().mm_projector(features)
25+
return features
26+
27+
28+
class LlavaModel(EagerModelBase):
29+
def __init__(self):
30+
model_path = "liuhaotian/llava-v1.5-7b"
31+
tokenizer, self.model_, self.image_processor_, context_len = (
32+
load_pretrained_model(
33+
model_path=model_path,
34+
model_base=None,
35+
model_name=get_model_name_from_path(model_path),
36+
)
37+
)
38+
self.device = "cpu"
39+
self.dtype = torch.float32
40+
self.model_.to(device=self.device, dtype=self.dtype)
41+
42+
def get_eager_model(self):
43+
model = EncoderModel(self.model_)
44+
return model
45+
46+
def get_example_inputs(self):
47+
image_file = "https://llava-vl.github.io/static/images/view.jpg"
48+
images = load_images([image_file])
49+
images_tensor = process_images(
50+
images, self.image_processor_, self.model_.config
51+
).to(self.model_.device)
52+
return (images_tensor,)

examples/third-party/LLaVA

Submodule LLaVA added at 7440ec9

0 commit comments

Comments
 (0)