Skip to content

Commit a35d1ac

Browse files
zhxchen17facebook-github-bot
authored andcommitted
Fix forward unittests.
Summary: from D49482310 Reviewed By: JacobSzwejbka, SherlockNoMad Differential Revision: D49547858 fbshipit-source-id: fb4c3b397174c585394511d7baf097fd7023c441
1 parent 167b72d commit a35d1ac

File tree

1 file changed

+5
-11
lines changed

1 file changed

+5
-11
lines changed

exir/program/test/test_program.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def test_edge_manager_transform(self):
156156
)
157157

158158
original_res = edge_manager.exported_program("forward")(
159-
torch.ones(1), torch.ones(1), torch.ones(1)
159+
torch.ones(1), torch.ones(1)
160160
)
161161

162162
# perform transformation
@@ -173,17 +173,13 @@ def test_edge_manager_transform(self):
173173

174174
# transformation was applied
175175
self.assertEqual(
176-
transformed_edge.exported_program("forward")(
177-
torch.ones(1), torch.ones(1), torch.ones(1)
178-
),
176+
transformed_edge.exported_program("forward")(torch.ones(1), torch.ones(1)),
179177
torch.ones(1), # x * y * x
180178
)
181179

182180
# original unchanged
183181
self.assertEqual(
184-
edge_manager.exported_program("forward")(
185-
torch.ones(1), torch.ones(1), torch.ones(1)
186-
),
182+
edge_manager.exported_program("forward")(torch.ones(1), torch.ones(1)),
187183
original_res, # x * y + x
188184
)
189185

@@ -199,9 +195,7 @@ def test_transform_dict_api(self):
199195
)
200196

201197
self.assertEqual(
202-
transformed_edge.exported_program("forward")(
203-
torch.ones(1), torch.ones(1), torch.ones(1)
204-
),
198+
transformed_edge.exported_program("forward")(torch.ones(1), torch.ones(1)),
205199
torch.ones(1), # x * y * x
206200
)
207201

@@ -222,7 +216,7 @@ def test_edge_to_backend_replaces_subgraph(self):
222216

223217
forward_program = delegate_manager.exported_program("forward")
224218
self.assertEqual(
225-
forward_program(torch.ones(1), torch.ones(1), torch.ones(1)),
219+
forward_program(torch.ones(1), torch.ones(1)),
226220
torch.ones(1) + 1, # x * y + x
227221
)
228222

0 commit comments

Comments
 (0)