Skip to content

Commit ed15042

Browse files
authored
Support 5-input concat in XNNPACK delegate
Differential Revision: D67439458 Pull Request resolved: #7401
1 parent 7af5f6d commit ed15042

File tree

7 files changed

+98
-38
lines changed

7 files changed

+98
-38
lines changed

backends/xnnpack/operators/op_cat.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-unsafe
8+
79
from typing import cast, Dict, List
810

911
import torch
@@ -17,6 +19,7 @@
1719
XNNConcatenate2,
1820
XNNConcatenate3,
1921
XNNConcatenate4,
22+
XNNConcatenate5,
2023
XNNGraph,
2124
XNode,
2225
)
@@ -71,6 +74,7 @@ def define_node(
7174
input2_id=vals_to_ids[list_of_tensors[1]],
7275
input3_id=XNN_INVALID_VALUE_ID,
7376
input4_id=XNN_INVALID_VALUE_ID,
77+
input5_id=XNN_INVALID_VALUE_ID,
7478
output_id=vals_to_ids[node],
7579
flags=0,
7680
)
@@ -81,6 +85,7 @@ def define_node(
8185
input2_id=vals_to_ids[list_of_tensors[1]],
8286
input3_id=vals_to_ids[list_of_tensors[2]],
8387
input4_id=XNN_INVALID_VALUE_ID,
88+
input5_id=XNN_INVALID_VALUE_ID,
8489
output_id=vals_to_ids[node],
8590
flags=0,
8691
)
@@ -91,6 +96,18 @@ def define_node(
9196
input2_id=vals_to_ids[list_of_tensors[1]],
9297
input3_id=vals_to_ids[list_of_tensors[2]],
9398
input4_id=vals_to_ids[list_of_tensors[3]],
99+
input5_id=XNN_INVALID_VALUE_ID,
100+
output_id=vals_to_ids[node],
101+
flags=0,
102+
)
103+
elif num_tensors_to_cat == 5:
104+
xnode = XNNConcatenate5(
105+
axis=axis,
106+
input1_id=vals_to_ids[list_of_tensors[0]],
107+
input2_id=vals_to_ids[list_of_tensors[1]],
108+
input3_id=vals_to_ids[list_of_tensors[2]],
109+
input4_id=vals_to_ids[list_of_tensors[3]],
110+
input5_id=vals_to_ids[list_of_tensors[4]],
94111
output_id=vals_to_ids[node],
95112
flags=0,
96113
)

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,17 +174,17 @@ class CatConfig(GenericNodePartitionerConfig):
174174

175175
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
176176
"""
177-
Only support concatenation of 2 - 4 tensors
177+
Only support concatenation of 2 - 5 tensors
178178
"""
179179
if not self.check_common_constraints(node, ep):
180180
return False
181181

182182
num_tensors = len(node.all_input_nodes)
183183

184-
if not (num_tensors >= 2 and num_tensors <= 4):
184+
if not (num_tensors >= 2 and num_tensors <= 5):
185185
why(
186186
node,
187-
reason=f"only support concatenation of 2 - 4 tensors, got {num_tensors} tensors",
187+
reason=f"only support concatenation of 2 - 5 tensors, got {num_tensors} tensors",
188188
)
189189
return False
190190

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1600,7 +1600,7 @@ Error defineConcatenate2Node(
16001600
}
16011601

16021602
/*
1603-
Defines serialized concatenate2 node into the subgraph,
1603+
Defines serialized concatenate3 node into the subgraph,
16041604
using the remapped ids to map the serialized ids,
16051605
to the new ids generated when defining the tensor value
16061606
*/
@@ -1633,7 +1633,7 @@ Error defineConcatenate3Node(
16331633
}
16341634

16351635
/*
1636-
Defines serialized concatenate2 node into the subgraph,
1636+
Defines serialized concatenate4 node into the subgraph,
16371637
using the remapped ids to map the serialized ids,
16381638
to the new ids generated when defining the tensor value
16391639
*/
@@ -1666,6 +1666,41 @@ Error defineConcatenate4Node(
16661666
return Error::Ok;
16671667
}
16681668

1669+
/*
1670+
Defines serialized concatenate5 node into the subgraph,
1671+
using the remapped ids to map the serialized ids,
1672+
to the new ids generated when defining the tensor value
1673+
*/
1674+
Error defineConcatenate5Node(
1675+
xnn_subgraph_t subgraph_ptr,
1676+
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1677+
const NodePtr node,
1678+
const fb_xnnpack::XNNGraph* graph) noexcept {
1679+
MAYBE_UNUSED(graph);
1680+
1681+
auto graph_node = node->xnode_union_as_XNNConcatenate5();
1682+
1683+
xnn_status status = xnn_define_concatenate5(
1684+
subgraph_ptr,
1685+
graph_node->axis(),
1686+
remapped_ids.at(graph_node->input1_id()),
1687+
remapped_ids.at(graph_node->input2_id()),
1688+
remapped_ids.at(graph_node->input3_id()),
1689+
remapped_ids.at(graph_node->input4_id()),
1690+
remapped_ids.at(graph_node->input5_id()),
1691+
remapped_ids.at(graph_node->output_id()),
1692+
graph_node->flags());
1693+
1694+
ET_CHECK_OR_RETURN_ERROR(
1695+
status == xnn_status_success,
1696+
Internal,
1697+
"Failed to create cat5 node %i with code: %s",
1698+
node->debug_handle(),
1699+
xnn_status_to_string(status));
1700+
1701+
return Error::Ok;
1702+
}
1703+
16691704
/*
16701705
Defines serialized static_slice node into the subgraph,
16711706
using the remapped ids to map the serialized ids,
@@ -1832,6 +1867,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
18321867
_DEFINE(Concatenate2)
18331868
_DEFINE(Concatenate3)
18341869
_DEFINE(Concatenate4)
1870+
_DEFINE(Concatenate5)
18351871
_DEFINE(StaticSlice)
18361872
_DEFINE(ScaledDotProductAttention)
18371873
_DEFINE(BatchMatrixMultiply)

backends/xnnpack/serialization/runtime_schema.fbs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ union XNodeUnion {
136136
XNNStaticSlice,
137137
XNNScaledDotProductAttention,
138138
XNNBatchMatrixMultiply: _XNNNode2x1,
139+
XNNConcatenate5: _XNNCat,
139140
}
140141

141142
union XValueUnion {
@@ -209,6 +210,7 @@ table _XNNCat {
209210
input4_id: uint;
210211
output_id: uint;
211212
flags: uint;
213+
input5_id: uint;
212214
}
213215

214216
table XNNELU {

backends/xnnpack/serialization/schema.fbs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ union XNodeUnion {
132132
XNNStaticSlice,
133133
XNNScaledDotProductAttention,
134134
XNNBatchMatrixMultiply: _XNNNode2x1,
135+
XNNConcatenate5: _XNNCat,
135136
}
136137

137138
union XValueUnion {
@@ -205,6 +206,7 @@ table _XNNCat {
205206
input4_id: uint;
206207
output_id: uint;
207208
flags: uint;
209+
input5_id: uint;
208210
}
209211

210212
table XNNELU {

backends/xnnpack/serialization/xnnpack_graph_schema.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class XNNCat:
4242
input4_id: int
4343
output_id: int
4444
flags: int
45+
input5_id: int
4546

4647

4748
# Generic node data class for convolution type nodes
@@ -177,6 +178,11 @@ class XNNConcatenate4(XNNCat):
177178
pass
178179

179180

181+
@dataclass
182+
class XNNConcatenate5(XNNCat):
183+
pass
184+
185+
180186
@dataclass
181187
class XNNBatchMatrixMultiply(XNNNode2x1):
182188
pass
@@ -357,6 +363,7 @@ class XNNScaledDotProductAttention:
357363
XNNConcatenate2,
358364
XNNConcatenate3,
359365
XNNConcatenate4,
366+
XNNConcatenate5,
360367
XNNStaticSlice,
361368
XNNScaledDotProductAttention,
362369
XNNBatchMatrixMultiply,

backends/xnnpack/test/ops/test_cat.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,18 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-unsafe
8+
79
import unittest
810

911
import torch
1012
from executorch.backends.xnnpack.test.tester import Tester
1113

1214

1315
class TestCat(unittest.TestCase):
14-
class Cat2(torch.nn.Module):
15-
def forward(self, arg1, arg2):
16-
xs = [arg1, arg2]
17-
x = torch.cat(xs)
18-
return x + x # Quantize by propagation.
19-
20-
class Cat3(torch.nn.Module):
21-
def forward(self, arg1, arg2, arg3):
22-
xs = [arg1, arg2, arg3]
23-
x = torch.cat(xs)
24-
return x + x # Quantize by propagation.
25-
26-
class Cat4(torch.nn.Module):
27-
def forward(self, arg1, arg2, arg3, arg4):
28-
xs = [arg1, arg2, arg3, arg4]
29-
x = torch.cat(xs)
30-
return x + x # Quantize by propagation.
31-
32-
class Cat5(torch.nn.Module):
33-
def forward(self, arg1, arg2, arg3, arg4, arg5):
34-
xs = [arg1, arg2, arg3, arg4, arg5]
16+
class Cat(torch.nn.Module):
17+
def forward(self, *args):
18+
xs = [*args]
3519
x = torch.cat(xs)
3620
return x + x # Quantize by propagation.
3721

@@ -84,7 +68,7 @@ def test_fp16_cat2(self):
8468
torch.randn(1, 2, 3).to(torch.float16),
8569
torch.randn(3, 2, 3).to(torch.float16),
8670
)
87-
self._test_cat(self.Cat2(), inputs)
71+
self._test_cat(self.Cat(), inputs)
8872

8973
def test_fp16_cat3(self):
9074
"""
@@ -95,7 +79,7 @@ def test_fp16_cat3(self):
9579
torch.randn(3, 2, 3).to(torch.float16),
9680
torch.randn(2, 2, 3).to(torch.float16),
9781
)
98-
self._test_cat(self.Cat3(), inputs)
82+
self._test_cat(self.Cat(), inputs)
9983

10084
def test_fp16_cat4(self):
10185
"""
@@ -107,15 +91,15 @@ def test_fp16_cat4(self):
10791
torch.randn(2, 2, 3).to(torch.float16),
10892
torch.randn(5, 2, 3).to(torch.float16),
10993
)
110-
self._test_cat(self.Cat4(), inputs)
94+
self._test_cat(self.Cat(), inputs)
11195

11296
def test_fp32_cat2(self):
11397
inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3))
114-
self._test_cat(self.Cat2(), inputs)
98+
self._test_cat(self.Cat(), inputs)
11599

116100
def test_fp32_cat3(self):
117101
inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(2, 2, 3))
118-
self._test_cat(self.Cat3(), inputs)
102+
self._test_cat(self.Cat(), inputs)
119103

120104
def test_fp32_cat4(self):
121105
inputs = (
@@ -124,15 +108,25 @@ def test_fp32_cat4(self):
124108
torch.randn(2, 2, 3),
125109
torch.randn(5, 2, 3),
126110
)
127-
self._test_cat(self.Cat4(), inputs)
111+
self._test_cat(self.Cat(), inputs)
112+
113+
def test_fp32_cat5(self):
114+
inputs = (
115+
torch.randn(1, 2, 3),
116+
torch.randn(3, 2, 3),
117+
torch.randn(2, 2, 3),
118+
torch.randn(5, 2, 3),
119+
torch.randn(1, 2, 3),
120+
)
121+
self._test_cat(self.Cat(), inputs)
128122

129123
def test_qs8_cat2(self):
130124
inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3))
131-
self._test_cat(self.Cat2(), inputs, cat_num=2, quant=True)
125+
self._test_cat(self.Cat(), inputs, cat_num=2, quant=True)
132126

133127
def test_qs8_cat3(self):
134128
inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(2, 2, 3))
135-
self._test_cat(self.Cat3(), inputs, cat_num=3, quant=True)
129+
self._test_cat(self.Cat(), inputs, cat_num=3, quant=True)
136130

137131
def test_qs8_cat4(self):
138132
inputs = (
@@ -141,7 +135,7 @@ def test_qs8_cat4(self):
141135
torch.randn(2, 2, 3),
142136
torch.randn(5, 2, 3),
143137
)
144-
self._test_cat(self.Cat4(), inputs, cat_num=4, quant=True)
138+
self._test_cat(self.Cat(), inputs, cat_num=4, quant=True)
145139

146140
def test_fp32_cat_unsupported(self):
147141
"""
@@ -153,9 +147,10 @@ def test_fp32_cat_unsupported(self):
153147
torch.randn(2, 2, 3),
154148
torch.randn(5, 2, 3),
155149
torch.randn(1, 2, 3),
150+
torch.randn(2, 2, 3),
156151
)
157152
(
158-
Tester(self.Cat5(), inputs)
153+
Tester(self.Cat(), inputs)
159154
.export()
160155
.check_count({"torch.ops.aten.cat": 1})
161156
.to_edge_transform_and_lower()
@@ -164,17 +159,18 @@ def test_fp32_cat_unsupported(self):
164159

165160
def test_fp32_cat_unsupported_legacy_mode(self):
166161
"""
167-
XNNPACK only supports concatenating up to 4 values, so it should not delegate here.
162+
XNNPACK only supports concatenating up to 5 values, so it should not delegate here.
168163
"""
169164
inputs = (
170165
torch.randn(1, 2, 3),
171166
torch.randn(3, 2, 3),
172167
torch.randn(2, 2, 3),
173168
torch.randn(5, 2, 3),
174169
torch.randn(1, 2, 3),
170+
torch.randn(6, 2, 3),
175171
)
176172
(
177-
Tester(self.Cat5(), inputs)
173+
Tester(self.Cat(), inputs)
178174
.export()
179175
.check_count({"torch.ops.aten.cat": 1})
180176
.to_edge()

0 commit comments

Comments
 (0)