Skip to content

Commit 4bab1f2

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Lower phi3 mini with LoRA to edge for training (#4722)
Summary: Exports phi3-mini with LoRA for training. The process involves AOT Autograd tracing a backward graphs which is combined with the forward into a joint graph, which is then finally lowered to Executorch. Differential Revision: D61309917 Pulled By: dvorjackz
1 parent 7b795d7 commit 4bab1f2

File tree

2 files changed

+72
-11
lines changed

2 files changed

+72
-11
lines changed

examples/models/phi3-mini-lora/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
## Summary
2-
In this exmaple, we export a model ([phi-3-mini](https://github.com/pytorch/executorch/tree/main/examples/models/phi-3-mini)) appended with additional LoRA layers to ExecuTorch.
2+
In this example, we export to Executorch a model ([phi-3-mini](https://github.com/pytorch/executorch/tree/main/examples/models/phi-3-mini)) appended with attention and mlp LoRA layers. The model is exported to Executorch for both inference and training. Note: the exported training model can only train at the moment.
33

44
## Instructions
55
### Step 1: [Optional] Install ExecuTorch dependencies
@@ -9,12 +9,12 @@ In this exmaple, we export a model ([phi-3-mini](https://github.com/pytorch/exec
99
- `./examples/models/phi3-mini-lora/install_requirements.sh`
1010

1111
### Step 3: Export and run the model
12-
1. Export the model to ExecuTorch.
12+
1. Export the inferenace and training models to ExecuTorch.
1313
```
1414
python export_model.py
1515
```
1616

17-
2. Run the model using an example runtime. For more detailed steps on this, check out [Build & Run](https://pytorch.org/executorch/stable/getting-started-setup.html#build-run).
17+
2. Run the inference model using an example runtime. For more detailed steps on this, check out [Build & Run](https://pytorch.org/executorch/stable/getting-started-setup.html#build-run).
1818
```
1919
# Clean and configure the CMake build system. Compiled programs will appear in the executorch/cmake-out directory we create here.
2020
(rm -rf cmake-out && mkdir cmake-out && cd cmake-out && cmake ..)

examples/models/phi3-mini-lora/export_model.py

Lines changed: 69 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,39 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import torch
78
from executorch.exir import to_edge
89
from torch import int64, long, no_grad, randint, Tensor, zeros
910
from torch.export import export, ExportedProgram
11+
from torch.export.experimental import _export_forward_backward
1012
from torch.nn.attention import sdpa_kernel, SDPBackend
1113
from torchtune.models.phi3._model_builders import lora_phi3_mini
1214

1315

16+
class TrainingModule(torch.nn.Module):
17+
"""
18+
The model being trained should return the loss from forward(). This
19+
class wraps the actual phi3-mini model and calculates an arbitrary
20+
loss with its forward() output.
21+
"""
22+
23+
def __init__(self, model, loss):
24+
super().__init__()
25+
self.model = model
26+
self.loss = loss
27+
28+
def forward(self, input):
29+
# Output is of the shape (seq_len, vocab_size).
30+
output = self.model(input)
31+
# 32064 the vocab size of the phi3-mini model.
32+
target = zeros((1, 32064), dtype=long)
33+
return self.loss(output, target)
34+
35+
1436
@no_grad()
15-
def export_mini_phi3_lora(model) -> None:
37+
def export_phi3_mini_lora(model) -> None:
1638
"""
17-
Export the example mini-phi3 with LoRA model to executorch.
39+
Export the example phi3-mini with LoRA model to executorch.
1840
1941
Note: need to use the SDPBackend's custom kernel for sdpa (scalable
2042
dot product attention) because the default sdpa kernel used in the
@@ -36,28 +58,67 @@ def export_mini_phi3_lora(model) -> None:
3658
executorch_program = edge_program.to_executorch()
3759

3860
# 4. Save the compiled .pte program.
39-
print("Saving to mini_phi3_lora.pte")
40-
with open("mini_phi3_lora.pte", "wb") as file:
61+
print("Saving to phi3_mini_lora.pte")
62+
with open("phi3_mini_lora.pte", "wb") as file:
4163
file.write(executorch_program.buffer)
4264

4365
print("Done.")
4466

4567

46-
def run_mini_phi3_lora(model) -> Tensor:
68+
def export_phi3_mini_lora_training(model) -> None:
69+
"""
70+
Export the example phi3-mini with LoRA model to executorch for training, only.
71+
"""
72+
print("Exporting mini phi3 with LoRA for training")
73+
# 1. torch.export: Defines the program with the ATen operator set.
74+
print("Exporting to aten dialect")
75+
example_args = (randint(0, 100, (1, 100), dtype=long),)
76+
with sdpa_kernel([SDPBackend.MATH]):
77+
exported_graph: ExportedProgram = export(model, example_args)
78+
print("Creating a joint forward-backwards graph for training")
79+
joint_graph = _export_forward_backward(exported_graph)
80+
81+
# 2. to_edge: Make optimizations for Edge devices.
82+
print("Lowering to edge dialect")
83+
edge_program = to_edge(joint_graph)
84+
85+
print(edge_program._edge_programs["forward"].graph_module)
86+
87+
# 3. to_executorch: Convert the graph to an ExecuTorch program.
88+
print("Exporting to executorch")
89+
executorch_program = edge_program.to_executorch()
90+
91+
# 4. Save the compiled .pte program.
92+
print("Saving to phi3_mini_lora_training.pte")
93+
with open("phi3_mini_lora_training.pte", "wb") as file:
94+
file.write(executorch_program.buffer)
95+
96+
print("Done.")
97+
98+
99+
def run_phi3_mini_lora(model) -> Tensor:
47100
"""Run the model and return the result."""
48-
args = zeros([3072, 1], dtype=int64)
101+
# Input shape: (batch_size, seq_len).
102+
args = zeros((1, 10), dtype=int64)
49103
model.eval()
50104
res = model(args)
51105
return res
52106

53107

54108
def main() -> None:
55-
mini_lora_model = lora_phi3_mini(
109+
print("Main")
110+
lora_model = lora_phi3_mini(
56111
lora_attn_modules=[
57112
"q_proj",
58113
]
59114
)
60-
export_mini_phi3_lora(mini_lora_model)
115+
116+
# Export for inference.
117+
export_phi3_mini_lora(lora_model)
118+
119+
# Export for training.
120+
lora_training_model = TrainingModule(lora_model, torch.nn.CrossEntropyLoss())
121+
export_phi3_mini_lora_training(lora_training_model)
61122

62123

63124
if __name__ == "__main__":

0 commit comments

Comments
 (0)