Skip to content

Commit e03bd6c

Browse files
angelayifacebook-github-bot
authored andcommitted
Add lift constant tensors passes after aten_to_edge (#359)
Summary: X-link: pytorch/pytorch#109382 Pull Request resolved: #359 When exporting using enable_aot (through the torch.export path), we want to lift all constant tensors as buffers to the exported program. The ScalarToTensor pass in EXIR's aten_to_edge passes will create some constant tensors in the graph, so we will need to run a lift_constant_tensors pass afterwards. Note that this only needs to be applied when exporting using the torch.export path because in the original path, nothing is lifted. Reviewed By: cccclai Differential Revision: D49207492 fbshipit-source-id: 971c44d109b41de8b3f9ff0565fd39bdd3a17b9f
1 parent 49d2e68 commit e03bd6c

File tree

5 files changed

+78
-11
lines changed

5 files changed

+78
-11
lines changed

exir/backend/test/test_backends_lifted.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,7 @@ def forward(self, x_raw, h, c):
647647
config=exir.ExecutorchBackendConfig(extract_segments=extract_segments),
648648
)
649649

650-
new_res = program_with_delegates.dump_graph_module()(*inputs)
650+
new_res = program_with_delegates.dump_exported_program()(*inputs)
651651
for t1, t2 in zip(new_res, orig_res, strict=True):
652652
self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03))
653653

@@ -780,7 +780,7 @@ def forward(self, x_raw, h, c):
780780
# config=exir.ExecutorchBackendConfig(extract_segments=extract_segments),
781781
# )
782782

783-
new_res = program_with_delegates.dump_graph_module()(*inputs)
783+
new_res = program_with_delegates.dump_exported_program()(*inputs)
784784
for t1, t2 in zip(new_res, orig_res, strict=True):
785785
self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03))
786786

exir/capture/_capture.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ def convert_to_fake(x):
245245
{},
246246
[],
247247
[],
248+
dialect="OLD_EXIR_ATEN",
248249
)
249250
return ExirExportedProgram(ep, False)
250251

exir/program/_program.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
EXIREdgeDialectVerifier,
3030
)
3131
from torch._export import ExportedProgram
32+
from torch._export.passes.lift_constant_tensor_pass import lift_constant_tensor_pass
3233
from torch.fx import _pytree as fx_pytree
3334
from torch.fx._compatibility import compatibility
3435
from torch.utils import _pytree as pytree
@@ -68,6 +69,7 @@ def __init__(
6869
module_call_graph,
6970
example_inputs,
7071
)
72+
self._dialect = "HACKED_ATEN"
7173

7274
def __call__(self, *args: Any, **kwargs: Any) -> Any:
7375
import torch._export.error as error
@@ -281,17 +283,33 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
281283
)
282284
raise
283285

284-
op_replace_pass = [OpReplacePass()] if config._use_edge_ops else []
285286
# TODO: the last two passes for aten_to_edge need to be eliminated_dead_code -> debug_handle_generator. After enable
286287
# use_edge_op it can be moved to aten_to_edge_passes before eliminated_dead_code pass. Also ExportPass doesn't play
287288
# well with node.meta, meaning after some passes permuting operators, we may lose some information in node.meta.
288289
# It might be regenerated in SpecPropPass so it may not be visiable. However debug handle will be lost.
289-
passes = (
290-
aten_to_edge_passes.passes[:-2]
291-
+ op_replace_pass
292-
+ aten_to_edge_passes.passes[-2:]
290+
pre_op_replace_passes = aten_to_edge_passes.passes[:-2]
291+
post_op_replace_passes = aten_to_edge_passes.passes[-2:]
292+
293+
new_ep = copy.deepcopy(ep).transform(*pre_op_replace_passes)
294+
if new_ep.exported_program.dialect == "ATEN":
295+
new_ep.exported_program = lift_constant_tensor_pass(new_ep.exported_program)
296+
297+
if config._use_edge_ops:
298+
new_ep = new_ep.transform(OpReplacePass())
299+
300+
new_ep = new_ep.transform(*post_op_replace_passes)
301+
new_ep.exported_program = ExportedProgram(
302+
new_ep.exported_program.graph_module,
303+
new_ep.exported_program.graph,
304+
new_ep.exported_program.graph_signature,
305+
new_ep.exported_program.call_spec,
306+
new_ep.exported_program.state_dict,
307+
new_ep.exported_program.range_constraints,
308+
new_ep.exported_program.equality_constraints,
309+
new_ep.exported_program.module_call_graph,
310+
new_ep.exported_program.example_inputs,
311+
dialect="EDGE",
293312
)
294-
new_ep = copy.deepcopy(ep).transform(*passes)
295313
if config._check_ir_validity:
296314
EXIREdgeDialectVerifier(check_edge_ops=config._use_edge_ops)(
297315
new_ep.exported_program.graph_module

exir/tests/test_passes.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from functorch.experimental import control_flow
4747

4848
from torch import nn
49+
from torch._export.passes.lift_constant_tensor_pass import lift_constant_tensor_pass
4950
from torch.fx import GraphModule, subgraph_rewriter
5051
from torch.fx.experimental.proxy_tensor import make_fx
5152
from torch.library import impl, Library
@@ -464,6 +465,53 @@ def mul(x: torch.Tensor) -> torch.Tensor:
464465
for arg in node.args + tuple(node.kwargs.values()):
465466
self.assertFalse(isinstance(arg, float))
466467

468+
def test_lift_scalar_tensor(self) -> None:
469+
class FooWithBuffer(torch.nn.Module):
470+
def __init__(self):
471+
super().__init__()
472+
self.register_buffer("buffer", torch.zeros(42))
473+
474+
def forward(self, x):
475+
return x.cos() + self.buffer.sum() + torch.tensor(4) + 4
476+
477+
ep = exir.capture(
478+
FooWithBuffer(), (torch.ones(6, 2),), exir.CaptureConfig(enable_aot=True)
479+
)
480+
new_ep = ep.transform(ScalarToTensorPass()).exported_program
481+
self.assertTrue(
482+
len([node for node in new_ep.graph.nodes if node.op == "get_attr"])
483+
)
484+
lifted_exported_program = lift_constant_tensor_pass(new_ep)
485+
486+
self.assertEqual(
487+
len(
488+
[
489+
node
490+
for node in lifted_exported_program.graph.nodes
491+
if node.op == "placeholder"
492+
]
493+
),
494+
4,
495+
)
496+
for node in lifted_exported_program.graph.nodes:
497+
self.assertTrue(node.op != "get_attr")
498+
499+
edge_ep = exir.capture(
500+
FooWithBuffer(), (torch.ones(6, 2),), exir.CaptureConfig(enable_aot=True)
501+
).to_edge()
502+
self.assertEqual(
503+
len(
504+
[
505+
node
506+
for node in edge_ep.exported_program.graph.nodes
507+
if node.op == "placeholder"
508+
]
509+
),
510+
4,
511+
)
512+
for node in edge_ep.exported_program.graph.nodes:
513+
self.assertTrue(node.op != "get_attr")
514+
467515
def test_remove_mixed_types_symfloats(self) -> None:
468516
def f(x: torch.Tensor) -> torch.Tensor:
469517
return torch.nn.functional.interpolate(

exir/tests/test_tracer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -398,9 +398,8 @@ def forward(self, x):
398398
return x.cos() + self.buffer.sum()
399399

400400
capture_config = exir.CaptureConfig(enable_aot=True)
401-
captured_gm = exir.capture(
402-
FooWithBuffer(), (torch.ones(6, 2),), capture_config
403-
).exported_program.graph_module
401+
captured_ep = exir.capture(FooWithBuffer(), (torch.ones(6, 2),), capture_config)
402+
captured_gm = captured_ep.exported_program.graph_module
404403

405404
placeholder_nodes = set()
406405
print(captured_gm.graph)
@@ -420,6 +419,7 @@ def forward(self, x):
420419
)
421420

422421
self.assertEqual(len(placeholder_nodes), 2)
422+
captured_ep.to_edge()
423423

424424
def test_export_unlift(self) -> None:
425425
class Foo(torch.nn.Module):

0 commit comments

Comments
 (0)