Skip to content

Commit 6867190

Browse files
cccclaifacebook-github-bot
authored andcommitted
Add lift constant tensors passes after aten_to_edge
Summary: When exporting using enable_aot (through the torch.export path), we want to lift all constant tensors as buffers to the exported program. The ScalarToTensor pass in EXIR's aten_to_edge passes will create some constant tensors in the graph, so we will need to run a lift_constant_tensors pass afterwards. Note that this only needs to be applied when exporting using the torch.export path because in the original path, nothing is lifted. Differential Revision: D49207492
1 parent f0125ba commit 6867190

File tree

5 files changed

+66
-12
lines changed

5 files changed

+66
-12
lines changed

exir/backend/test/test_backends_lifted.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,7 @@ def forward(self, x_raw, h, c):
647647
config=exir.ExecutorchBackendConfig(extract_segments=extract_segments),
648648
)
649649

650-
new_res = program_with_delegates.dump_graph_module()(*inputs)
650+
new_res = program_with_delegates.dump_exported_program()(*inputs)
651651
for t1, t2 in zip(new_res, orig_res, strict=True):
652652
self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03))
653653

@@ -780,7 +780,7 @@ def forward(self, x_raw, h, c):
780780
# config=exir.ExecutorchBackendConfig(extract_segments=extract_segments),
781781
# )
782782

783-
new_res = program_with_delegates.dump_graph_module()(*inputs)
783+
new_res = program_with_delegates.dump_exported_program()(*inputs)
784784
for t1, t2 in zip(new_res, orig_res, strict=True):
785785
self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03))
786786

exir/capture/_capture.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ def convert_to_fake(x):
245245
{},
246246
[],
247247
[],
248+
dialect="OLD_EXIR_ATEN",
248249
)
249250
return ExirExportedProgram(ep, False)
250251

exir/program/_program.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
EXIREdgeDialectVerifier,
3030
)
3131
from torch._export import ExportedProgram
32+
from torch._export.passes.lift_constant_tensor_pass import lift_constant_tensor_pass
3233
from torch.fx import _pytree as fx_pytree
3334
from torch.fx._compatibility import compatibility
3435
from torch.utils import _pytree as pytree
@@ -68,6 +69,7 @@ def __init__(
6869
module_call_graph,
6970
example_inputs,
7071
)
72+
self._dialect = "HACKED_ATEN"
7173

7274
def __call__(self, *args: Any, **kwargs: Any) -> Any:
7375
import torch._export.error as error
@@ -274,17 +276,23 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
274276
)
275277
raise
276278

277-
op_replace_pass = [OpReplacePass()] if config._use_edge_ops else []
278279
# TODO: the last two passes for aten_to_edge need to be eliminated_dead_code -> debug_handle_generator. After enable
279280
# use_edge_op it can be moved to aten_to_edge_passes before eliminated_dead_code pass. Also ExportPass doesn't play
280281
# well with node.meta, meaning after some passes permuting operators, we may lose some information in node.meta.
281282
# It might be regenerated in SpecPropPass so it may not be visiable. However debug handle will be lost.
282-
passes = (
283-
aten_to_edge_passes.passes[:-2]
284-
+ op_replace_pass
285-
+ aten_to_edge_passes.passes[-2:]
286-
)
287-
new_ep = copy.deepcopy(ep).transform(*passes)
283+
pre_op_replace_passes = aten_to_edge_passes.passes[:-2]
284+
post_op_replace_passes = aten_to_edge_passes.passes[-2:]
285+
286+
new_ep = copy.deepcopy(ep).transform(*pre_op_replace_passes)
287+
if new_ep.exported_program.dialect == "ATEN":
288+
new_ep.exported_program = lift_constant_tensor_pass(new_ep.exported_program)
289+
290+
if config._use_edge_ops:
291+
new_ep = new_ep.transform(OpReplacePass())
292+
293+
new_ep = new_ep.transform(*post_op_replace_passes)
294+
new_ep.exported_program.dialect = "EDGE"
295+
288296
if config._check_ir_validity:
289297
EXIREdgeDialectVerifier(check_edge_ops=config._use_edge_ops)(
290298
new_ep.exported_program.graph_module

exir/tests/test_passes.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from functorch.experimental import control_flow
4747

4848
from torch import nn
49+
from torch._export.passes.lift_constant_tensor_pass import lift_constant_tensor_pass
4950
from torch.fx import GraphModule, subgraph_rewriter
5051
from torch.fx.experimental.proxy_tensor import make_fx
5152
from torch.library import impl, Library
@@ -464,6 +465,50 @@ def mul(x: torch.Tensor) -> torch.Tensor:
464465
for arg in node.args + tuple(node.kwargs.values()):
465466
self.assertFalse(isinstance(arg, float))
466467

468+
def test_lift_scalar_tensor(self) -> None:
469+
class FooWithBuffer(torch.nn.Module):
470+
def __init__(self):
471+
super().__init__()
472+
self.register_buffer("buffer", torch.zeros(42))
473+
474+
def forward(self, x):
475+
return x.cos() + self.buffer.sum() + torch.tensor(4) + 4
476+
477+
ep = exir.capture(
478+
FooWithBuffer(), (torch.ones(6, 2),), exir.CaptureConfig(enable_aot=True)
479+
)
480+
new_ep = ep.transform(ScalarToTensorPass()).exported_program
481+
lifted_exported_program = lift_constant_tensor_pass(new_ep)
482+
483+
self.assertEqual(
484+
len(
485+
[
486+
node
487+
for node in lifted_exported_program.graph.nodes
488+
if node.op == "placeholder"
489+
]
490+
),
491+
4,
492+
)
493+
for node in lifted_exported_program.graph.nodes:
494+
self.assertTrue(node.op != "get_attr")
495+
496+
edge_ep = exir.capture(
497+
FooWithBuffer(), (torch.ones(6, 2),), exir.CaptureConfig(enable_aot=True)
498+
).to_edge()
499+
self.assertEqual(
500+
len(
501+
[
502+
node
503+
for node in edge_ep.exported_program.graph.nodes
504+
if node.op == "placeholder"
505+
]
506+
),
507+
4,
508+
)
509+
for node in edge_ep.exported_program.graph.nodes:
510+
self.assertTrue(node.op != "get_attr")
511+
467512
def test_remove_mixed_types_symfloats(self) -> None:
468513
def f(x: torch.Tensor) -> torch.Tensor:
469514
return torch.nn.functional.interpolate(

exir/tests/test_tracer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -398,9 +398,8 @@ def forward(self, x):
398398
return x.cos() + self.buffer.sum()
399399

400400
capture_config = exir.CaptureConfig(enable_aot=True)
401-
captured_gm = exir.capture(
402-
FooWithBuffer(), (torch.ones(6, 2),), capture_config
403-
).exported_program.graph_module
401+
captured_ep = exir.capture(FooWithBuffer(), (torch.ones(6, 2),), capture_config)
402+
captured_gm = captured_ep.exported_program.graph_module
404403

405404
placeholder_nodes = set()
406405
print(captured_gm.graph)
@@ -420,6 +419,7 @@ def forward(self, x):
420419
)
421420

422421
self.assertEqual(len(placeholder_nodes), 2)
422+
captured_ep.to_edge()
423423

424424
def test_export_unlift(self) -> None:
425425
class Foo(torch.nn.Module):

0 commit comments

Comments
 (0)