@@ -1012,17 +1012,19 @@ def false_fn(x, y):
1012
1012
x = x - y
1013
1013
return x
1014
1014
1015
- def f (x , y ):
1016
- x = x + y
1017
- x = control_flow .cond (x [0 ][0 ] == 1 , true_fn , false_fn , [x , y ])
1018
- x = x - y
1019
- return x
1015
+ class Module (torch .nn .Module ):
1016
+ def forward (self , x , y ):
1017
+ x = x + y
1018
+ x = control_flow .cond (x [0 ][0 ] == 1 , true_fn , false_fn , [x , y ])
1019
+ x = x - y
1020
+ return x
1020
1021
1022
+ f = Module ()
1021
1023
inputs = (torch .ones (2 , 2 ), torch .ones (2 , 2 ))
1022
1024
orig_res = f (* inputs )
1023
1025
orig = to_edge (
1024
1026
export (
1025
- torch . export . WrapperModule ( f ) ,
1027
+ f ,
1026
1028
inputs ,
1027
1029
)
1028
1030
)
@@ -1066,15 +1068,17 @@ def map_fn(x, y):
1066
1068
x = x + y
1067
1069
return x
1068
1070
1069
- def f (xs , y ):
1070
- y = torch .mm (y , y )
1071
- return control_flow .map (map_fn , xs , y )
1071
+ class Module (torch .nn .Module ):
1072
+ def forward (self , xs , y ):
1073
+ y = torch .mm (y , y )
1074
+ return control_flow .map (map_fn , xs , y )
1072
1075
1076
+ f = Module ()
1073
1077
inputs = (torch .ones (2 , 2 ), torch .ones (2 , 2 ))
1074
1078
orig_res = f (* inputs )
1075
1079
orig = to_edge (
1076
1080
export (
1077
- torch . export . WrapperModule ( f ) ,
1081
+ f ,
1078
1082
inputs ,
1079
1083
)
1080
1084
)
@@ -1132,9 +1136,10 @@ def map_fn(x, pred1, pred2, y):
1132
1136
x = x + y
1133
1137
return x .sin ()
1134
1138
1135
- def f (xs , pred1 , pred2 , y ):
1136
- y = torch .mm (y , y )
1137
- return control_flow .map (map_fn , xs , pred1 , pred2 , y )
1139
+ class Module (torch .nn .Module ):
1140
+ def forward (self , xs , pred1 , pred2 , y ):
1141
+ y = torch .mm (y , y )
1142
+ return control_flow .map (map_fn , xs , pred1 , pred2 , y )
1138
1143
1139
1144
inputs = (
1140
1145
torch .ones (2 , 2 ),
@@ -1143,10 +1148,11 @@ def f(xs, pred1, pred2, y):
1143
1148
torch .ones (2 , 2 ),
1144
1149
)
1145
1150
1151
+ f = Module ()
1146
1152
orig_res = f (* inputs )
1147
1153
orig = to_edge (
1148
1154
export (
1149
- torch . export . WrapperModule ( f ) ,
1155
+ f ,
1150
1156
inputs ,
1151
1157
)
1152
1158
)
@@ -1205,12 +1211,14 @@ def f(xs, pred1, pred2, y):
1205
1211
)
1206
1212
1207
1213
def test_list_input (self ):
1208
- def f (x : List [torch .Tensor ]):
1209
- y = x [0 ] + x [1 ]
1210
- return y
1214
+ class Module (torch .nn .Module ):
1215
+ def forward (self , x : List [torch .Tensor ]):
1216
+ y = x [0 ] + x [1 ]
1217
+ return y
1211
1218
1219
+ f = Module ()
1212
1220
inputs = ([torch .randn (2 , 2 ), torch .randn (2 , 2 )],)
1213
- edge_prog = to_edge (export (torch . export . WrapperModule ( f ) , inputs ))
1221
+ edge_prog = to_edge (export (f , inputs ))
1214
1222
lowered_gm = to_backend (
1215
1223
BackendWithCompilerDemo .__name__ , edge_prog .exported_program (), []
1216
1224
)
@@ -1227,12 +1235,14 @@ def forward(self, x: List[torch.Tensor]):
1227
1235
gm .exported_program ().module ()(* inputs )
1228
1236
1229
1237
def test_dict_input (self ):
1230
- def f (x : Dict [str , torch .Tensor ]):
1231
- y = x ["a" ] + x ["b" ]
1232
- return y
1238
+ class Module (torch .nn .Module ):
1239
+ def forward (self , x : Dict [str , torch .Tensor ]):
1240
+ y = x ["a" ] + x ["b" ]
1241
+ return y
1233
1242
1243
+ f = Module ()
1234
1244
inputs = ({"a" : torch .randn (2 , 2 ), "b" : torch .randn (2 , 2 )},)
1235
- edge_prog = to_edge (export (torch . export . WrapperModule ( f ) , inputs ))
1245
+ edge_prog = to_edge (export (f , inputs ))
1236
1246
lowered_gm = to_backend (
1237
1247
BackendWithCompilerDemo .__name__ , edge_prog .exported_program (), []
1238
1248
)
0 commit comments