Skip to content

Commit b28754a

Browse files
cccclaifacebook-github-bot
authored andcommitted
Add lift constant tensors passes after aten_to_edge (#359)
Summary: X-link: pytorch/pytorch#109382 When exporting using enable_aot (through the torch.export path), we want to lift all constant tensors as buffers to the exported program. The ScalarToTensor pass in EXIR's aten_to_edge passes will create some constant tensors in the graph, so we will need to run a lift_constant_tensors pass afterwards. Note that this only needs to be applied when exporting using the torch.export path because in the original path, nothing is lifted. Reviewed By: cccclai Differential Revision: D49207492
1 parent c52000a commit b28754a

File tree

5 files changed

+78
-11
lines changed

5 files changed

+78
-11
lines changed

exir/backend/test/test_backends_lifted.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,7 @@ def forward(self, x_raw, h, c):
647647
config=exir.ExecutorchBackendConfig(extract_segments=extract_segments),
648648
)
649649

650-
new_res = program_with_delegates.dump_graph_module()(*inputs)
650+
new_res = program_with_delegates.dump_exported_program()(*inputs)
651651
for t1, t2 in zip(new_res, orig_res, strict=True):
652652
self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03))
653653

@@ -780,7 +780,7 @@ def forward(self, x_raw, h, c):
780780
# config=exir.ExecutorchBackendConfig(extract_segments=extract_segments),
781781
# )
782782

783-
new_res = program_with_delegates.dump_graph_module()(*inputs)
783+
new_res = program_with_delegates.dump_exported_program()(*inputs)
784784
for t1, t2 in zip(new_res, orig_res, strict=True):
785785
self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03))
786786

exir/capture/_capture.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ def convert_to_fake(x):
245245
{},
246246
[],
247247
[],
248+
dialect="OLD_EXIR_ATEN",
248249
)
249250
return ExirExportedProgram(ep, False)
250251

exir/program/_program.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
EXIREdgeDialectVerifier,
3131
)
3232
from torch._export import ExportedProgram
33+
from torch._export.passes.lift_constant_tensor_pass import lift_constant_tensor_pass
3334
from torch.fx import _pytree as fx_pytree
3435
from torch.fx._compatibility import compatibility
3536
from torch.utils import _pytree as pytree
@@ -69,6 +70,7 @@ def __init__(
6970
module_call_graph,
7071
example_inputs,
7172
)
73+
self._dialect = "HACKED_ATEN"
7274

7375
def __call__(self, *args: Any, **kwargs: Any) -> Any:
7476
import torch._export.error as error
@@ -283,17 +285,33 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
283285
)
284286
raise
285287

286-
op_replace_pass = [OpReplacePass()] if config._use_edge_ops else []
287288
# TODO: the last two passes for aten_to_edge need to be eliminated_dead_code -> debug_handle_generator. After enable
288289
# use_edge_op it can be moved to aten_to_edge_passes before eliminated_dead_code pass. Also ExportPass doesn't play
289290
# well with node.meta, meaning after some passes permuting operators, we may lose some information in node.meta.
290291
# It might be regenerated in SpecPropPass so it may not be visiable. However debug handle will be lost.
291-
passes = (
292-
aten_to_edge_passes.passes[:-2]
293-
+ op_replace_pass
294-
+ aten_to_edge_passes.passes[-2:]
292+
pre_op_replace_passes = aten_to_edge_passes.passes[:-2]
293+
post_op_replace_passes = aten_to_edge_passes.passes[-2:]
294+
295+
new_ep = copy.deepcopy(ep).transform(*pre_op_replace_passes)
296+
if new_ep.exported_program.dialect == "ATEN":
297+
new_ep.exported_program = lift_constant_tensor_pass(new_ep.exported_program)
298+
299+
if config._use_edge_ops:
300+
new_ep = new_ep.transform(OpReplacePass())
301+
302+
new_ep = new_ep.transform(*post_op_replace_passes)
303+
new_ep.exported_program = ExportedProgram(
304+
new_ep.exported_program.graph_module,
305+
new_ep.exported_program.graph,
306+
new_ep.exported_program.graph_signature,
307+
new_ep.exported_program.call_spec,
308+
new_ep.exported_program.state_dict,
309+
new_ep.exported_program.range_constraints,
310+
new_ep.exported_program.equality_constraints,
311+
new_ep.exported_program.module_call_graph,
312+
new_ep.exported_program.example_inputs,
313+
dialect="EDGE",
295314
)
296-
new_ep = copy.deepcopy(ep).transform(*passes)
297315
if config._check_ir_validity:
298316
EXIREdgeDialectVerifier(check_edge_ops=config._use_edge_ops)(
299317
new_ep.exported_program.graph_module

exir/tests/test_passes.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from functorch.experimental import control_flow
4747

4848
from torch import nn
49+
from torch._export.passes.lift_constant_tensor_pass import lift_constant_tensor_pass
4950
from torch.fx import GraphModule, subgraph_rewriter
5051
from torch.fx.experimental.proxy_tensor import make_fx
5152
from torch.library import impl, Library
@@ -464,6 +465,53 @@ def mul(x: torch.Tensor) -> torch.Tensor:
464465
for arg in node.args + tuple(node.kwargs.values()):
465466
self.assertFalse(isinstance(arg, float))
466467

468+
def test_lift_scalar_tensor(self) -> None:
469+
class FooWithBuffer(torch.nn.Module):
470+
def __init__(self):
471+
super().__init__()
472+
self.register_buffer("buffer", torch.zeros(42))
473+
474+
def forward(self, x):
475+
return x.cos() + self.buffer.sum() + torch.tensor(4) + 4
476+
477+
ep = exir.capture(
478+
FooWithBuffer(), (torch.ones(6, 2),), exir.CaptureConfig(enable_aot=True)
479+
)
480+
new_ep = ep.transform(ScalarToTensorPass()).exported_program
481+
self.assertTrue(
482+
len([node for node in new_ep.graph.nodes if node.op == "get_attr"])
483+
)
484+
lifted_exported_program = lift_constant_tensor_pass(new_ep)
485+
486+
self.assertEqual(
487+
len(
488+
[
489+
node
490+
for node in lifted_exported_program.graph.nodes
491+
if node.op == "placeholder"
492+
]
493+
),
494+
4,
495+
)
496+
for node in lifted_exported_program.graph.nodes:
497+
self.assertTrue(node.op != "get_attr")
498+
499+
edge_ep = exir.capture(
500+
FooWithBuffer(), (torch.ones(6, 2),), exir.CaptureConfig(enable_aot=True)
501+
).to_edge()
502+
self.assertEqual(
503+
len(
504+
[
505+
node
506+
for node in edge_ep.exported_program.graph.nodes
507+
if node.op == "placeholder"
508+
]
509+
),
510+
4,
511+
)
512+
for node in edge_ep.exported_program.graph.nodes:
513+
self.assertTrue(node.op != "get_attr")
514+
467515
def test_remove_mixed_types_symfloats(self) -> None:
468516
def f(x: torch.Tensor) -> torch.Tensor:
469517
return torch.nn.functional.interpolate(

exir/tests/test_tracer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -398,9 +398,8 @@ def forward(self, x):
398398
return x.cos() + self.buffer.sum()
399399

400400
capture_config = exir.CaptureConfig(enable_aot=True)
401-
captured_gm = exir.capture(
402-
FooWithBuffer(), (torch.ones(6, 2),), capture_config
403-
).exported_program.graph_module
401+
captured_ep = exir.capture(FooWithBuffer(), (torch.ones(6, 2),), capture_config)
402+
captured_gm = captured_ep.exported_program.graph_module
404403

405404
placeholder_nodes = set()
406405
print(captured_gm.graph)
@@ -420,6 +419,7 @@ def forward(self, x):
420419
)
421420

422421
self.assertEqual(len(placeholder_nodes), 2)
422+
captured_ep.to_edge()
423423

424424
def test_export_unlift(self) -> None:
425425
class Foo(torch.nn.Module):

0 commit comments

Comments
 (0)