Skip to content

Commit 41ee92d

Browse files
jackzhxngpytorchbot
authored andcommitted
Update phi-3-mini lora export code and readme (#5327)
Summary: Updated readme and example export code ahead of branch gut. Pull Request resolved: #5327 Test Plan: - Exported manually Reviewed By: JacobSzwejbka Differential Revision: D62623250 Pulled By: dvorjackz fbshipit-source-id: 79ee3ad1d42ae961d94d225ee1e642c5bc540127 (cherry picked from commit 08f16d0)
1 parent eecf74f commit 41ee92d

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

examples/models/phi-3-mini-lora/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
## Summary
2-
In this example, we export to ExecuTorch a model ([phi-3-mini](https://github.com/pytorch/executorch/tree/main/examples/models/phi-3-mini)) appended with attention and mlp LoRA layers. The model is exported to ExecuTorch for both inference and training. Note: the exported training model can only train at the moment.
2+
In this example, we showcase how to export a model ([phi-3-mini](https://github.com/pytorch/executorch/tree/main/examples/models/phi-3-mini)) appended with LoRA layers to ExecuTorch. The model is exported to ExecuTorch for both inference and training.
3+
4+
To see how you can use the model exported for training in a fully involved finetuning loop, please see our example on [LLM PTE Fintetuning](https://github.com/pytorch/executorch/tree/main/examples/llm_pte_finetuning).
35

46
## Instructions
57
### Step 1: [Optional] Install ExecuTorch dependencies

examples/models/phi-3-mini-lora/export_model.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,13 @@ def __init__(self, model, loss):
2828
self.model = model
2929
self.loss = loss
3030

31-
def forward(self, input):
31+
def forward(self, input: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
3232
# Output is of the shape (seq_len, vocab_size).
33-
output = self.model(input)
34-
target = zeros((1, vocab_size), dtype=long)
35-
return self.loss(output, target)
33+
logits = self.model(input)
34+
logits = logits[..., :-1, :].contiguous()
35+
labels = labels[..., 1:].contiguous()
36+
logits = logits.transpose(1, 2)
37+
return self.loss(logits, labels)
3638

3739

3840
@no_grad()
@@ -47,7 +49,11 @@ def export_phi3_mini_lora(model) -> None:
4749
model.eval()
4850
# 1. torch.export: Defines the program with the ATen operator set.
4951
print("Exporting to aten dialect")
50-
example_args = (randint(0, 100, (1, 100), dtype=long),)
52+
batch_size = 1
53+
vocab_size = 100
54+
seq_len = 10
55+
tokens = randint(0, vocab_size, (batch_size, seq_len), dtype=long)
56+
example_args = (tokens,)
5157
with sdpa_kernel([SDPBackend.MATH]):
5258
aten_dialect: ExportedProgram = export(model, example_args)
5359

@@ -80,7 +86,12 @@ def export_phi3_mini_lora_training(model) -> None:
8086
print("Exporting phi3-mini with LoRA for training")
8187
# 1. torch.export: Defines the program with the ATen operator set.
8288
print("Exporting to aten dialect")
83-
example_args = (randint(0, 100, (1, 100), dtype=long),)
89+
batch_size = 1
90+
vocab_size = 100
91+
seq_len = 10
92+
tokens = randint(0, vocab_size, (batch_size, seq_len), dtype=long)
93+
labels = tokens
94+
example_args = (tokens, labels)
8495
with sdpa_kernel([SDPBackend.MATH]):
8596
exported_graph: ExportedProgram = export(model, example_args)
8697
print("Creating a joint forward-backwards graph for training")

0 commit comments

Comments
 (0)