Skip to content

Commit 01974e0

Browse files
committed
Lunwen comments
1 parent d4a335c commit 01974e0

File tree

2 files changed

+22
-13
lines changed

2 files changed

+22
-13
lines changed

examples/models/phi-3-mini/export_phi-3-mini.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
1717
from executorch.exir import to_edge
1818
from torch._export import capture_pre_autograd_graph
19+
from torch.export.experimental import _export_forward_backward
1920
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
2021

2122
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
@@ -80,9 +81,10 @@ def main(args) -> None:
8081
strict=False,
8182
pre_dispatch=False,
8283
)
84+
joint_graph = _export_forward_backward(model)
8385

8486
edge_config = get_xnnpack_edge_compile_config()
85-
edge_manager = to_edge(model, compile_config=edge_config)
87+
edge_manager = to_edge(joint_graph, compile_config=edge_config)
8688
edge_manager = edge_manager.to_backend(XnnpackPartitioner(has_dynamic_shapes=True))
8789
et_program = edge_manager.to_executorch()
8890

examples/models/phi3-mini-lora/export_model.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
from torchtune.models.phi3._model_builders import lora_phi3_mini
1414

1515
class TrainingModule(torch.nn.Module):
16+
"""
17+
The model being trained should return the loss from forward(). This
18+
class wraps the actual phi3-mini model and calculates an arbitrary
19+
loss with its forward() output.
20+
"""
1621
def __init__(self, model, loss):
1722
super().__init__()
1823
self.model = model
@@ -21,14 +26,14 @@ def __init__(self, model, loss):
2126
def forward(self, input):
2227
# Output is of the shape (seq_len, vocab_size).
2328
output = self.model(input)
24-
# Vocab size of 32064 is taken from the phi3 model itself.
29+
# 32064 the vocab size of the phi3-mini model.
2530
target = zeros((1, 32064), dtype=long)
2631
return self.loss(output, target)
2732

2833
@no_grad()
29-
def export_mini_phi3_lora(model) -> None:
34+
def export_phi3_mini_lora(model) -> None:
3035
"""
31-
Export the example mini-phi3 with LoRA model to executorch.
36+
Export the example phi3-mini with LoRA model to executorch.
3237
3338
Note: need to use the SDPBackend's custom kernel for sdpa (scalable
3439
dot product attention) because the default sdpa kernel used in the
@@ -50,15 +55,15 @@ def export_mini_phi3_lora(model) -> None:
5055
executorch_program = edge_program.to_executorch()
5156

5257
# 4. Save the compiled .pte program.
53-
print("Saving to mini_phi3_lora.pte")
54-
with open("mini_phi3_lora.pte", "wb") as file:
58+
print("Saving to phi3_mini_lora.pte")
59+
with open("phi3_mini_lora.pte", "wb") as file:
5560
file.write(executorch_program.buffer)
5661

5762
print("Done.")
5863

59-
def export_mini_phi3_lora_training(model) -> None:
64+
def export_phi3_mini_lora_training(model) -> None:
6065
"""
61-
Export the example mini-phi3 with LoRA model to executorch for training, only.
66+
Export the example phi3-mini with LoRA model to executorch for training, only.
6267
"""
6368
print("Exporting mini phi3 with LoRA for training")
6469
# 1. torch.export: Defines the program with the ATen operator set.
@@ -73,19 +78,21 @@ def export_mini_phi3_lora_training(model) -> None:
7378
print("Lowering to edge dialect")
7479
edge_program = to_edge(joint_graph)
7580

81+
print(edge_program._edge_programs["forward"].graph_module)
82+
7683
# 3. to_executorch: Convert the graph to an ExecuTorch program.
7784
print("Exporting to executorch")
7885
executorch_program = edge_program.to_executorch()
7986

8087
# 4. Save the compiled .pte program.
81-
print("Saving to mini_phi3_lora_training.pte")
82-
with open("mini_phi3_lora_training.pte", "wb") as file:
88+
print("Saving to phi3_mini_lora_training.pte")
89+
with open("phi3_mini_lora_training.pte", "wb") as file:
8390
file.write(executorch_program.buffer)
8491

8592
print("Done.")
8693

8794

88-
def run_mini_phi3_lora(model) -> Tensor:
95+
def run_phi3_mini_lora(model) -> Tensor:
8996
"""Run the model and return the result."""
9097
# Input shape: (batch_size, seq_len).
9198
args = zeros((1, 10), dtype=int64)
@@ -103,11 +110,11 @@ def main() -> None:
103110
)
104111

105112
# Export for inference.
106-
export_mini_phi3_lora(lora_model)
113+
export_phi3_mini_lora(lora_model)
107114

108115
# Export for training.
109116
lora_training_model = TrainingModule(lora_model, torch.nn.CrossEntropyLoss())
110-
export_mini_phi3_lora_training(lora_training_model)
117+
export_phi3_mini_lora_training(lora_training_model)
111118

112119

113120
if __name__ == "__main__":

0 commit comments

Comments
 (0)