Skip to content

Commit 643c381

Browse files
Fix TrainingModule Parameter Bug
Differential Revision: D69568035 Pull Request resolved: #8443
1 parent 9ba5494 commit 643c381

File tree

2 files changed

+27
-16
lines changed

2 files changed

+27
-16
lines changed

extension/training/pybindings/_training_module.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,20 +41,20 @@ def forward_backward(self, method_name: str, inputs: Sequence[Any]) -> List[Any]
4141
self.parameters_method_prefix + method_name, ()
4242
)[0]
4343

44-
full_outputs = self.model.run_method(method_name, inputs)
44+
# Important that the outputs are not cloned because we need the optimizer to
45+
# be able to mutate the actual weights and not clones of them.
46+
full_outputs = self.model.run_method(method_name, inputs, clone_outputs=False)
4547

4648
user_outs = full_outputs[:grad_start_idx]
49+
user_outs = [x.clone() for x in user_outs]
4750
grads = full_outputs[grad_start_idx:params_start_idx]
48-
params = full_outputs[params_start_idx:]
51+
grads = [grad.clone() for grad in grads]
4952

50-
# Important that the outputs are not cloned because we need the optimizer to
51-
# be able to mutate the actual weights and not clones of them.
52-
fqn = self.model.run_method(
53-
self.fqn_method_prefix + method_name, (), clone_outputs=False
54-
)
53+
fqn = self.model.run_method(self.fqn_method_prefix + method_name, ())
5554

5655
self.named_grads = dict(zip(fqn, grads))
5756
if self.named_params is None:
57+
params = full_outputs[params_start_idx:]
5858
self.named_params = dict(zip(fqn, params))
5959

6060
return user_outs
@@ -65,7 +65,7 @@ def named_gradients(self) -> Dict[str, Tensor]:
6565
return self.named_grads
6666

6767
def named_parameters(self) -> Dict[str, Tensor]:
68-
if self.named_grads is None:
68+
if self.named_params is None:
6969
raise RuntimeError(
7070
"Must call forward_backward before named_params. This will be fixed in a later version"
7171
)

extension/training/pybindings/test/test.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,19 @@ def __init__(self):
2828
def forward(self, x, y):
2929
return self.loss(self.linear(x).softmax(dim=0), y)
3030

31-
def get_random_inputs(self):
32-
return (torch.randn(3), torch.tensor([1.0, 0.0, 0.0]))
31+
def get_inputs(self):
32+
return (torch.ones(3, dtype=torch.float32), torch.tensor([1.0, 0.0, 0.0]))
3333

3434
def test(self):
3535
m = self.ModuleSimpleTrain()
36-
ep = torch.export.export(m, m.get_random_inputs(), strict=True)
36+
ep = torch.export.export(m, m.get_inputs(), strict=True)
3737
ep = _export_forward_backward(ep)
3838
ep = to_edge(ep)
3939
ep = ep.to_executorch()
4040
buffer = ep.buffer
4141
tm = _load_for_executorch_for_training_from_buffer(buffer)
4242

43-
tm.forward_backward("forward", m.get_random_inputs())
44-
orig_param = list(tm.named_parameters().values())[0].clone()
43+
orig_loss = tm.forward_backward("forward", m.get_inputs())
4544
optimizer = get_sgd_optimizer(
4645
tm.named_parameters(),
4746
0.1,
@@ -50,7 +49,19 @@ def test(self):
5049
0,
5150
False,
5251
)
52+
53+
cloned_params = list(tm.named_parameters().values())
54+
cloned_params = [p.clone() for p in cloned_params]
55+
5356
optimizer.step(tm.named_gradients())
54-
self.assertFalse(
55-
torch.allclose(orig_param, list(tm.named_parameters().values())[0])
56-
)
57+
58+
# The python module caches the param tensors after the first
59+
# inference. So this doesn't test if the params are actually
60+
# updated in cpp world.
61+
for p, cloned_p in zip(tm.named_parameters().values(), cloned_params):
62+
self.assertFalse(torch.allclose(p, cloned_p))
63+
64+
# Test that the params actually changed in cpp by running against
65+
# the same inputs again and seeing that the loss is different.
66+
second_loss = tm.forward_backward("forward", m.get_inputs())
67+
self.assertFalse(torch.allclose(orig_loss[0], second_loss[0]))

0 commit comments

Comments
 (0)