Skip to content

Commit 36f9903

Browse files
Cleaned up test.
1 parent 0f1bc67 commit 36f9903

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

exir/tests/test_passes.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,14 +1077,15 @@ def forward(self) -> torch.Tensor:
10771077
FileCheck().check("_lifted_tensor_constant1").check(
10781078
"b_a" # followed by the buffer input.
10791079
).run(ep.graph_module.code)
1080+
10801081
# the graph signature should also be the same:
1081-
assert ep.graph_signature.input_specs[0].arg.name == "_lifted_tensor_constant1"
1082-
assert ep.graph_signature.input_specs[1].arg.name == "b_a"
1082+
self.assertEqual(
1083+
ep.graph_signature.input_specs[0].arg.name, "_lifted_tensor_constant1"
1084+
)
1085+
self.assertEqual(ep.graph_signature.input_specs[1].arg.name, "b_a")
10831086

1084-
executorch_program = edge.to_executorch()
1085-
# # the graph signature should also be the same:
1086-
# executorch_program.graph_signature.input_specs[0].arg.name == "_lifted_tensor_constant1"
1087-
# executorch_program.graph_signature.input_specs[1].arg.name == "b_a"
1087+
# Validate that the program successfully passes validation to executorch:
1088+
edge.to_executorch()
10881089

10891090
def test_constant_prop_pass_for_parameter(self) -> None:
10901091
def count_additions(gm: torch.fx.GraphModule) -> int:

0 commit comments

Comments
 (0)