Skip to content

Commit d3da92d

Browse files
authored
Lower phi3 mini with LoRA to edge for training
Differential Revision: D61309917 Pull Request resolved: #4722
1 parent 2b3c01c commit d3da92d

File tree

2 files changed

+73
-11
lines changed

2 files changed

+73
-11
lines changed

examples/models/phi3-mini-lora/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
## Summary
2-
In this exmaple, we export a model ([phi-3-mini](https://github.com/pytorch/executorch/tree/main/examples/models/phi-3-mini)) appended with additional LoRA layers to ExecuTorch.
2+
In this example, we export to ExecuTorch a model ([phi-3-mini](https://github.com/pytorch/executorch/tree/main/examples/models/phi-3-mini)) appended with attention and mlp LoRA layers. The model is exported to ExecuTorch for both inference and training. Note: the exported training model can only train at the moment.
33

44
## Instructions
55
### Step 1: [Optional] Install ExecuTorch dependencies
@@ -9,12 +9,12 @@ In this exmaple, we export a model ([phi-3-mini](https://github.com/pytorch/exec
99
- `./examples/models/phi3-mini-lora/install_requirements.sh`
1010

1111
### Step 3: Export and run the model
12-
1. Export the model to ExecuTorch.
12+
1. Export the inferenace and training models to ExecuTorch.
1313
```
1414
python export_model.py
1515
```
1616

17-
2. Run the model using an example runtime. For more detailed steps on this, check out [Build & Run](https://pytorch.org/executorch/stable/getting-started-setup.html#build-run).
17+
2. Run the inference model using an example runtime. For more detailed steps on this, check out [Build & Run](https://pytorch.org/executorch/stable/getting-started-setup.html#build-run).
1818
```
1919
# Clean and configure the CMake build system. Compiled programs will appear in the executorch/cmake-out directory we create here.
2020
(rm -rf cmake-out && mkdir cmake-out && cd cmake-out && cmake ..)

examples/models/phi3-mini-lora/export_model.py

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,40 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import torch
78
from executorch.exir import to_edge
89
from torch import int64, long, no_grad, randint, Tensor, zeros
910
from torch.export import export, ExportedProgram
11+
from torch.export.experimental import _export_forward_backward
1012
from torch.nn.attention import sdpa_kernel, SDPBackend
1113
from torchtune.models.phi3._model_builders import lora_phi3_mini
1214

15+
vocab_size = 32064
16+
17+
18+
class TrainingModule(torch.nn.Module):
19+
"""
20+
The model being trained should return the loss from forward(). This
21+
class wraps the actual phi3-mini model and calculates an arbitrary
22+
loss with its forward() output.
23+
"""
24+
25+
def __init__(self, model, loss):
26+
super().__init__()
27+
self.model = model
28+
self.loss = loss
29+
30+
def forward(self, input):
31+
# Output is of the shape (seq_len, vocab_size).
32+
output = self.model(input)
33+
target = zeros((1, vocab_size), dtype=long)
34+
return self.loss(output, target)
35+
1336

1437
@no_grad()
15-
def export_mini_phi3_lora(model) -> None:
38+
def export_phi3_mini_lora(model) -> None:
1639
"""
17-
Export the example mini-phi3 with LoRA model to executorch.
40+
Export the example phi3-mini with LoRA model to executorch.
1841
1942
Note: need to use the SDPBackend's custom kernel for sdpa (scalable
2043
dot product attention) because the default sdpa kernel used in the
@@ -36,28 +59,67 @@ def export_mini_phi3_lora(model) -> None:
3659
executorch_program = edge_program.to_executorch()
3760

3861
# 4. Save the compiled .pte program.
39-
print("Saving to mini_phi3_lora.pte")
40-
with open("mini_phi3_lora.pte", "wb") as file:
62+
print("Saving to phi3_mini_lora.pte")
63+
with open("phi3_mini_lora.pte", "wb") as file:
64+
file.write(executorch_program.buffer)
65+
66+
print("Done.")
67+
68+
69+
def export_phi3_mini_lora_training(model) -> None:
70+
"""
71+
Export the example phi3-mini with LoRA model to executorch for training, only.
72+
"""
73+
print("Exporting phi3-mini with LoRA for training")
74+
# 1. torch.export: Defines the program with the ATen operator set.
75+
print("Exporting to aten dialect")
76+
example_args = (randint(0, 100, (1, 100), dtype=long),)
77+
with sdpa_kernel([SDPBackend.MATH]):
78+
exported_graph: ExportedProgram = export(model, example_args)
79+
print("Creating a joint forward-backwards graph for training")
80+
joint_graph = _export_forward_backward(exported_graph)
81+
82+
# 2. to_edge: Make optimizations for Edge devices.
83+
print("Lowering to edge dialect")
84+
edge_program = to_edge(joint_graph)
85+
86+
print(edge_program._edge_programs["forward"].graph_module)
87+
88+
# 3. to_executorch: Convert the graph to an ExecuTorch program.
89+
print("Exporting to executorch")
90+
executorch_program = edge_program.to_executorch()
91+
92+
# 4. Save the compiled .pte program.
93+
print("Saving to phi3_mini_lora_training.pte")
94+
with open("phi3_mini_lora_training.pte", "wb") as file:
4195
file.write(executorch_program.buffer)
4296

4397
print("Done.")
4498

4599

46-
def run_mini_phi3_lora(model) -> Tensor:
100+
def run_phi3_mini_lora(model) -> Tensor:
47101
"""Run the model and return the result."""
48-
args = zeros([3072, 1], dtype=int64)
102+
# Input shape: (batch_size, seq_len).
103+
args = zeros((1, 10), dtype=int64)
49104
model.eval()
50105
res = model(args)
51106
return res
52107

53108

54109
def main() -> None:
55-
mini_lora_model = lora_phi3_mini(
110+
print("Main")
111+
lora_model = lora_phi3_mini(
56112
lora_attn_modules=[
57113
"q_proj",
58114
]
59115
)
60-
export_mini_phi3_lora(mini_lora_model)
116+
117+
# Export for inference.
118+
export_phi3_mini_lora(lora_model)
119+
120+
# Export for training.
121+
lora_training_model = TrainingModule(lora_model, torch.nn.CrossEntropyLoss())
122+
export_phi3_mini_lora_training(lora_training_model)
61123

62124

63125
if __name__ == "__main__":

0 commit comments

Comments
 (0)