Skip to content

Commit f739212

Browse files
angelayifacebook-github-bot
authored andcommitted
Disable exported_program.__call__ (#1954)
Summary: Pull Request resolved: #1954 X-link: pytorch/pytorch#119466 `ExportedProgram` is an artifact produced by torch.export, containing the graph that is exported, along with other attributes about the original program such as the graph signature, state dict, and constants. One slightly confusing thing that users run into is that they treat the `ExportedProgram` as a `torch.nn.Module`, since the object is callable. However, as we do not plan to support all features that `torch.nn.Module`s have, like hooks, we want to create a distinction between it and the `ExportedProgram` by removing the `__call__` method. Instead users can create a proper `torch.nn.Module` through `exported_program.module()` and use that as a callable. Reviewed By: zhxchen17 Differential Revision: D53075378 fbshipit-source-id: e17671b655b6f410fd5e70971c246f36558324ff
1 parent 8fe00b5 commit f739212

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

exir/backend/test/test_backends_lifted.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ def forward(self, x_raw, h, c):
626626
),
627627
)
628628

629-
new_res = program_with_delegates.exported_program()(*inputs)
629+
new_res = program_with_delegates.exported_program().module()(*inputs)
630630
for t1, t2 in zip(new_res, orig_res, strict=True):
631631
self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03))
632632

@@ -745,7 +745,7 @@ def forward(self, x_raw, h, c):
745745
HTAPartitionerOnePatternDemo()
746746
)
747747

748-
new_res = traced_with_delegate.exported_program()(*inputs)
748+
new_res = traced_with_delegate.exported_program().module()(*inputs)
749749
for t1, t2 in zip(new_res, orig_res, strict=True):
750750
self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03))
751751

@@ -768,7 +768,7 @@ def forward(self, x_raw, h, c):
768768
# config=exir.ExecutorchBackendConfig(extract_delegate_segments=extract_delegate_segments),
769769
# )
770770

771-
new_res = program_with_delegates.exported_program()(*inputs)
771+
new_res = program_with_delegates.exported_program().module()(*inputs)
772772
for t1, t2 in zip(new_res, orig_res, strict=True):
773773
self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03))
774774

@@ -1029,7 +1029,7 @@ def f(x, y):
10291029
partitioned = orig
10301030
partitioned = partitioned.to_backend(AddMulPartitionerDemo())
10311031

1032-
new_res = partitioned.exported_program()(*inputs)
1032+
new_res = partitioned.exported_program().module()(*inputs)
10331033
self.assertTrue(torch.allclose(orig_res, new_res[0]))
10341034

10351035
toplevel_lowered = get_lowered_submodules(
@@ -1102,7 +1102,7 @@ def f(xs, y):
11021102
map_fn_lowered[0][1].original_module.graph_module.code
11031103
)
11041104

1105-
new_res = partitioned.exported_program()(*inputs)
1105+
new_res = partitioned.exported_program().module()(*inputs)
11061106

11071107
self.assertTrue(torch.allclose(orig_res, new_res[0]))
11081108

@@ -1153,7 +1153,7 @@ def f(xs, pred1, pred2, y):
11531153
partitioned = orig
11541154
partitioned = partitioned.to_backend(AddMulPartitionerDemo())
11551155

1156-
new_res = partitioned.exported_program()(*inputs)
1156+
new_res = partitioned.exported_program().module()(*inputs)
11571157
self.assertTrue(torch.allclose(orig_res, new_res[0]))
11581158

11591159
toplevel_lowered = get_lowered_submodules(
@@ -1224,7 +1224,7 @@ def forward(self, x: List[torch.Tensor]):
12241224
return self.lowered(x)
12251225

12261226
gm = to_edge(export(ComposedM(), inputs))
1227-
gm.exported_program()(*inputs)
1227+
gm.exported_program().module()(*inputs)
12281228

12291229
def test_dict_input(self):
12301230
def f(x: Dict[str, torch.Tensor]):
@@ -1246,4 +1246,4 @@ def forward(self, x: List[torch.Tensor]):
12461246
return self.lowered(x)
12471247

12481248
gm = to_edge(export(ComposedM(), inputs))
1249-
gm.exported_program()(*inputs)
1249+
gm.exported_program().module()(*inputs)

exir/tests/test_verification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
136136
exec_prog = to_edge(export(model2, (inputs,))).to_executorch()
137137

138138
exported_prog = exec_prog.exported_program()
139-
res = exported_prog(inputs)[0] # noqa
139+
res = exported_prog.module()(inputs)[0] # noqa
140140
# Verifiers are run internally in to_edge, export, and to_executorch.
141141
# If we make it this far then no errors were thrown in verification
142142

0 commit comments

Comments
 (0)