Skip to content

Add lift constant tensors passes after aten_to_edge #359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions exir/backend/test/test_backends_lifted.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ def forward(self, x_raw, h, c):
config=exir.ExecutorchBackendConfig(extract_segments=extract_segments),
)

new_res = program_with_delegates.dump_graph_module()(*inputs)
new_res = program_with_delegates.dump_exported_program()(*inputs)
for t1, t2 in zip(new_res, orig_res, strict=True):
self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03))

Expand Down Expand Up @@ -780,7 +780,7 @@ def forward(self, x_raw, h, c):
# config=exir.ExecutorchBackendConfig(extract_segments=extract_segments),
# )

new_res = program_with_delegates.dump_graph_module()(*inputs)
new_res = program_with_delegates.dump_exported_program()(*inputs)
for t1, t2 in zip(new_res, orig_res, strict=True):
self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03))

Expand Down
1 change: 1 addition & 0 deletions exir/capture/_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def convert_to_fake(x):
{},
[],
[],
dialect="OLD_EXIR_ATEN",
)
return ExirExportedProgram(ep, False)

Expand Down
30 changes: 24 additions & 6 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
EXIREdgeDialectVerifier,
)
from torch._export import ExportedProgram
from torch._export.passes.lift_constant_tensor_pass import lift_constant_tensor_pass
from torch.fx import _pytree as fx_pytree
from torch.fx._compatibility import compatibility
from torch.utils import _pytree as pytree
Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(
module_call_graph,
example_inputs,
)
self._dialect = "HACKED_ATEN"

def __call__(self, *args: Any, **kwargs: Any) -> Any:
import torch._export.error as error
Expand Down Expand Up @@ -281,17 +283,33 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
)
raise

op_replace_pass = [OpReplacePass()] if config._use_edge_ops else []
# TODO: the last two passes for aten_to_edge need to be eliminated_dead_code -> debug_handle_generator. After enable
# use_edge_op it can be moved to aten_to_edge_passes before eliminated_dead_code pass. Also ExportPass doesn't play
# well with node.meta, meaning after some passes permuting operators, we may lose some information in node.meta.
# It might be regenerated in SpecPropPass so it may not be visiable. However debug handle will be lost.
passes = (
aten_to_edge_passes.passes[:-2]
+ op_replace_pass
+ aten_to_edge_passes.passes[-2:]
pre_op_replace_passes = aten_to_edge_passes.passes[:-2]
post_op_replace_passes = aten_to_edge_passes.passes[-2:]

new_ep = copy.deepcopy(ep).transform(*pre_op_replace_passes)
if new_ep.exported_program.dialect == "ATEN":
new_ep.exported_program = lift_constant_tensor_pass(new_ep.exported_program)

if config._use_edge_ops:
new_ep = new_ep.transform(OpReplacePass())

new_ep = new_ep.transform(*post_op_replace_passes)
new_ep.exported_program = ExportedProgram(
new_ep.exported_program.graph_module,
new_ep.exported_program.graph,
new_ep.exported_program.graph_signature,
new_ep.exported_program.call_spec,
new_ep.exported_program.state_dict,
new_ep.exported_program.range_constraints,
new_ep.exported_program.equality_constraints,
new_ep.exported_program.module_call_graph,
new_ep.exported_program.example_inputs,
dialect="EDGE",
)
new_ep = copy.deepcopy(ep).transform(*passes)
if config._check_ir_validity:
EXIREdgeDialectVerifier(check_edge_ops=config._use_edge_ops)(
new_ep.exported_program.graph_module
Expand Down
48 changes: 48 additions & 0 deletions exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from functorch.experimental import control_flow

from torch import nn
from torch._export.passes.lift_constant_tensor_pass import lift_constant_tensor_pass
from torch.fx import GraphModule, subgraph_rewriter
from torch.fx.experimental.proxy_tensor import make_fx
from torch.library import impl, Library
Expand Down Expand Up @@ -464,6 +465,53 @@ def mul(x: torch.Tensor) -> torch.Tensor:
for arg in node.args + tuple(node.kwargs.values()):
self.assertFalse(isinstance(arg, float))

def test_lift_scalar_tensor(self) -> None:
class FooWithBuffer(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer", torch.zeros(42))

def forward(self, x):
return x.cos() + self.buffer.sum() + torch.tensor(4) + 4

ep = exir.capture(
FooWithBuffer(), (torch.ones(6, 2),), exir.CaptureConfig(enable_aot=True)
)
new_ep = ep.transform(ScalarToTensorPass()).exported_program
self.assertTrue(
len([node for node in new_ep.graph.nodes if node.op == "get_attr"])
)
lifted_exported_program = lift_constant_tensor_pass(new_ep)

self.assertEqual(
len(
[
node
for node in lifted_exported_program.graph.nodes
if node.op == "placeholder"
]
),
4,
)
for node in lifted_exported_program.graph.nodes:
self.assertTrue(node.op != "get_attr")

edge_ep = exir.capture(
FooWithBuffer(), (torch.ones(6, 2),), exir.CaptureConfig(enable_aot=True)
).to_edge()
self.assertEqual(
len(
[
node
for node in edge_ep.exported_program.graph.nodes
if node.op == "placeholder"
]
),
4,
)
for node in edge_ep.exported_program.graph.nodes:
self.assertTrue(node.op != "get_attr")

def test_remove_mixed_types_symfloats(self) -> None:
def f(x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.interpolate(
Expand Down
6 changes: 3 additions & 3 deletions exir/tests/test_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,9 +398,8 @@ def forward(self, x):
return x.cos() + self.buffer.sum()

capture_config = exir.CaptureConfig(enable_aot=True)
captured_gm = exir.capture(
FooWithBuffer(), (torch.ones(6, 2),), capture_config
).exported_program.graph_module
captured_ep = exir.capture(FooWithBuffer(), (torch.ones(6, 2),), capture_config)
captured_gm = captured_ep.exported_program.graph_module

placeholder_nodes = set()
print(captured_gm.graph)
Expand All @@ -420,6 +419,7 @@ def forward(self, x):
)

self.assertEqual(len(placeholder_nodes), 2)
captured_ep.to_edge()

def test_export_unlift(self) -> None:
class Foo(torch.nn.Module):
Expand Down