Skip to content

Commit c6f1f28

Browse files
mcr229YIWENX14
authored andcommitted
Remove unused Empty Tensors from Edge Graph
Differential Revision: D68589336 Pull Request resolved: #7954
1 parent 501550a commit c6f1f28

File tree

6 files changed

+182
-0
lines changed

6 files changed

+182
-0
lines changed

backends/xnnpack/test/ops/test_cat.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,15 @@ def test_qs8_cat_gt_5(self):
187187
inputs.append(torch.randn(1, 2, 3))
188188
self._test_cat(self.Cat(), tuple(inputs), cat_num=num_inputs, quant=True)
189189

190+
def test_qs8_cat_with_empty_tensor(self):
191+
inputs = (
192+
torch.randn(0, 2, 3),
193+
torch.randn(1, 2, 3),
194+
torch.randn(3, 2, 3),
195+
torch.randn(0, 2, 3),
196+
)
197+
self._test_cat(self.Cat(), inputs, cat_num=4, quant=True)
198+
190199
class CatNegativeDim(torch.nn.Module):
191200
def __init__(self):
192201
super().__init__()

exir/passes/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ python_library(
1717
":memory_planning_pass",
1818
":normalize_transpose_pass",
1919
":prim_ops_py_registry",
20+
":prune_empty_tensor_pass",
2021
":quant_fusion_pass",
2122
":quantize_io_pass",
2223
":remove_noop_pass",
@@ -197,6 +198,18 @@ python_library(
197198
],
198199
)
199200

201+
python_library(
202+
name = "prune_empty_tensor_pass",
203+
srcs = [
204+
"prune_empty_tensors_pass.py",
205+
],
206+
deps = [
207+
"//caffe2:torch",
208+
"//executorch/exir:pass_base",
209+
"//executorch/exir/dialects:lib",
210+
],
211+
)
212+
200213
python_library(
201214
name = "remove_mixed_type_operators",
202215
srcs = [

exir/passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from executorch.exir.passes.memory_format_ops_pass import MemoryFormatOpsPass
4343
from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
4444
from executorch.exir.passes.normalize_transpose_pass import NormalizeTransposePass
45+
from executorch.exir.passes.prune_empty_tensors_pass import PruneEmptyTensorsPass
4546
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
4647
from executorch.exir.passes.remove_noop_pass import RemoveNoopPass, RemoveToCopyPass
4748
from executorch.exir.passes.replace_aten_with_edge_pass import OpReplacePass
@@ -486,6 +487,7 @@ def dead_code_elimination_pass(graph_module: torch.fx.GraphModule) -> PassResult
486487
ScalarToTensorPass(),
487488
SymToTensorPass(),
488489
RemoveNoopPass(),
490+
PruneEmptyTensorsPass(),
489491
RemoveToCopyPass(),
490492
]
491493
).passes
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
from typing import cast, List
9+
10+
import torch
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass, PassResult
13+
from torch.fx import GraphModule, Node
14+
15+
# This is a list of ops that are No Ops if used with an empty tensor.
16+
# Which means that if we remove the empty tensor as input to this op,
17+
# The result of the operation will stay the same
18+
19+
20+
class PruneEmptyTensorsPass(ExportPass):
21+
"""
22+
Removes Any empty tensors from the graph that can safely be removed
23+
without affecting the results of the graph. Currently we remove empty
24+
tensor operations from the following ops:
25+
- aten.cat.default
26+
"""
27+
28+
def remove_empty_tensors_from_cat(
29+
self, graph_module: GraphModule, cat_node: Node
30+
) -> None:
31+
"""
32+
Removes empty tensors from the graph that are inputs to aten.cat.default
33+
"""
34+
concat_list = cast(List[Node], cat_node.args[0])
35+
pruned_concat_list = []
36+
for input_arg in concat_list:
37+
input_arg_tensor = input_arg.meta["val"]
38+
if input_arg_tensor.numel() != 0:
39+
pruned_concat_list.append(input_arg)
40+
41+
cat_node.args = (pruned_concat_list,) + cat_node.args[1:]
42+
if len(pruned_concat_list) == 0:
43+
# if all the inputs to the cat are empty tensors, then we can replace
44+
# this concat node with an aten full like
45+
cat_tensor = cat_node.meta["val"]
46+
with graph_module.graph.inserting_after(cat_node):
47+
full_like = graph_module.graph.create_node(
48+
"call_function",
49+
target=exir_ops.edge.aten.full.default,
50+
args=(tuple(cat_tensor.shape), 0),
51+
kwargs={"dtype": cat_tensor.dtype},
52+
)
53+
full_like.meta = cat_node.meta
54+
cat_node.replace_all_uses_with(full_like)
55+
56+
def call(self, graph_module: GraphModule) -> PassResult:
57+
for node in graph_module.graph.nodes:
58+
if node.op != "call_function":
59+
continue
60+
61+
if node.target == torch.ops.aten.cat.default:
62+
self.remove_empty_tensors_from_cat(graph_module, node)
63+
64+
graph_module.graph.eliminate_dead_code()
65+
graph_module.graph.lint()
66+
67+
return PassResult(graph_module, True)

exir/tests/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,20 @@ python_unittest(
443443
],
444444
)
445445

446+
python_unittest(
447+
name = "test_prune_empty_tensors",
448+
srcs = [
449+
"test_prune_empty_tensors_pass.py",
450+
],
451+
deps = [
452+
"//caffe2:torch",
453+
"//executorch/exir:lib",
454+
"//executorch/exir:memory",
455+
"//executorch/exir/capture:config",
456+
"//executorch/exir/passes:lib",
457+
],
458+
)
459+
446460
python_unittest(
447461
name = "warnings",
448462
srcs = [
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
import torch.nn as nn
11+
from executorch.exir import to_edge
12+
from executorch.exir.capture._config import ExecutorchBackendConfig
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.passes import MemoryPlanningPass
15+
16+
17+
class TestCat(nn.Module):
18+
def forward(self, x, y, z):
19+
empty = torch.empty((0, 6))
20+
return torch.cat([empty, x, empty, y, z, empty])
21+
22+
def get_example_inputs(self):
23+
return (torch.rand(5, 6), torch.rand(5, 6), torch.rand(5, 6))
24+
25+
26+
class TestPruneEmptyTensors(unittest.TestCase):
27+
def test_empty_tensor_removed_from_cat(self) -> None:
28+
model = TestCat()
29+
model.eval()
30+
example_inputs = model.get_example_inputs()
31+
ep = torch.export.export(model, example_inputs, strict=True)
32+
etpm = to_edge(ep).to_executorch(
33+
config=ExecutorchBackendConfig(
34+
remove_view_copy=False,
35+
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
36+
),
37+
)
38+
39+
for node in etpm.exported_program().graph_module.graph.nodes:
40+
if node.target in [
41+
exir_ops.edge.aten.cat.default,
42+
torch.ops.aten.cat.default,
43+
]:
44+
self.assertTrue(len(node.all_input_nodes) == 3)
45+
for input_arg in node.all_input_nodes:
46+
tensor_val = input_arg.meta["val"]
47+
self.assertTrue(tensor_val.numel() != 0)
48+
49+
actual = etpm.exported_program().module()(*example_inputs)
50+
51+
reference = model(*example_inputs)
52+
53+
self.assertTrue(torch.allclose(actual, reference))
54+
55+
def test_cat_removed_all_empty(self) -> None:
56+
model = TestCat()
57+
model.eval()
58+
example_inputs = (torch.empty((0, 6)), torch.empty((0, 6)), torch.empty((0, 6)))
59+
ep = torch.export.export(model, example_inputs, strict=True)
60+
etpm = to_edge(ep).to_executorch(
61+
config=ExecutorchBackendConfig(
62+
remove_view_copy=False,
63+
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
64+
),
65+
)
66+
67+
for node in etpm.exported_program().graph_module.graph.nodes:
68+
self.assertFalse(
69+
node.target
70+
in [exir_ops.edge.aten.cat.default, torch.ops.aten.cat.default]
71+
)
72+
73+
actual = etpm.exported_program().module()(*example_inputs)
74+
75+
reference = model(*example_inputs)
76+
77+
self.assertTrue(torch.allclose(actual, reference))

0 commit comments

Comments
 (0)