@@ -130,11 +130,11 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
130
130
aten_dialect : ExportedProgram = export (f , example_args )
131
131
132
132
# Works correctly
133
- print (aten_dialect (torch .ones (3 , 3 ), torch .ones (3 , 3 )))
133
+ print (aten_dialect . module () (torch .ones (3 , 3 ), torch .ones (3 , 3 )))
134
134
135
135
# Errors
136
136
try :
137
- print (aten_dialect (torch .ones (3 , 2 ), torch .ones (3 , 2 )))
137
+ print (aten_dialect . module () (torch .ones (3 , 2 ), torch .ones (3 , 2 )))
138
138
except Exception :
139
139
tb .print_exc ()
140
140
@@ -175,18 +175,18 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
175
175
# Now let's try running the model with different shapes:
176
176
177
177
# Works correctly
178
- print (aten_dialect (torch .ones (3 , 3 ), torch .ones (3 , 3 )))
179
- print (aten_dialect (torch .ones (3 , 2 ), torch .ones (3 , 2 )))
178
+ print (aten_dialect . module () (torch .ones (3 , 3 ), torch .ones (3 , 3 )))
179
+ print (aten_dialect . module () (torch .ones (3 , 2 ), torch .ones (3 , 2 )))
180
180
181
181
# Errors because it violates our constraint that input 0, dim 1 <= 10
182
182
try :
183
- print (aten_dialect (torch .ones (3 , 15 ), torch .ones (3 , 15 )))
183
+ print (aten_dialect . module () (torch .ones (3 , 15 ), torch .ones (3 , 15 )))
184
184
except Exception :
185
185
tb .print_exc ()
186
186
187
187
# Errors because it violates our constraint that input 0, dim 1 == input 1, dim 1
188
188
try :
189
- print (aten_dialect (torch .ones (3 , 3 ), torch .ones (3 , 2 )))
189
+ print (aten_dialect . module () (torch .ones (3 , 3 ), torch .ones (3 , 2 )))
190
190
except Exception :
191
191
tb .print_exc ()
192
192
@@ -287,23 +287,25 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
287
287
# there is only one program, it will by default be saved to the name "forward".
288
288
289
289
290
- def encode (x ):
291
- return torch .nn .functional .linear (x , torch .randn (5 , 10 ))
290
+ class Encode (torch .nn .Module ):
291
+ def forward (self , x ):
292
+ return torch .nn .functional .linear (x , torch .randn (5 , 10 ))
292
293
293
294
294
- def decode (x ):
295
- return torch .nn .functional .linear (x , torch .randn (10 , 5 ))
295
+ class Decode (torch .nn .Module ):
296
+ def forward (self , x ):
297
+ return torch .nn .functional .linear (x , torch .randn (10 , 5 ))
296
298
297
299
298
300
encode_args = (torch .randn (1 , 10 ),)
299
301
aten_encode : ExportedProgram = export (
300
- capture_pre_autograd_graph (encode , encode_args ),
302
+ capture_pre_autograd_graph (Encode () , encode_args ),
301
303
encode_args ,
302
304
)
303
305
304
306
decode_args = (torch .randn (1 , 5 ),)
305
307
aten_decode : ExportedProgram = export (
306
- capture_pre_autograd_graph (decode , decode_args ),
308
+ capture_pre_autograd_graph (Decode () , decode_args ),
307
309
decode_args ,
308
310
)
309
311
@@ -486,17 +488,18 @@ def forward(self, x):
486
488
# ``LoweredBackendModule`` for each of those subgraphs.
487
489
488
490
489
- def f (a , x , b ):
490
- y = torch .mm (a , x )
491
- z = y + b
492
- a = z - a
493
- y = torch .mm (a , x )
494
- z = y + b
495
- return z
491
+ class Foo (torch .nn .Module ):
492
+ def forward (self , a , x , b ):
493
+ y = torch .mm (a , x )
494
+ z = y + b
495
+ a = z - a
496
+ y = torch .mm (a , x )
497
+ z = y + b
498
+ return z
496
499
497
500
498
501
example_args = (torch .randn (2 , 2 ), torch .randn (2 , 2 ), torch .randn (2 , 2 ))
499
- pre_autograd_aten_dialect = capture_pre_autograd_graph (f , example_args )
502
+ pre_autograd_aten_dialect = capture_pre_autograd_graph (Foo () , example_args )
500
503
aten_dialect : ExportedProgram = export (pre_autograd_aten_dialect , example_args )
501
504
edge_program : EdgeProgramManager = to_edge (aten_dialect )
502
505
exported_program = edge_program .exported_program ()
@@ -520,17 +523,18 @@ def f(a, x, b):
520
523
# call ``to_backend`` on it:
521
524
522
525
523
- def f (a , x , b ):
524
- y = torch .mm (a , x )
525
- z = y + b
526
- a = z - a
527
- y = torch .mm (a , x )
528
- z = y + b
529
- return z
526
+ class Foo (torch .nn .Module ):
527
+ def forward (self , a , x , b ):
528
+ y = torch .mm (a , x )
529
+ z = y + b
530
+ a = z - a
531
+ y = torch .mm (a , x )
532
+ z = y + b
533
+ return z
530
534
531
535
532
536
example_args = (torch .randn (2 , 2 ), torch .randn (2 , 2 ), torch .randn (2 , 2 ))
533
- pre_autograd_aten_dialect = capture_pre_autograd_graph (f , example_args )
537
+ pre_autograd_aten_dialect = capture_pre_autograd_graph (Foo () , example_args )
534
538
aten_dialect : ExportedProgram = export (pre_autograd_aten_dialect , example_args )
535
539
edge_program : EdgeProgramManager = to_edge (aten_dialect )
536
540
exported_program = edge_program .exported_program ()
0 commit comments