Skip to content

Commit 30acae5

Browse files
authored
Switch over backend tests to export_for_training
Differential Revision: D62428363 Pull Request resolved: #5220
1 parent 43e2f2d commit 30acae5

File tree

4 files changed

+21
-23
lines changed

4 files changed

+21
-23
lines changed

backends/example/test_example_delegate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def get_example_inputs():
4646
)
4747

4848
m = model.eval()
49-
m = torch._export.capture_pre_autograd_graph(m, copy.deepcopy(example_inputs))
49+
m = torch.export.export_for_training(m, copy.deepcopy(example_inputs)).module()
5050
# print("original model:", m)
5151
quantizer = ExampleQuantizer()
5252
# quantizer = XNNPACKQuantizer()
@@ -82,7 +82,7 @@ def test_delegate_mobilenet_v2(self):
8282
)
8383

8484
m = model.eval()
85-
m = torch._export.capture_pre_autograd_graph(m, copy.deepcopy(example_inputs))
85+
m = torch.export.export_for_training(m, copy.deepcopy(example_inputs)).module()
8686
quantizer = ExampleQuantizer()
8787

8888
m = prepare_pt2e(m, quantizer)

exir/backend/test/TARGETS

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,14 @@ python_library(
8282
"//executorch/test/...",
8383
],
8484
deps = [
85-
":backend_with_compiler_demo",
86-
"//caffe2:torch",
87-
"//executorch/exir:graph_module",
88-
"//executorch/exir/backend:compile_spec_schema",
89-
"//executorch/exir/backend:partitioner",
90-
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
91-
"//executorch/exir/backend/test/demos/rpc:executor_backend_partitioner",
92-
"//executorch/exir/backend/test/demos/rpc:executor_backend_preprocess",
93-
"//executorch/exir/dialects:lib",
85+
"fbcode//caffe2:torch",
86+
"fbcode//executorch/exir:graph_module",
87+
"fbcode//executorch/exir/backend:compile_spec_schema",
88+
"fbcode//executorch/exir/backend:partitioner",
89+
"fbcode//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
90+
"fbcode//executorch/exir/backend/test:backend_with_compiler_demo",
91+
"fbcode//executorch/exir/backend/test/demos/rpc:executor_backend_preprocess",
92+
"fbcode//executorch/exir/dialects:lib",
9493
],
9594
)
9695

exir/backend/test/test_partitioner.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,8 @@
3939
_load_for_executorch_from_buffer,
4040
)
4141
from executorch.extension.pytree import tree_flatten
42-
from torch._export import capture_pre_autograd_graph
4342
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
44-
from torch.export import export
43+
from torch.export import export, export_for_training
4544
from torch.fx.passes.operator_support import any_chain
4645

4746

@@ -77,7 +76,7 @@ def partition(
7776

7877
mlp = MLP()
7978
example_inputs = mlp.get_random_inputs()
80-
model = capture_pre_autograd_graph(mlp, example_inputs)
79+
model = export_for_training(mlp, example_inputs).module()
8180
aten = export(model, example_inputs)
8281
spec_key = "path"
8382
spec_value = "/a/b/c/d"
@@ -138,7 +137,7 @@ def partition(
138137

139138
mlp = MLP()
140139
example_inputs = mlp.get_random_inputs()
141-
model = capture_pre_autograd_graph(mlp, example_inputs)
140+
model = export_for_training(mlp, example_inputs).module()
142141
aten = export(model, example_inputs)
143142
edge = exir.to_edge(aten)
144143

@@ -178,7 +177,7 @@ def partition(
178177

179178
mlp = MLP()
180179
example_inputs = mlp.get_random_inputs()
181-
model = capture_pre_autograd_graph(mlp, example_inputs)
180+
model = export_for_training(mlp, example_inputs).module()
182181
edge = exir.to_edge(export(model, example_inputs))
183182

184183
with self.assertRaisesRegex(
@@ -230,7 +229,7 @@ def partition(
230229
partition_tags=partition_tags,
231230
)
232231

233-
model = capture_pre_autograd_graph(self.AddConst(), (torch.ones(2, 2),))
232+
model = export_for_training(self.AddConst(), (torch.ones(2, 2),)).module()
234233
edge = exir.to_edge(export(model, (torch.ones(2, 2),)))
235234
delegated = edge.to_backend(PartitionerNoTagData())
236235

@@ -309,7 +308,7 @@ def partition(
309308
partition_tags=partition_tags,
310309
)
311310

312-
model = capture_pre_autograd_graph(self.AddConst(), (torch.ones(2, 2),))
311+
model = export_for_training(self.AddConst(), (torch.ones(2, 2),)).module()
313312
edge = exir.to_edge(export(model, (torch.ones(2, 2),)))
314313
delegated = edge.to_backend(PartitionerTagData())
315314

@@ -384,7 +383,7 @@ def partition(
384383
partition_tags=partition_tags,
385384
)
386385

387-
model = capture_pre_autograd_graph(self.AddConst(), (torch.ones(2, 2),))
386+
model = export_for_training(self.AddConst(), (torch.ones(2, 2),)).module()
388387
edge = exir.to_edge(export(model, (torch.ones(2, 2),)))
389388
delegated = edge.to_backend(PartitionerTagData())
390389

@@ -472,7 +471,7 @@ def partition(
472471
)
473472

474473
inputs = (torch.ones(2, 2),)
475-
model = capture_pre_autograd_graph(ReuseConstData(), (torch.ones(2, 2),))
474+
model = export_for_training(ReuseConstData(), (torch.ones(2, 2),)).module()
476475
edge = exir.to_edge(export(model, (torch.ones(2, 2),)))
477476
exec_prog = edge.to_backend(PartitionerTagData()).to_executorch()
478477
executorch_module = _load_for_executorch_from_buffer(exec_prog.buffer)
@@ -532,7 +531,7 @@ def partition(
532531
partition_tags=partition_tags,
533532
)
534533

535-
model = capture_pre_autograd_graph(ReuseConstData(), (torch.ones(2, 2),))
534+
model = export_for_training(ReuseConstData(), (torch.ones(2, 2),)).module()
536535
edge = exir.to_edge(export(model, (torch.ones(2, 2),)))
537536
with self.assertRaises(RuntimeError) as error:
538537
_ = edge.to_backend(PartitionerTagData())

exir/backend/test/test_passes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from executorch.exir.backend.canonical_partitioners.duplicate_constant_node_pass import (
1212
duplicate_constant_node,
1313
)
14-
from torch._export import capture_pre_autograd_graph
1514
from torch._export.utils import is_buffer
15+
from torch.export import export_for_training
1616
from torch.testing import FileCheck
1717

1818

@@ -29,7 +29,7 @@ def forward(self, x):
2929
z = x - self.const
3030
return y, z
3131

32-
model = capture_pre_autograd_graph(ReuseConstData(), (torch.ones(2, 2),))
32+
model = export_for_training(ReuseConstData(), (torch.ones(2, 2),)).module()
3333
edge = exir.to_edge(torch.export.export(model, (torch.ones(2, 2),)))
3434

3535
const_nodes = [

0 commit comments

Comments
 (0)