Skip to content

Export mini phi3 LoRA model #4062

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions examples/models/phi3-mini-lora/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
## Summary
In this exmaple, we export a model ([phi-3-mini](https://github.com/pytorch/executorch/tree/main/examples/models/phi-3-mini)) appended with additional LoRA layers to ExecuTorch.

## Instructions
### Step 1: [Optional] Install ExecuTorch dependencies
`./install_requirements.sh` in ExecuTorch root directory.

### Step 2: Install TorchTune nightly
The LoRA model used is recent and is not yet officially released on `TorchTune`. To be able to run this example, you will need to run the following to install TorchTune nighly:
- `./examples/models/llava_encoder/install_requirements.sh`'

### Step 3: Export and run the model
1. Export the model to ExecuTorch.
```
python export_model.py
```

2. Run the model using an example runtime. For more more detailed steps on this, check out [Build & Run](https://pytorch.org/executorch/stable/getting-started-setup.html#build-run).
```
# Clean and configure the CMake build system. Compiled programs will appear in the executorch/cmake-out directory we create here.
(rm -rf cmake-out && mkdir cmake-out && cd cmake-out && cmake ..)

# Build the executor_runner target
cmake --build cmake-out --target executor_runner -j9

./cmake-out/executor_runner --model_path mini_phi3_lora.pte
```
64 changes: 64 additions & 0 deletions examples/models/phi3-mini-lora/export_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.exir import to_edge
from torch import int64, long, no_grad, randint, Tensor, zeros
from torch.export import export, ExportedProgram
from torch.nn.attention import sdpa_kernel, SDPBackend
from torchtune.models.phi3._model_builders import lora_phi3_mini


@no_grad()
def export_mini_phi3_lora(model) -> None:
"""
Export the example mini-phi3 with LoRA model to executorch.
Note: need to use the SDPBackend's custom kernel for sdpa (scalable
dot product attention) because the default sdpa kernel used in the
model results in a internally mutating graph.
"""
model.eval()
# 1. torch.export: Defines the program with the ATen operator set.
print("Exporting to aten dialect")
example_args = (randint(0, 100, (1, 100), dtype=long),)
with sdpa_kernel([SDPBackend.MATH]):
aten_dialect: ExportedProgram = export(model, example_args)

# 2. to_edge: Make optimizations for Edge devices.
print("Lowering to edge dialect")
edge_program = to_edge(aten_dialect)

# 3. to_executorch: Convert the graph to an ExecuTorch program.
print("Exporting to executorch")
executorch_program = edge_program.to_executorch()

# 4. Save the compiled .pte program.
print("Saving to mini_phi3_lora.pte")
with open("mini_phi3_lora.pte", "wb") as file:
file.write(executorch_program.buffer)

print("Done.")


def run_mini_phi3_lora(model) -> Tensor:
"""Run the model and return the result."""
args = zeros([3072, 1], dtype=int64)
model.eval()
res = model(args)
return res


def main() -> None:
mini_lora_model = lora_phi3_mini(
lora_attn_modules=[
"q_proj",
]
)
export_mini_phi3_lora(mini_lora_model)


if __name__ == "__main__":
main()
10 changes: 10 additions & 0 deletions examples/models/phi3-mini-lora/install_requirements.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Install nightly build of TorchTune.
pip install --pre torchtune --extra-index-url https://download.pytorch.org/whl/nightly/cpu --no-cache-dir
pip install tiktoken
Loading