Skip to content

Commit 4cbabce

Browse files
mcr229YIWENX14
authored andcommitted
Support Empty Input Tensors and > 5 Cat Inputs
Differential Revision: D68523312 Pull Request resolved: #7855
1 parent 8fee8d5 commit 4cbabce

File tree

6 files changed

+257
-51
lines changed

6 files changed

+257
-51
lines changed

backends/xnnpack/_passes/TARGETS

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,7 @@ oncall("executorch")
44

55
python_library(
66
name = "xnnpack_passes",
7-
srcs = [
8-
"__init__.py",
9-
"channels_last_tagged_reshape_pass.py",
10-
"conv1d_unsqueeze_pass.py",
11-
"convert_to_linear.py",
12-
"convert_to_sdpa.py",
13-
"convert_to_upsample_bilinear2d.py",
14-
"fuse_activation_pass.py",
15-
"fuse_batch_norm_with_conv.py",
16-
"prelu_reshape_pass.py",
17-
"remove_getitem_op.py",
18-
"tag_implicit_q_dq_pass.py",
19-
"xnnpack_pass.py",
20-
],
7+
srcs = native.glob(["*.py"]),
218
deps = [
229
"//caffe2:torch",
2310
"//executorch/backends/transforms:addmm_mm_to_linear",

backends/xnnpack/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from executorch.backends.xnnpack._passes.convert_to_upsample_bilinear2d import (
1818
ConvertToUpsampleBilinear2d,
1919
)
20+
from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate
2021
from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass
2122
from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import (
2223
FuseBatchNormWithConvPass,
@@ -63,6 +64,7 @@ def __init__(
6364
ConstPropPass,
6465
FuseBatchNormWithConvPass,
6566
FuseActivationPass,
67+
DecomposeConcatenate,
6668
RemoveGetItemPass,
6769
Conv1dUnsqueezePass,
6870
PReLUReshapePass,
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
import torch
10+
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
15+
logger = logging.getLogger(__name__)
16+
logger.setLevel(logging.WARNING)
17+
18+
19+
class DecomposeConcatenate(ExportPass):
20+
"""
21+
XNNPACK's Concatenate operation only supports concatenation for <= 5 tensors
22+
at a time. As a result, to support concatenates with > 5 tensors, we can decompose
23+
concatenates into sequences of cats each with <= 5 tensors.
24+
25+
Example:
26+
Before Pass:
27+
cat: "f32" = torch.ops.aten.cat.default([t1, t2, t3, t4, t5, t6], 1);
28+
29+
After Pass:
30+
cat: "f32" = torch.ops.aten.cat.default([t1, t2, t3, t4, t5], 1);
31+
cat_1: "f32" = torch.ops.aten.cat.default([cat, t6], 1);
32+
"""
33+
34+
def call(self, graph_module: torch.fx.GraphModule):
35+
gm = graph_module
36+
for node in gm.graph.nodes:
37+
if (
38+
node.op == "call_function"
39+
and node.target.__name__ == "aten.cat.default"
40+
):
41+
concat_args = node.args
42+
nodes_to_concat = node.args[0]
43+
if len(nodes_to_concat) <= 5:
44+
continue
45+
46+
is_quantized = all(
47+
is_dequant(node) for node in nodes_to_concat
48+
) and all(is_quant(node) for node in node.users.keys())
49+
50+
# replace the cat args with the same args but only with the first 5 nodes
51+
new_concat_args = (nodes_to_concat[:5],) + concat_args[1:]
52+
node.args = new_concat_args
53+
54+
remainder_nodes_to_concat = nodes_to_concat[5:]
55+
with gm.graph.inserting_after(node):
56+
logger.debug(f"Decomposing cat node {node}")
57+
remainder_concat_node = gm.graph.create_node(
58+
"call_function",
59+
target=exir_ops.edge.aten.cat.default,
60+
args=([],), # we will replace this remainder_nodes later
61+
kwargs=node.kwargs,
62+
)
63+
node.replace_all_uses_with(remainder_concat_node)
64+
if is_quantized:
65+
# if quantized we need to enforce the q/dq pattern for the newly inserted
66+
# concat node
67+
q_params = nodes_to_concat[0].args[1:]
68+
q_kwargs = nodes_to_concat[0].kwargs
69+
# Quantizer enforces all the inputs and output to a concat node must share
70+
# the same qparams, this means the newly inserted q/dq pair must share the
71+
# same qparams as the first quantized input in the concat node.
72+
with gm.graph.inserting_after(node):
73+
logger.debug(
74+
f"Inserting Q/DQ pair for new cat node {remainder_concat_node}"
75+
)
76+
q_node = gm.graph.create_node(
77+
"call_function",
78+
target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
79+
args=(node,) + q_params,
80+
kwargs=q_kwargs,
81+
)
82+
with gm.graph.inserting_after(q_node):
83+
dq_node = gm.graph.create_node(
84+
"call_function",
85+
target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
86+
args=(q_node,) + q_params,
87+
kwargs=q_kwargs,
88+
)
89+
remainder_concat_node.args = (
90+
[dq_node] + remainder_nodes_to_concat,
91+
) + node.args[1:]
92+
else:
93+
remainder_concat_node.args = (
94+
[node] + remainder_nodes_to_concat,
95+
) + node.args[1:]
96+
97+
gm.recompile()
98+
new_gm = super().call(gm).graph_module
99+
return PassResult(new_gm, True)

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,10 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
181181

182182
num_tensors = len(node.all_input_nodes)
183183

184-
if not (num_tensors >= 2 and num_tensors <= 5):
184+
if not (num_tensors >= 2):
185185
why(
186186
node,
187-
reason=f"only support concatenation of 2 - 5 tensors, got {num_tensors} tensors",
187+
reason=f"only support concatenation of > 2 tensors, got {num_tensors} tensors",
188188
)
189189
return False
190190

backends/xnnpack/test/ops/test_cat.py

Lines changed: 44 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,13 @@
1414

1515
class TestCat(unittest.TestCase):
1616
class Cat(torch.nn.Module):
17+
def __init__(self, dim=0):
18+
super().__init__()
19+
self.dim = dim
20+
1721
def forward(self, *args):
1822
xs = [*args]
19-
x = torch.cat(xs)
23+
x = torch.cat(xs, dim=self.dim)
2024
return x + x # Quantize by propagation.
2125

2226
def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2):
@@ -27,7 +31,6 @@ def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2):
2731
tester.quantize()
2832

2933
tester.export().check_count({"torch.ops.aten.cat": 1})
30-
tester.dump_artifact()
3134

3235
if quant:
3336
# Expect multiple quantize ops - one per input, cat, and add.
@@ -93,6 +96,29 @@ def test_fp16_cat4(self):
9396
)
9497
self._test_cat(self.Cat(), inputs)
9598

99+
def test_fp16_cat5(self):
100+
"""
101+
Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first.
102+
"""
103+
inputs = (
104+
torch.randn(1, 2, 3).to(torch.float16),
105+
torch.randn(3, 2, 3).to(torch.float16),
106+
torch.randn(2, 2, 3).to(torch.float16),
107+
torch.randn(5, 2, 3).to(torch.float16),
108+
torch.randn(5, 2, 3).to(torch.float16),
109+
)
110+
self._test_cat(self.Cat(), inputs)
111+
112+
def test_fp16_cat_gt_5(self):
113+
"""
114+
Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first.
115+
"""
116+
for num_inputs in range(6, 10):
117+
inputs = []
118+
for _ in range(num_inputs):
119+
inputs.append(torch.randn(1, 2, 3).to(torch.float16))
120+
self._test_cat(self.Cat(), tuple(inputs))
121+
96122
def test_fp32_cat2(self):
97123
inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3))
98124
self._test_cat(self.Cat(), inputs)
@@ -120,6 +146,13 @@ def test_fp32_cat5(self):
120146
)
121147
self._test_cat(self.Cat(), inputs)
122148

149+
def test_fp32_cat_gt_5(self):
150+
for num_inputs in range(6, 10):
151+
inputs = []
152+
for _ in range(num_inputs):
153+
inputs.append(torch.randn(1, 2, 3))
154+
self._test_cat(self.Cat(), tuple(inputs))
155+
123156
def test_qs8_cat2(self):
124157
inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3))
125158
self._test_cat(self.Cat(), inputs, cat_num=2, quant=True)
@@ -137,46 +170,22 @@ def test_qs8_cat4(self):
137170
)
138171
self._test_cat(self.Cat(), inputs, cat_num=4, quant=True)
139172

140-
def test_fp32_cat_unsupported(self):
141-
"""
142-
XNNPACK only supports concatenating up to 4 values, so it should not delegate here.
143-
"""
173+
def test_qs8_cat5(self):
144174
inputs = (
145175
torch.randn(1, 2, 3),
146176
torch.randn(3, 2, 3),
147177
torch.randn(2, 2, 3),
148178
torch.randn(5, 2, 3),
149-
torch.randn(1, 2, 3),
150-
torch.randn(2, 2, 3),
151-
)
152-
(
153-
Tester(self.Cat(), inputs)
154-
.export()
155-
.check_count({"torch.ops.aten.cat": 1})
156-
.to_edge_transform_and_lower()
157-
.check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1})
158-
)
159-
160-
def test_fp32_cat_unsupported_legacy_mode(self):
161-
"""
162-
XNNPACK only supports concatenating up to 5 values, so it should not delegate here.
163-
"""
164-
inputs = (
165-
torch.randn(1, 2, 3),
166-
torch.randn(3, 2, 3),
167-
torch.randn(2, 2, 3),
168179
torch.randn(5, 2, 3),
169-
torch.randn(1, 2, 3),
170-
torch.randn(6, 2, 3),
171-
)
172-
(
173-
Tester(self.Cat(), inputs)
174-
.export()
175-
.check_count({"torch.ops.aten.cat": 1})
176-
.to_edge()
177-
.partition()
178-
.check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1})
179180
)
181+
self._test_cat(self.Cat(), inputs, cat_num=5, quant=True)
182+
183+
def test_qs8_cat_gt_5(self):
184+
for num_inputs in range(6, 10):
185+
inputs = []
186+
for _ in range(num_inputs):
187+
inputs.append(torch.randn(1, 2, 3))
188+
self._test_cat(self.Cat(), tuple(inputs), cat_num=num_inputs, quant=True)
180189

181190
class CatNegativeDim(torch.nn.Module):
182191
def __init__(self):
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import math
8+
import unittest
9+
10+
import torch
11+
from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate
12+
from executorch.backends.xnnpack.test.tester import RunPasses, Tester
13+
14+
15+
class TestDecomposeCatPass(unittest.TestCase):
16+
PassStage = RunPasses([DecomposeConcatenate])
17+
cat_name = "executorch_exir_dialects_edge__ops_aten_cat_default"
18+
19+
class Cat(torch.nn.Module):
20+
def forward(self, *args):
21+
xs = [*args]
22+
x = torch.cat(xs)
23+
return x + x # Quantize by propagation.
24+
25+
def test_cat_gt_5(self):
26+
inputs = [
27+
torch.randn(1, 2, 3),
28+
]
29+
for num_inputs in range(6, 10):
30+
inputs = []
31+
for _ in range(num_inputs):
32+
inputs.append(torch.randn(1, 2, 3))
33+
34+
num_cats = int(len(inputs) > 5)
35+
num_cats += math.ceil((len(inputs) - 5) / 4)
36+
(
37+
Tester(self.Cat(), tuple(inputs))
38+
.export()
39+
.to_edge()
40+
.check_count({self.cat_name: 1})
41+
.run_passes(self.PassStage)
42+
.check_count({self.cat_name: num_cats})
43+
.run_method_and_compare_outputs()
44+
)
45+
46+
def test_cat_gt_10(self):
47+
inputs = [
48+
torch.randn(1, 2, 3),
49+
]
50+
for num_inputs in [11, 16, 18]:
51+
inputs = []
52+
for _ in range(num_inputs):
53+
inputs.append(torch.randn(1, 2, 3))
54+
55+
num_cats = int(len(inputs) > 5)
56+
num_cats += math.ceil((len(inputs) - 5) / 4)
57+
(
58+
Tester(self.Cat(), tuple(inputs))
59+
.export()
60+
.to_edge()
61+
.check_count({self.cat_name: 1})
62+
.run_passes(self.PassStage)
63+
.check_count({self.cat_name: num_cats})
64+
.run_method_and_compare_outputs()
65+
)
66+
67+
def test_qs8_cat_gt_5(self):
68+
inputs = [
69+
torch.randn(1, 2, 3),
70+
]
71+
for num_inputs in range(6, 10):
72+
inputs = []
73+
for _ in range(num_inputs):
74+
inputs.append(torch.randn(1, 2, 3))
75+
76+
num_cats = int(len(inputs) > 5)
77+
num_cats += math.ceil((len(inputs) - 5) / 4)
78+
(
79+
Tester(self.Cat(), tuple(inputs))
80+
.quantize()
81+
.export()
82+
.to_edge()
83+
.check_count({self.cat_name: 1})
84+
.run_passes(self.PassStage)
85+
.check_count({self.cat_name: num_cats})
86+
.run_method_and_compare_outputs()
87+
)
88+
89+
def test_qs8_cat_gt_10(self):
90+
inputs = [
91+
torch.randn(1, 2, 3),
92+
]
93+
for num_inputs in [11, 16, 18]:
94+
inputs = []
95+
for _ in range(num_inputs):
96+
inputs.append(torch.randn(1, 2, 3))
97+
98+
num_cats = int(len(inputs) > 5)
99+
num_cats += math.ceil((len(inputs) - 5) / 4)
100+
(
101+
Tester(self.Cat(), tuple(inputs))
102+
.quantize()
103+
.export()
104+
.to_edge()
105+
.check_count({self.cat_name: 1})
106+
.run_passes(self.PassStage)
107+
.check_count({self.cat_name: num_cats})
108+
.run_method_and_compare_outputs()
109+
)

0 commit comments

Comments
 (0)