|
39 | 39 | _load_for_executorch_from_buffer,
|
40 | 40 | )
|
41 | 41 | from executorch.extension.pytree import tree_flatten
|
42 |
| -from torch._export import capture_pre_autograd_graph |
43 | 42 | from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
|
44 |
| -from torch.export import export |
| 43 | +from torch.export import export, export_for_training |
45 | 44 | from torch.fx.passes.operator_support import any_chain
|
46 | 45 |
|
47 | 46 |
|
@@ -77,7 +76,7 @@ def partition(
|
77 | 76 |
|
78 | 77 | mlp = MLP()
|
79 | 78 | example_inputs = mlp.get_random_inputs()
|
80 |
| - model = capture_pre_autograd_graph(mlp, example_inputs) |
| 79 | + model = export_for_training(mlp, example_inputs).module() |
81 | 80 | aten = export(model, example_inputs)
|
82 | 81 | spec_key = "path"
|
83 | 82 | spec_value = "/a/b/c/d"
|
@@ -138,7 +137,7 @@ def partition(
|
138 | 137 |
|
139 | 138 | mlp = MLP()
|
140 | 139 | example_inputs = mlp.get_random_inputs()
|
141 |
| - model = capture_pre_autograd_graph(mlp, example_inputs) |
| 140 | + model = export_for_training(mlp, example_inputs).module() |
142 | 141 | aten = export(model, example_inputs)
|
143 | 142 | edge = exir.to_edge(aten)
|
144 | 143 |
|
@@ -178,7 +177,7 @@ def partition(
|
178 | 177 |
|
179 | 178 | mlp = MLP()
|
180 | 179 | example_inputs = mlp.get_random_inputs()
|
181 |
| - model = capture_pre_autograd_graph(mlp, example_inputs) |
| 180 | + model = export_for_training(mlp, example_inputs).module() |
182 | 181 | edge = exir.to_edge(export(model, example_inputs))
|
183 | 182 |
|
184 | 183 | with self.assertRaisesRegex(
|
@@ -230,7 +229,7 @@ def partition(
|
230 | 229 | partition_tags=partition_tags,
|
231 | 230 | )
|
232 | 231 |
|
233 |
| - model = capture_pre_autograd_graph(self.AddConst(), (torch.ones(2, 2),)) |
| 232 | + model = export_for_training(self.AddConst(), (torch.ones(2, 2),)).module() |
234 | 233 | edge = exir.to_edge(export(model, (torch.ones(2, 2),)))
|
235 | 234 | delegated = edge.to_backend(PartitionerNoTagData())
|
236 | 235 |
|
@@ -309,7 +308,7 @@ def partition(
|
309 | 308 | partition_tags=partition_tags,
|
310 | 309 | )
|
311 | 310 |
|
312 |
| - model = capture_pre_autograd_graph(self.AddConst(), (torch.ones(2, 2),)) |
| 311 | + model = export_for_training(self.AddConst(), (torch.ones(2, 2),)).module() |
313 | 312 | edge = exir.to_edge(export(model, (torch.ones(2, 2),)))
|
314 | 313 | delegated = edge.to_backend(PartitionerTagData())
|
315 | 314 |
|
@@ -384,7 +383,7 @@ def partition(
|
384 | 383 | partition_tags=partition_tags,
|
385 | 384 | )
|
386 | 385 |
|
387 |
| - model = capture_pre_autograd_graph(self.AddConst(), (torch.ones(2, 2),)) |
| 386 | + model = export_for_training(self.AddConst(), (torch.ones(2, 2),)).module() |
388 | 387 | edge = exir.to_edge(export(model, (torch.ones(2, 2),)))
|
389 | 388 | delegated = edge.to_backend(PartitionerTagData())
|
390 | 389 |
|
@@ -472,7 +471,7 @@ def partition(
|
472 | 471 | )
|
473 | 472 |
|
474 | 473 | inputs = (torch.ones(2, 2),)
|
475 |
| - model = capture_pre_autograd_graph(ReuseConstData(), (torch.ones(2, 2),)) |
| 474 | + model = export_for_training(ReuseConstData(), (torch.ones(2, 2),)).module() |
476 | 475 | edge = exir.to_edge(export(model, (torch.ones(2, 2),)))
|
477 | 476 | exec_prog = edge.to_backend(PartitionerTagData()).to_executorch()
|
478 | 477 | executorch_module = _load_for_executorch_from_buffer(exec_prog.buffer)
|
@@ -532,7 +531,7 @@ def partition(
|
532 | 531 | partition_tags=partition_tags,
|
533 | 532 | )
|
534 | 533 |
|
535 |
| - model = capture_pre_autograd_graph(ReuseConstData(), (torch.ones(2, 2),)) |
| 534 | + model = export_for_training(ReuseConstData(), (torch.ones(2, 2),)).module() |
536 | 535 | edge = exir.to_edge(export(model, (torch.ones(2, 2),)))
|
537 | 536 | with self.assertRaises(RuntimeError) as error:
|
538 | 537 | _ = edge.to_backend(PartitionerTagData())
|
|
0 commit comments