13
13
from torchtune .models .phi3 ._model_builders import lora_phi3_mini
14
14
15
15
class TrainingModule (torch .nn .Module ):
16
+ """
17
+ The model being trained should return the loss from forward(). This
18
+ class wraps the actual phi3-mini model and calculates an arbitrary
19
+ loss with its forward() output.
20
+ """
16
21
def __init__ (self , model , loss ):
17
22
super ().__init__ ()
18
23
self .model = model
@@ -21,14 +26,14 @@ def __init__(self, model, loss):
21
26
def forward (self , input ):
22
27
# Output is of the shape (seq_len, vocab_size).
23
28
output = self .model (input )
24
- # Vocab size of 32064 is taken from the phi3 model itself .
29
+ # 32064 the vocab size of the phi3-mini model.
25
30
target = zeros ((1 , 32064 ), dtype = long )
26
31
return self .loss (output , target )
27
32
28
33
@no_grad ()
29
- def export_mini_phi3_lora (model ) -> None :
34
+ def export_phi3_mini_lora (model ) -> None :
30
35
"""
31
- Export the example mini- phi3 with LoRA model to executorch.
36
+ Export the example phi3-mini with LoRA model to executorch.
32
37
33
38
Note: need to use the SDPBackend's custom kernel for sdpa (scalable
34
39
dot product attention) because the default sdpa kernel used in the
@@ -50,15 +55,15 @@ def export_mini_phi3_lora(model) -> None:
50
55
executorch_program = edge_program .to_executorch ()
51
56
52
57
# 4. Save the compiled .pte program.
53
- print ("Saving to mini_phi3_lora .pte" )
54
- with open ("mini_phi3_lora .pte" , "wb" ) as file :
58
+ print ("Saving to phi3_mini_lora .pte" )
59
+ with open ("phi3_mini_lora .pte" , "wb" ) as file :
55
60
file .write (executorch_program .buffer )
56
61
57
62
print ("Done." )
58
63
59
- def export_mini_phi3_lora_training (model ) -> None :
64
+ def export_phi3_mini_lora_training (model ) -> None :
60
65
"""
61
- Export the example mini- phi3 with LoRA model to executorch for training, only.
66
+ Export the example phi3-mini with LoRA model to executorch for training, only.
62
67
"""
63
68
print ("Exporting mini phi3 with LoRA for training" )
64
69
# 1. torch.export: Defines the program with the ATen operator set.
@@ -73,19 +78,21 @@ def export_mini_phi3_lora_training(model) -> None:
73
78
print ("Lowering to edge dialect" )
74
79
edge_program = to_edge (joint_graph )
75
80
81
+ print (edge_program ._edge_programs ["forward" ].graph_module )
82
+
76
83
# 3. to_executorch: Convert the graph to an ExecuTorch program.
77
84
print ("Exporting to executorch" )
78
85
executorch_program = edge_program .to_executorch ()
79
86
80
87
# 4. Save the compiled .pte program.
81
- print ("Saving to mini_phi3_lora_training .pte" )
82
- with open ("mini_phi3_lora_training .pte" , "wb" ) as file :
88
+ print ("Saving to phi3_mini_lora_training .pte" )
89
+ with open ("phi3_mini_lora_training .pte" , "wb" ) as file :
83
90
file .write (executorch_program .buffer )
84
91
85
92
print ("Done." )
86
93
87
94
88
- def run_mini_phi3_lora (model ) -> Tensor :
95
+ def run_phi3_mini_lora (model ) -> Tensor :
89
96
"""Run the model and return the result."""
90
97
# Input shape: (batch_size, seq_len).
91
98
args = zeros ((1 , 10 ), dtype = int64 )
@@ -103,11 +110,11 @@ def main() -> None:
103
110
)
104
111
105
112
# Export for inference.
106
- export_mini_phi3_lora (lora_model )
113
+ export_phi3_mini_lora (lora_model )
107
114
108
115
# Export for training.
109
116
lora_training_model = TrainingModule (lora_model , torch .nn .CrossEntropyLoss ())
110
- export_mini_phi3_lora_training (lora_training_model )
117
+ export_phi3_mini_lora_training (lora_training_model )
111
118
112
119
113
120
if __name__ == "__main__" :
0 commit comments