Skip to content

Commit 1933dae

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Add Llava model definition (#4259)
Summary: Pull Request resolved: #4259 Reviewed By: helunwencser Differential Revision: D59759978 fbshipit-source-id: 8ff8a5b24481b28e0814b45f60b4b0fdbfd47e4e
1 parent c757499 commit 1933dae

File tree

7 files changed

+441
-29
lines changed

7 files changed

+441
-29
lines changed

.ci/scripts/gather_test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
"w2l": "linux.12xlarge",
2424
"ic4": "linux.12xlarge",
2525
"resnet50": "linux.12xlarge",
26-
"llava": "linux.4xlarge",
26+
"llava": "linux.12xlarge",
2727
# This one causes timeout on smaller runner, the root cause is unclear (T161064121)
2828
"dl3": "linux.12xlarge",
2929
"emformer_join": "linux.12xlarge",

.ci/scripts/test.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,14 @@ test_model() {
6767
run_portable_executor_runner
6868
rm "./${MODEL_NAME}.pte"
6969
fi
70+
STRICT="--strict"
7071
if [[ "${MODEL_NAME}" == "llava" ]]; then
7172
# Install requirements for llava
7273
bash examples/models/llava/install_requirements.sh
74+
STRICT="--no-strict"
7375
fi
7476
# python3 -m examples.portable.scripts.export --model_name="llama2" should works too
75-
"${PYTHON_EXECUTABLE}" -m examples.portable.scripts.export --model_name="${MODEL_NAME}"
77+
"${PYTHON_EXECUTABLE}" -m examples.portable.scripts.export --model_name="${MODEL_NAME}" "${STRICT}"
7678
run_portable_executor_runner
7779
}
7880

examples/models/llava/install_requirements.sh

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,14 @@ pip install protobuf
1717
# Reinstall bitsandbytes to make it compatible.
1818
pip install bitsandbytes -I
1919

20+
# numpy needs to be pin to 1.24. 1.26.4 will error out
21+
pip install numpy==1.24
22+
23+
# Newer transformer will give TypeError: LlavaLlamaForCausalLM.forward() got an unexpected keyword argument 'cache_position'
24+
pip install transformers==4.37.2
25+
2026
# The deps of llava can have different versions than deps of ExecuTorch.
2127
# For example, torch version required from llava is older than ExecuTorch.
2228
# To make both work, recover ExecuTorch's original dependencies by rerunning
2329
# the install_requirements.sh.
24-
bash -x ./install_requirements.sh
30+
bash -x ./install_requirements.sh --pybind xnnpack

examples/models/llava/main.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
import torch
10+
11+
from model import LlavaModel
12+
13+
14+
def main():
15+
16+
llava_model = LlavaModel()
17+
llava = llava_model.get_eager_model()
18+
19+
prompt_before_image, resized, prompt_after_image = llava_model.get_example_inputs()
20+
logging.info(f"Prompt: {llava_model.prompt}")
21+
preprocessed = llava.image_preprocess(resized)
22+
with torch.inference_mode():
23+
output_ids = llava_model.model.generate(
24+
llava_model.input_ids,
25+
images=preprocessed,
26+
image_sizes=[preprocessed.size],
27+
do_sample=False,
28+
num_beams=1,
29+
max_new_tokens=10,
30+
use_cache=True,
31+
)
32+
33+
outputs = llava_model.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[
34+
0
35+
].strip()
36+
logging.info(f"Reference output: {outputs}")
37+
38+
# comparing with llava result
39+
# prefill_logits = llava.prefill(prompt_before_image, resized, prompt_after_image)
40+
# prefill_logits_ref = llava.prefill_ref(*inputs)[0]
41+
# print(f"Prefill logits all close? {torch.allclose(prefill_logits, prefill_logits_ref, atol=1e-3)}")
42+
43+
# prefill_logits = llava.prefill(*inputs)
44+
# context_len = prefill_logits.shape[1]
45+
# print(prefill_logits)
46+
# # first token
47+
# new_tokens = [torch.argmax(prefill_logits[..., -1, :]).item()]
48+
# # print(tokenizer.decode(new_tokens))
49+
# for i in range(llava_model.args.max_new_tokens):
50+
# print(i, llava_model.tokenizer.decode(new_tokens[i]))
51+
# logits = llava.forward(
52+
# torch.tensor([new_tokens[i]]), torch.tensor([context_len + i])
53+
# )
54+
# new_tokens.append(torch.argmax(logits[-1, :]))
55+
prefill_logits = llava.prefill(prompt_before_image, resized, prompt_after_image)
56+
context_len = prefill_logits.shape[1]
57+
logging.info(prefill_logits)
58+
new_tokens = [torch.argmax(prefill_logits[..., -1, :]).item()]
59+
i = 0
60+
logging.info(i, llava_model.tokenizer.decode(new_tokens[i]))
61+
logits = llava.step(torch.tensor([new_tokens[i]]), torch.tensor([context_len + i]))
62+
logging.info(logits)
63+
64+
65+
if __name__ == "__main__":
66+
main()

0 commit comments

Comments
 (0)