@@ -28,20 +28,19 @@ def __init__(self):
28
28
def forward (self , x , y ):
29
29
return self .loss (self .linear (x ).softmax (dim = 0 ), y )
30
30
31
- def get_random_inputs (self ):
32
- return (torch .randn ( 3 ), torch .tensor ([1.0 , 0.0 , 0.0 ]))
31
+ def get_inputs (self ):
32
+ return (torch .ones ( 3 , dtype = torch . float32 ), torch .tensor ([1.0 , 0.0 , 0.0 ]))
33
33
34
34
def test (self ):
35
35
m = self .ModuleSimpleTrain ()
36
- ep = torch .export .export (m , m .get_random_inputs (), strict = True )
36
+ ep = torch .export .export (m , m .get_inputs (), strict = True )
37
37
ep = _export_forward_backward (ep )
38
38
ep = to_edge (ep )
39
39
ep = ep .to_executorch ()
40
40
buffer = ep .buffer
41
41
tm = _load_for_executorch_for_training_from_buffer (buffer )
42
42
43
- tm .forward_backward ("forward" , m .get_random_inputs ())
44
- orig_param = list (tm .named_parameters ().values ())[0 ].clone ()
43
+ orig_loss = tm .forward_backward ("forward" , m .get_inputs ())
45
44
optimizer = get_sgd_optimizer (
46
45
tm .named_parameters (),
47
46
0.1 ,
@@ -50,7 +49,19 @@ def test(self):
50
49
0 ,
51
50
False ,
52
51
)
52
+
53
+ cloned_params = list (tm .named_parameters ().values ())
54
+ cloned_params = [p .clone () for p in cloned_params ]
55
+
53
56
optimizer .step (tm .named_gradients ())
54
- self .assertFalse (
55
- torch .allclose (orig_param , list (tm .named_parameters ().values ())[0 ])
56
- )
57
+
58
+ # The python module caches the param tensors after the first
59
+ # inference. So this doesn't test if the params are actually
60
+ # updated in cpp world.
61
+ for p , cloned_p in zip (tm .named_parameters ().values (), cloned_params ):
62
+ self .assertFalse (torch .allclose (p , cloned_p ))
63
+
64
+ # Test that the params actually changed in cpp by running against
65
+ # the same inputs again and seeing that the loss is different.
66
+ second_loss = tm .forward_backward ("forward" , m .get_inputs ())
67
+ self .assertFalse (torch .allclose (orig_loss [0 ], second_loss [0 ]))
0 commit comments