Skip to content

Lower phi3 mini with LoRA to edge for training #4722

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/models/phi3-mini-lora/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
## Summary
In this exmaple, we export a model ([phi-3-mini](https://github.com/pytorch/executorch/tree/main/examples/models/phi-3-mini)) appended with additional LoRA layers to ExecuTorch.
In this example, we export to ExecuTorch a model ([phi-3-mini](https://github.com/pytorch/executorch/tree/main/examples/models/phi-3-mini)) appended with attention and mlp LoRA layers. The model is exported to ExecuTorch for both inference and training. Note: the exported training model can only train at the moment.

## Instructions
### Step 1: [Optional] Install ExecuTorch dependencies
Expand All @@ -9,12 +9,12 @@ In this exmaple, we export a model ([phi-3-mini](https://github.com/pytorch/exec
- `./examples/models/phi3-mini-lora/install_requirements.sh`

### Step 3: Export and run the model
1. Export the model to ExecuTorch.
1. Export the inferenace and training models to ExecuTorch.
```
python export_model.py
```

2. Run the model using an example runtime. For more detailed steps on this, check out [Build & Run](https://pytorch.org/executorch/stable/getting-started-setup.html#build-run).
2. Run the inference model using an example runtime. For more detailed steps on this, check out [Build & Run](https://pytorch.org/executorch/stable/getting-started-setup.html#build-run).
```
# Clean and configure the CMake build system. Compiled programs will appear in the executorch/cmake-out directory we create here.
(rm -rf cmake-out && mkdir cmake-out && cd cmake-out && cmake ..)
Expand Down
78 changes: 70 additions & 8 deletions examples/models/phi3-mini-lora/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,40 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir import to_edge
from torch import int64, long, no_grad, randint, Tensor, zeros
from torch.export import export, ExportedProgram
from torch.export.experimental import _export_forward_backward
from torch.nn.attention import sdpa_kernel, SDPBackend
from torchtune.models.phi3._model_builders import lora_phi3_mini

vocab_size = 32064


class TrainingModule(torch.nn.Module):
"""
The model being trained should return the loss from forward(). This
class wraps the actual phi3-mini model and calculates an arbitrary
loss with its forward() output.
"""

def __init__(self, model, loss):
super().__init__()
self.model = model
self.loss = loss

def forward(self, input):
# Output is of the shape (seq_len, vocab_size).
output = self.model(input)
target = zeros((1, vocab_size), dtype=long)
return self.loss(output, target)


@no_grad()
def export_mini_phi3_lora(model) -> None:
def export_phi3_mini_lora(model) -> None:
"""
Export the example mini-phi3 with LoRA model to executorch.
Export the example phi3-mini with LoRA model to executorch.

Note: need to use the SDPBackend's custom kernel for sdpa (scalable
dot product attention) because the default sdpa kernel used in the
Expand All @@ -36,28 +59,67 @@ def export_mini_phi3_lora(model) -> None:
executorch_program = edge_program.to_executorch()

# 4. Save the compiled .pte program.
print("Saving to mini_phi3_lora.pte")
with open("mini_phi3_lora.pte", "wb") as file:
print("Saving to phi3_mini_lora.pte")
with open("phi3_mini_lora.pte", "wb") as file:
file.write(executorch_program.buffer)

print("Done.")


def export_phi3_mini_lora_training(model) -> None:
"""
Export the example phi3-mini with LoRA model to executorch for training, only.
"""
print("Exporting phi3-mini with LoRA for training")
# 1. torch.export: Defines the program with the ATen operator set.
print("Exporting to aten dialect")
example_args = (randint(0, 100, (1, 100), dtype=long),)
with sdpa_kernel([SDPBackend.MATH]):
exported_graph: ExportedProgram = export(model, example_args)
print("Creating a joint forward-backwards graph for training")
joint_graph = _export_forward_backward(exported_graph)

# 2. to_edge: Make optimizations for Edge devices.
print("Lowering to edge dialect")
edge_program = to_edge(joint_graph)

print(edge_program._edge_programs["forward"].graph_module)

# 3. to_executorch: Convert the graph to an ExecuTorch program.
print("Exporting to executorch")
executorch_program = edge_program.to_executorch()

# 4. Save the compiled .pte program.
print("Saving to phi3_mini_lora_training.pte")
with open("phi3_mini_lora_training.pte", "wb") as file:
file.write(executorch_program.buffer)

print("Done.")


def run_mini_phi3_lora(model) -> Tensor:
def run_phi3_mini_lora(model) -> Tensor:
"""Run the model and return the result."""
args = zeros([3072, 1], dtype=int64)
# Input shape: (batch_size, seq_len).
args = zeros((1, 10), dtype=int64)
model.eval()
res = model(args)
return res


def main() -> None:
mini_lora_model = lora_phi3_mini(
print("Main")
lora_model = lora_phi3_mini(
lora_attn_modules=[
"q_proj",
]
)
export_mini_phi3_lora(mini_lora_model)

# Export for inference.
export_phi3_mini_lora(lora_model)

# Export for training.
lora_training_model = TrainingModule(lora_model, torch.nn.CrossEntropyLoss())
export_phi3_mini_lora_training(lora_training_model)


if __name__ == "__main__":
Expand Down
Loading