Skip to content

Commit 9b3f2ba

Browse files
authored
Add a pass to replace nodes with empty tensors with full.
Differential Revision: D68907459 Pull Request resolved: #8130
1 parent b02c692 commit 9b3f2ba

File tree

4 files changed

+90
-5
lines changed

4 files changed

+90
-5
lines changed

backends/cadence/aot/graph_builder.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,16 @@
66
from typing import Optional, Sequence, Union
77

88
import torch
9-
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
9+
from executorch.exir.pass_base import (
10+
Argument,
11+
ExportPass,
12+
NodeMetadata,
13+
PassResult,
14+
ProxyValue,
15+
)
1016
from torch._dispatch.python import enable_python_dispatcher
1117
from torch._subclasses import FakeTensor, FakeTensorMode
12-
from torch.fx.node import Argument, Target
18+
from torch.fx.node import Target
1319
from torch.utils import _pytree as pytree
1420

1521

backends/cadence/aot/replace_ops.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2071,10 +2071,32 @@ def call_operator(
20712071
)
20722072

20732073

2074+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
2075+
class ReplaceEmptyTensorsWithFullPass(ExportPass):
2076+
"""Replaces nodes that produce empty tensors with full nodes."""
2077+
2078+
def call_operator(self, op, args, kwargs, meta):
2079+
val = meta.data.get("val", None)
2080+
if isinstance(val, torch.Tensor) and val.numel() == 0:
2081+
return super().call_operator(
2082+
exir_ops.edge.aten.full.default,
2083+
args=(val.shape, 0),
2084+
kwargs={"dtype": val.dtype},
2085+
meta=meta,
2086+
)
2087+
return super().call_operator(op, args, kwargs, meta)
2088+
2089+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
2090+
ret = super().call(graph_module)
2091+
modified = ret.graph_module.graph.eliminate_dead_code() or ret.modified
2092+
return PassResult(ret.graph_module, modified)
2093+
2094+
20742095
# This class encapsulates all the functions that replace/switch one op in the
20752096
# graph with another.
20762097
class CadenceReplaceOpsInGraph:
20772098
passes = [
2099+
ReplaceEmptyTensorsWithFullPass,
20782100
ReplaceFunctionallyEquivalentOpTargets,
20792101
ReplaceTCopyWithTransposePass,
20802102
ReplacePermuteWithTransposePass,

backends/cadence/aot/tests/test_graph_builder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def test_graph_with_single_im2row(self) -> None:
2626
channels_last = False
2727
im2row = builder.call_operator(
2828
exir_ops.edge.cadence.im2row.default,
29-
# pyre-ignore
3029
(
3130
x,
3231
(2, 2),
@@ -80,7 +79,7 @@ def _get_inner_graph(self, x_shape: Sequence[int]) -> torch.fx.GraphModule:
8079
x = builder.placeholder("x", torch.randn(*x_shape))
8180
add = builder.call_operator(
8281
exir_ops.edge.aten.add.Tensor,
83-
(x, x), # pyre-ignore
82+
(x, x),
8483
)
8584
builder.output([x, add])
8685
gm = builder.get_graph_module()

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
import torch.nn.functional as F
88
from executorch.backends.cadence.aot import compiler
99
from executorch.backends.cadence.aot.compiler import export_to_edge, quantize_pt2
10-
from executorch.backends.cadence.aot.graph_builder import single_op_builder
10+
from executorch.backends.cadence.aot.graph_builder import (
11+
GraphBuilder,
12+
single_op_builder,
13+
)
1114
from executorch.backends.cadence.aot.pass_utils import count_node
1215
from executorch.backends.cadence.aot.replace_ops import (
1316
ForceChannelLastForConvPass,
@@ -18,6 +21,7 @@
1821
ReplaceConstantPadNdWithSlicePass,
1922
ReplaceConvolutionOptionalArgsWithConcreteArgsPass,
2023
ReplaceConvWithIm2RowAndLinear,
24+
ReplaceEmptyTensorsWithFullPass,
2125
ReplaceFunctionallyEquivalentOpTargets,
2226
ReplaceIm2RowWithViewPass,
2327
ReplaceLinearWithFullyConnectedOpPass,
@@ -1681,3 +1685,57 @@ def test_cat_insert_transpose(self):
16811685
count_node(gm_after_pass, exir_ops.edge.aten.transpose_copy.int),
16821686
3,
16831687
)
1688+
1689+
1690+
class TestReplaceEmptyTensorsWithFullPass(unittest.TestCase):
1691+
def _get_slice_empty_gm(self) -> torch.fx.GraphModule:
1692+
builder = GraphBuilder()
1693+
x = builder.placeholder("x", torch.randn(4))
1694+
# This is empty (numel == 0).
1695+
slice0 = builder.call_operator(
1696+
exir_ops.edge.aten.slice_copy.Tensor, (x, 0, 0, 0)
1697+
)
1698+
# Copy of x.
1699+
slice1 = builder.call_operator(exir_ops.edge.aten.slice_copy.Tensor, (x,))
1700+
cat = builder.call_operator(
1701+
exir_ops.edge.aten.cat.default,
1702+
((slice0, slice1),),
1703+
)
1704+
builder.output([cat])
1705+
return builder.get_graph_module()
1706+
1707+
def test_empty_slice(self):
1708+
gm = self._get_slice_empty_gm()
1709+
self.assertEqual(
1710+
len(
1711+
gm.graph.find_nodes(
1712+
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
1713+
)
1714+
),
1715+
2,
1716+
)
1717+
self.assertEqual(
1718+
len(
1719+
gm.graph.find_nodes(
1720+
op="call_function", target=exir_ops.edge.aten.full.default
1721+
)
1722+
),
1723+
0,
1724+
)
1725+
updated_gm = ReplaceEmptyTensorsWithFullPass()(gm).graph_module
1726+
self.assertEqual(
1727+
len(
1728+
updated_gm.graph.find_nodes(
1729+
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
1730+
)
1731+
),
1732+
1,
1733+
)
1734+
self.assertEqual(
1735+
len(
1736+
updated_gm.graph.find_nodes(
1737+
op="call_function", target=exir_ops.edge.aten.full.default
1738+
)
1739+
),
1740+
1,
1741+
)

0 commit comments

Comments
 (0)