Skip to content

Commit 3f77078

Browse files
serialize fqns.
Differential Revision: D61864979 Pull Request resolved: #4931
1 parent 1cea0ee commit 3f77078

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

exir/emit/_emit_program.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,20 +84,26 @@ def _remove_non_user_outputs(exported_program: ExportedProgram) -> torch.fx.Grap
8484
def _get_training_metadata(methods: Dict[str, ExportedProgram]) -> Dict[str, int]:
8585
gradients_method_prefix = "__et_training_gradients_index_"
8686
parameters_method_prefix = "__et_training_parameters_index_"
87+
fqn_method_prefix = "__et_training_fqn_"
8788
training_metadata = {}
8889
for name, method in methods.items():
8990
found_grad = False
9091
found_param = False
92+
fqns = []
9193
i = 0
9294
for output_spec in method.graph_signature.output_specs:
93-
if output_spec.kind == OutputKind.GRADIENT_TO_PARAMETER and not found_grad:
94-
training_metadata[gradients_method_prefix + name] = i
95-
found_grad = True
95+
if output_spec.kind == OutputKind.GRADIENT_TO_PARAMETER:
96+
if not found_grad:
97+
training_metadata[gradients_method_prefix + name] = i
98+
found_grad = True
99+
fqns.append(output_spec.target)
96100
elif output_spec.kind == OutputKind.TOKEN and not found_param:
97101
assert found_grad # Params must come after gradients
98102
training_metadata[parameters_method_prefix + name] = i
99103
found_param = True
100104
i += 1
105+
if len(fqns) > 0:
106+
training_metadata[fqn_method_prefix + name] = fqns
101107
return training_metadata
102108

103109

exir/tests/test_joint_graph.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def forward(self, x, y):
110110
self.assertTrue(torch.allclose(m.linear.bias, et_outputs[4]))
111111

112112
self.assertEqual(
113-
len(et.executorch_program.execution_plan), 3
113+
len(et.executorch_program.execution_plan), 4
114114
) # forward + 2 training metadata functions
115115

116116
# gradient outputs start at index 1
@@ -121,10 +121,17 @@ def forward(self, x, y):
121121
1,
122122
)
123123

124-
# parameter outputs start at index 3
125124
self.assertEqual(
126125
et.executorch_program.execution_plan[2] # pyre-ignore
127126
.values[0]
127+
.val.string_val,
128+
"linear.weight",
129+
)
130+
131+
# parameter outputs start at index 3
132+
self.assertEqual(
133+
et.executorch_program.execution_plan[3] # pyre-ignore
134+
.values[0]
128135
.val.int_val,
129136
3,
130137
)

0 commit comments

Comments
 (0)