Skip to content

Commit 801276e

Browse files
committed
[llava] Use huggingface LLaVA instead of depending on third-party/LLaVa
Currently we depend on third-party/LLaVA for llava model definition. This is hard to use because we have to pull LLaVA in as a git submodule and install from there. It also breaks a lot of dependency assumptions. This PR removes `third-party/LLaVA`, in favor of huggingface llava model definition. ghstack-source-id: 3c11fa5 Pull Request resolved: #4687
1 parent 2117c1a commit 801276e

File tree

7 files changed

+188
-201
lines changed

7 files changed

+188
-201
lines changed

.github/workflows/pull.yml

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -209,23 +209,6 @@ jobs:
209209
bash examples/models/llama2/install_requirements.sh
210210
bash examples/models/llava/install_requirements.sh
211211
212-
# run export_llava.sh
213-
python examples/models/llava/export_llava.py --use-sdpa-with-kv-cache --pte-name llava_custom_sdpa.pte
214-
215-
# verify file exists
216-
if [ ! -f "llava_custom_sdpa.pte" ]; then
217-
echo "llava_custom_sdpa.pte not found!"
218-
exit 1
219-
fi
220-
221-
python examples/models/llava/export_llava.py --no-use-sdpa-with-kv-cache --pte-name llava.pte
222-
223-
# verify file exists
224-
if [ ! -f "llava.pte" ]; then
225-
echo "llava.pte not found!"
226-
exit 1
227-
fi
228-
229212
# run python unittest
230213
python -m unittest examples.models.llava.test.test_llava
231214

.gitmodules

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,6 @@
2828
[submodule "backends/xnnpack/third-party/pthreadpool"]
2929
path = backends/xnnpack/third-party/pthreadpool
3030
url = https://github.com/Maratyszcza/pthreadpool.git
31-
[submodule "examples/third-party/LLaVA"]
32-
path = examples/third-party/LLaVA
33-
url = https://github.com/haotian-liu/LLaVA.git
3431
[submodule "examples/third-party/fbjni"]
3532
path = examples/third-party/fbjni
3633
url = https://github.com/facebookincubator/fbjni.git

examples/models/llava/export_llava.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def forward(self, input_pos, embeddings):
8585
["-X", "-qmode", "8da4w", "--group_size", "128", "--embedding-quantize", "4,32"]
8686
)
8787
quant_transform = get_quant_weight_transform(args, dtype_override, False)
88-
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
88+
_, quantizers, _ = get_quantizer_and_quant_params(args)
8989
source_transforms = []
9090
if llava.use_sdpa_with_kv_cache_op:
9191
source_transforms.append(replace_sdpa_with_custom_op)
@@ -149,15 +149,7 @@ def forward(self, images):
149149

150150

151151
def export_token_embedding(llava, prompt):
152-
embed = torch.nn.Embedding(
153-
llava.model_.config.vocab_size,
154-
llava.model_.config.hidden_size,
155-
llava.model_.config.pad_token_id,
156-
)
157-
embed.load_state_dict(
158-
llava.model_.get_model().embed_tokens.state_dict(), strict=True, assign=True
159-
)
160-
embed = embed.to(torch.float32)
152+
embed = llava.embed_tokens
161153
token_dim_1 = Dim("token_dim_1", min=2, max=3518)
162154
dynamic_shapes = [{1: token_dim_1}]
163155
with torch.no_grad():
@@ -167,24 +159,7 @@ def export_token_embedding(llava, prompt):
167159
return token_embedding_ep
168160

169161

170-
def main():
171-
parser = ArgumentParser()
172-
parser.add_argument(
173-
"--use-sdpa-with-kv-cache",
174-
default=True,
175-
action=BooleanOptionalAction,
176-
help="Use sdpa_with_kv_cache custom op in LLava text model.",
177-
)
178-
parser.add_argument(
179-
"--pte-name",
180-
default="llava_combined_xnnpack.pte",
181-
help="Name of the exported ExecuTorch program.",
182-
)
183-
args = parser.parse_args()
184-
logging.info(
185-
f"Exporting Llava model to ExecuTorch with sdpa_with_kv_cache: {args.use_sdpa_with_kv_cache}"
186-
)
187-
llava_model = LlavaModel(use_sdpa_with_kv_cache_op=args.use_sdpa_with_kv_cache)
162+
def export_all(llava_model: LlavaModel):
188163
llava = llava_model.get_eager_model()
189164

190165
(
@@ -226,6 +201,29 @@ def main():
226201
)
227202

228203
executorch_program = lowered_and_edge.to_executorch()
204+
return executorch_program
205+
206+
207+
def main():
208+
parser = ArgumentParser()
209+
parser.add_argument(
210+
"--use-sdpa-with-kv-cache",
211+
default=True,
212+
action=BooleanOptionalAction,
213+
help="Use sdpa_with_kv_cache custom op in LLava text model.",
214+
)
215+
parser.add_argument(
216+
"--pte-name",
217+
default="llava_combined_xnnpack.pte",
218+
help="Name of the exported ExecuTorch program.",
219+
)
220+
args = parser.parse_args()
221+
logging.info(
222+
f"Exporting Llava model to ExecuTorch with sdpa_with_kv_cache: {args.use_sdpa_with_kv_cache}"
223+
)
224+
llava_model = LlavaModel(use_sdpa_with_kv_cache_op=args.use_sdpa_with_kv_cache)
225+
226+
executorch_program = export_all(llava_model)
229227

230228
with open(args.pte_name, "wb") as f:
231229
executorch_program.write_to_file(f)

examples/models/llava/install_requirements.sh

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,39 +6,7 @@
66
# LICENSE file in the root directory of this source tree.
77

88
set -x
9-
OS=$(uname)
109

11-
# install llava from the submodule. We can't do pip install llava because it is packaged incorrectly.
12-
if [[ $OS != "Darwin" ]];
13-
then
14-
#This doesn't work for macos, on python 3.12, because torch 2.1.2 is missing.
15-
pip install --force-reinstall -e examples/third-party/LLaVA
16-
else
17-
# manually install dependencies
18-
pip install tokenizers==0.15.1 sentencepiece==0.1.99 \
19-
shortuuid accelerate==0.21.0 peft \
20-
pydantic markdown2[all] scikit-learn==1.2.2 \
21-
requests httpx==0.24.0 uvicorn fastapi \
22-
einops==0.6.1 einops-exts==0.0.4 timm==0.6.13
23-
24-
pip install --force-reinstall -e examples/third-party/LLaVA --no-deps
25-
fi
26-
27-
# not included in the pip install package, but needed in llava
28-
pip install protobuf
29-
30-
# bitsandbytes depends on numpy 1.x, which is not compatible with numpy 2.x.
31-
# Reinstall bitsandbytes to make it compatible.
32-
pip install bitsandbytes -I
33-
34-
# The deps of llava can have different versions than deps of ExecuTorch.
35-
# For example, torch version required from llava is older than ExecuTorch.
36-
# To make both work, recover ExecuTorch's original dependencies by rerunning
37-
# the install_requirements.sh. Notice this won't install executorch.
38-
bash -x ./install_requirements.sh --pybind xnnpack
39-
40-
# Newer transformer (4.38) will give TypeError: LlavaLlamaForCausalLM.forward() got an unexpected keyword argument 'cache_position'
41-
pip install timm==0.6.13
42-
pip install transformers==4.37.2
10+
pip install transformers
4311

4412
pip list

0 commit comments

Comments
 (0)