Skip to content

Runner changes for TorchTune Llama3.2 vision text decoder #6610

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 46 commits into from
Nov 14, 2024

Conversation

jackzhxng
Copy link
Contributor

@jackzhxng jackzhxng commented Nov 1, 2024

Summary

Changes to eager (Python) and native (ET) runners to run TorchTune's llama3_2_vision text decoder without KV cache (KV cache in progress). Should extend to the regular TorchTune llama3_2 model as well, will add support in following PRs.

Native runner relies on #6670 to get in.

PR chain:

Test plan

Download the model from torchtune: tune download meta-llama/Llama-3.2-11B-Vision-Instruct --output-dir /tmp/Llama-3.2-11B-Vision-Instruct.

Run eager:

python -m examples.models.llama.runner.eager --model llama3_2_vision --checkpoint /tmp/Llama-3.2-11B-Vision-Instruct/original/consolidated.pth  --params examples/models/llama3_2_vision/text_decoder/params/demo_config.json --metadata '{"append_eos_to_prompt": 0, "get_bos_id":128000, "get_eos_ids":[128009, 128001], "get_n_bos": 0, "get_n_eos": 0}' --output_name="llama3_2_vision.pte" --tokenizer /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model -d fp32 --verbose --prompt "What is the capital of USA?" --max_seq_length 32

Run executorch on portable lib (doesn't work until #6670 gets in, can test for now by adding pytorch/pytorch#137662 following to PyTorch installation):

# Export model to executorch.
python -m examples.models.llama.export_llama --model llama3_2_vision --checkpoint /tmp/Llama-3.2-11B-Vision-Instruct/original/consolidated.pth  --params examples/models/llama3_2_vision/text_decoder/params/demo_config.json -d fp32 --metadata '{"append_eos_to_prompt": 0, "get_bos_id":128000, "get_eos_ids":[128009, 128001], "get_n_bos": 0, "get_n_eos": 0}' --output_name="llama3_2_vision.pte"

# Run using native runner.
python -m examples.models.llama.runner.native --model llama3_2_vision --pte llama3_2_vision.pte  --tokenizer /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model --prompt "How many calories are in bread?" --params examples/models/llama3_2_vision/text_decoder/params/demo_config.json --max_len 64

facebook-github-bot pushed a commit that referenced this pull request Nov 11, 2024
Summary:
Specify model to export in the CLI.


Test Plan:
Exported the stories 110M model.
```
python -m examples.models.llama.export_llama -c stories110M/stories110M.pt -p stories110M/params.json -X -kv
```

PR chain:
- [Add kwarg example inputs to eager model base](#5765)
- [Llama2 model cleanup](#5859)
- **YOU ARE HERE ~>** [Accept model type parameter in export_llama](#5910)
- [Export TorchTune llama3_2_vision in ET](#5911)
- [Runner changes for TorchTune Llama3.2 vision text decoder](#6610)
- [Add et version of TorchTune MHA for swapping with custom op](#5912)

Differential Revision: D65612837

Pulled By: dvorjackz
facebook-github-bot pushed a commit that referenced this pull request Nov 12, 2024
Summary:
Specify model to export in the CLI.


Test Plan:
Exported the stories 110M model.
```
python -m examples.models.llama.export_llama -c stories110M/stories110M.pt -p stories110M/params.json -X -kv
```

PR chain:
- [Add kwarg example inputs to eager model base](#5765)
- [Llama2 model cleanup](#5859)
- **YOU ARE HERE ~>** [Accept model type parameter in export_llama](#5910)
- [Export TorchTune llama3_2_vision in ET](#5911)
- [Runner changes for TorchTune Llama3.2 vision text decoder](#6610)
- [Add et version of TorchTune MHA for swapping with custom op](#5912)

Reviewed By: helunwencser

Differential Revision: D65612837

Pulled By: dvorjackz
facebook-github-bot pushed a commit that referenced this pull request Nov 12, 2024
Summary:
Specify model to export in the CLI.


Test Plan:
Exported the stories 110M model.
```
python -m examples.models.llama.export_llama -c stories110M/stories110M.pt -p stories110M/params.json -X -kv
```

PR chain:
- [Add kwarg example inputs to eager model base](#5765)
- [Llama2 model cleanup](#5859)
- **YOU ARE HERE ~>** [Accept model type parameter in export_llama](#5910)
- [Export TorchTune llama3_2_vision in ET](#5911)
- [Runner changes for TorchTune Llama3.2 vision text decoder](#6610)
- [Add et version of TorchTune MHA for swapping with custom op](#5912)

Reviewed By: helunwencser

Differential Revision: D65612837

Pulled By: dvorjackz
facebook-github-bot pushed a commit that referenced this pull request Nov 13, 2024
Summary:
Specify model to export in the CLI.


Test Plan:
Exported the stories 110M model.
```
python -m examples.models.llama.export_llama -c stories110M/stories110M.pt -p stories110M/params.json -X -kv
```

PR chain:
- [Add kwarg example inputs to eager model base](#5765)
- [Llama2 model cleanup](#5859)
- **YOU ARE HERE ~>** [Accept model type parameter in export_llama](#5910)
- [Export TorchTune llama3_2_vision in ET](#5911)
- [Runner changes for TorchTune Llama3.2 vision text decoder](#6610)
- [Add et version of TorchTune MHA for swapping with custom op](#5912)

Reviewed By: helunwencser

Differential Revision: D65612837

Pulled By: dvorjackz
@@ -89,7 +101,6 @@ def build_args_parser() -> argparse.ArgumentParser:
parser.add_argument(
"-kv",
"--kv_cache",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'd want the default to still be True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah this was weird since "store_true" works by having a default of False, but the default is here set to True so it's just always True regardless of what you put

@jackzhxng jackzhxng requested a review from tarun292 November 14, 2024 18:38
Base automatically changed from jz/tt-llama-2 to main November 14, 2024 22:01
@jackzhxng jackzhxng merged commit 6c944db into main Nov 14, 2024
38 of 39 checks passed
@jackzhxng jackzhxng deleted the jz/native-runner-tt branch November 14, 2024 22:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. release notes: examples Changes to any of our example LLMs integrations, such as Llama3 and Llava
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants