@@ -156,7 +156,7 @@ def test_edge_manager_transform(self):
156
156
)
157
157
158
158
original_res = edge_manager .exported_program ("forward" )(
159
- torch .ones (1 ), torch .ones (1 ), torch . ones ( 1 )
159
+ torch .ones (1 ), torch .ones (1 )
160
160
)
161
161
162
162
# perform transformation
@@ -173,17 +173,13 @@ def test_edge_manager_transform(self):
173
173
174
174
# transformation was applied
175
175
self .assertEqual (
176
- transformed_edge .exported_program ("forward" )(
177
- torch .ones (1 ), torch .ones (1 ), torch .ones (1 )
178
- ),
176
+ transformed_edge .exported_program ("forward" )(torch .ones (1 ), torch .ones (1 )),
179
177
torch .ones (1 ), # x * y * x
180
178
)
181
179
182
180
# original unchanged
183
181
self .assertEqual (
184
- edge_manager .exported_program ("forward" )(
185
- torch .ones (1 ), torch .ones (1 ), torch .ones (1 )
186
- ),
182
+ edge_manager .exported_program ("forward" )(torch .ones (1 ), torch .ones (1 )),
187
183
original_res , # x * y + x
188
184
)
189
185
@@ -199,9 +195,7 @@ def test_transform_dict_api(self):
199
195
)
200
196
201
197
self .assertEqual (
202
- transformed_edge .exported_program ("forward" )(
203
- torch .ones (1 ), torch .ones (1 ), torch .ones (1 )
204
- ),
198
+ transformed_edge .exported_program ("forward" )(torch .ones (1 ), torch .ones (1 )),
205
199
torch .ones (1 ), # x * y * x
206
200
)
207
201
@@ -222,7 +216,7 @@ def test_edge_to_backend_replaces_subgraph(self):
222
216
223
217
forward_program = delegate_manager .exported_program ("forward" )
224
218
self .assertEqual (
225
- forward_program (torch .ones (1 ), torch .ones (1 ), torch . ones ( 1 ) ),
219
+ forward_program (torch .ones (1 ), torch .ones (1 )),
226
220
torch .ones (1 ) + 1 , # x * y + x
227
221
)
228
222
0 commit comments