Skip to content

Commit 713b5e1

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Support 5-input concat in XNNPACK delegate
Summary: I noticed that support for 5-input concatenate ops was added to the XNNPACK library subgraph layer. We can support this in the delegate. Differential Revision: D67439458
1 parent 7d07409 commit 713b5e1

File tree

7 files changed

+93
-38
lines changed

7 files changed

+93
-38
lines changed

backends/xnnpack/operators/op_cat.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
XNNConcatenate2,
1818
XNNConcatenate3,
1919
XNNConcatenate4,
20+
XNNConcatenate5,
2021
XNNGraph,
2122
XNode,
2223
)
@@ -71,6 +72,7 @@ def define_node(
7172
input2_id=vals_to_ids[list_of_tensors[1]],
7273
input3_id=XNN_INVALID_VALUE_ID,
7374
input4_id=XNN_INVALID_VALUE_ID,
75+
input5_id=XNN_INVALID_VALUE_ID,
7476
output_id=vals_to_ids[node],
7577
flags=0,
7678
)
@@ -81,6 +83,7 @@ def define_node(
8183
input2_id=vals_to_ids[list_of_tensors[1]],
8284
input3_id=vals_to_ids[list_of_tensors[2]],
8385
input4_id=XNN_INVALID_VALUE_ID,
86+
input5_id=XNN_INVALID_VALUE_ID,
8487
output_id=vals_to_ids[node],
8588
flags=0,
8689
)
@@ -91,6 +94,18 @@ def define_node(
9194
input2_id=vals_to_ids[list_of_tensors[1]],
9295
input3_id=vals_to_ids[list_of_tensors[2]],
9396
input4_id=vals_to_ids[list_of_tensors[3]],
97+
input5_id=XNN_INVALID_VALUE_ID,
98+
output_id=vals_to_ids[node],
99+
flags=0,
100+
)
101+
elif num_tensors_to_cat == 5:
102+
xnode = XNNConcatenate5(
103+
axis=axis,
104+
input1_id=vals_to_ids[list_of_tensors[0]],
105+
input2_id=vals_to_ids[list_of_tensors[1]],
106+
input3_id=vals_to_ids[list_of_tensors[2]],
107+
input4_id=vals_to_ids[list_of_tensors[3]],
108+
input5_id=vals_to_ids[list_of_tensors[4]],
94109
output_id=vals_to_ids[node],
95110
flags=0,
96111
)

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,17 +172,17 @@ class CatConfig(GenericNodePartitionerConfig):
172172

173173
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
174174
"""
175-
Only support concatenation of 2 - 4 tensors
175+
Only support concatenation of 2 - 5 tensors
176176
"""
177177
if not self.check_common_constraints(node, ep):
178178
return False
179179

180180
num_tensors = len(node.all_input_nodes)
181181

182-
if not (num_tensors >= 2 and num_tensors <= 4):
182+
if not (num_tensors >= 2 and num_tensors <= 5):
183183
why(
184184
node,
185-
reason=f"only support concatenation of 2 - 4 tensors, got {num_tensors} tensors",
185+
reason=f"only support concatenation of 2 - 5 tensors, got {num_tensors} tensors",
186186
)
187187
return False
188188

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1600,7 +1600,7 @@ Error defineConcatenate2Node(
16001600
}
16011601

16021602
/*
1603-
Defines serialized concatenate2 node into the subgraph,
1603+
Defines serialized concatenate3 node into the subgraph,
16041604
using the remapped ids to map the serialized ids,
16051605
to the new ids generated when defining the tensor value
16061606
*/
@@ -1633,7 +1633,7 @@ Error defineConcatenate3Node(
16331633
}
16341634

16351635
/*
1636-
Defines serialized concatenate2 node into the subgraph,
1636+
Defines serialized concatenate4 node into the subgraph,
16371637
using the remapped ids to map the serialized ids,
16381638
to the new ids generated when defining the tensor value
16391639
*/
@@ -1666,6 +1666,41 @@ Error defineConcatenate4Node(
16661666
return Error::Ok;
16671667
}
16681668

1669+
/*
1670+
Defines serialized concatenate5 node into the subgraph,
1671+
using the remapped ids to map the serialized ids,
1672+
to the new ids generated when defining the tensor value
1673+
*/
1674+
Error defineConcatenate5Node(
1675+
xnn_subgraph_t subgraph_ptr,
1676+
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1677+
const NodePtr node,
1678+
const fb_xnnpack::XNNGraph* graph) noexcept {
1679+
MAYBE_UNUSED(graph);
1680+
1681+
auto graph_node = node->xnode_union_as_XNNConcatenate5();
1682+
1683+
xnn_status status = xnn_define_concatenate5(
1684+
subgraph_ptr,
1685+
graph_node->axis(),
1686+
remapped_ids.at(graph_node->input1_id()),
1687+
remapped_ids.at(graph_node->input2_id()),
1688+
remapped_ids.at(graph_node->input3_id()),
1689+
remapped_ids.at(graph_node->input4_id()),
1690+
remapped_ids.at(graph_node->input5_id()),
1691+
remapped_ids.at(graph_node->output_id()),
1692+
graph_node->flags());
1693+
1694+
ET_CHECK_OR_RETURN_ERROR(
1695+
status == xnn_status_success,
1696+
Internal,
1697+
"Failed to create cat5 node %i with code: %s",
1698+
node->debug_handle(),
1699+
xnn_status_to_string(status));
1700+
1701+
return Error::Ok;
1702+
}
1703+
16691704
/*
16701705
Defines serialized static_slice node into the subgraph,
16711706
using the remapped ids to map the serialized ids,
@@ -1832,6 +1867,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
18321867
_DEFINE(Concatenate2)
18331868
_DEFINE(Concatenate3)
18341869
_DEFINE(Concatenate4)
1870+
_DEFINE(Concatenate5)
18351871
_DEFINE(StaticSlice)
18361872
_DEFINE(ScaledDotProductAttention)
18371873
_DEFINE(BatchMatrixMultiply)

backends/xnnpack/serialization/runtime_schema.fbs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ union XNodeUnion {
133133
XNNConcatenate2: _XNNCat,
134134
XNNConcatenate3: _XNNCat,
135135
XNNConcatenate4: _XNNCat,
136+
XNNConcatenate5: _XNNCat,
136137
XNNStaticSlice,
137138
XNNScaledDotProductAttention,
138139
XNNBatchMatrixMultiply: _XNNNode2x1,
@@ -209,6 +210,7 @@ table _XNNCat {
209210
input4_id: uint;
210211
output_id: uint;
211212
flags: uint;
213+
input5_id: uint;
212214
}
213215

214216
table XNNELU {

backends/xnnpack/serialization/schema.fbs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ union XNodeUnion {
129129
XNNConcatenate2: _XNNCat,
130130
XNNConcatenate3: _XNNCat,
131131
XNNConcatenate4: _XNNCat,
132+
XNNConcatenate5: _XNNCat,
132133
XNNStaticSlice,
133134
XNNScaledDotProductAttention,
134135
XNNBatchMatrixMultiply: _XNNNode2x1,
@@ -205,6 +206,7 @@ table _XNNCat {
205206
input4_id: uint;
206207
output_id: uint;
207208
flags: uint;
209+
input5_id: uint;
208210
}
209211

210212
table XNNELU {

backends/xnnpack/serialization/xnnpack_graph_schema.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class XNNCat:
4242
input4_id: int
4343
output_id: int
4444
flags: int
45+
input5_id: int
4546

4647

4748
# Generic node data class for convolution type nodes
@@ -177,6 +178,11 @@ class XNNConcatenate4(XNNCat):
177178
pass
178179

179180

181+
@dataclass
182+
class XNNConcatenate5(XNNCat):
183+
pass
184+
185+
180186
@dataclass
181187
class XNNBatchMatrixMultiply(XNNNode2x1):
182188
pass

backends/xnnpack/test/ops/test_cat.py

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,9 @@
1111

1212

1313
class TestCat(unittest.TestCase):
14-
class Cat2(torch.nn.Module):
15-
def forward(self, arg1, arg2):
16-
xs = [arg1, arg2]
17-
x = torch.cat(xs)
18-
return x + x # Quantize by propagation.
19-
20-
class Cat3(torch.nn.Module):
21-
def forward(self, arg1, arg2, arg3):
22-
xs = [arg1, arg2, arg3]
23-
x = torch.cat(xs)
24-
return x + x # Quantize by propagation.
25-
26-
class Cat4(torch.nn.Module):
27-
def forward(self, arg1, arg2, arg3, arg4):
28-
xs = [arg1, arg2, arg3, arg4]
29-
x = torch.cat(xs)
30-
return x + x # Quantize by propagation.
31-
32-
class Cat5(torch.nn.Module):
33-
def forward(self, arg1, arg2, arg3, arg4, arg5):
34-
xs = [arg1, arg2, arg3, arg4, arg5]
14+
class Cat(torch.nn.Module):
15+
def forward(self, *args):
16+
xs = [*args]
3517
x = torch.cat(xs)
3618
return x + x # Quantize by propagation.
3719

@@ -84,7 +66,7 @@ def test_fp16_cat2(self):
8466
torch.randn(1, 2, 3).to(torch.float16),
8567
torch.randn(3, 2, 3).to(torch.float16),
8668
)
87-
self._test_cat(self.Cat2(), inputs)
69+
self._test_cat(self.Cat(), inputs)
8870

8971
def test_fp16_cat3(self):
9072
"""
@@ -95,7 +77,7 @@ def test_fp16_cat3(self):
9577
torch.randn(3, 2, 3).to(torch.float16),
9678
torch.randn(2, 2, 3).to(torch.float16),
9779
)
98-
self._test_cat(self.Cat3(), inputs)
80+
self._test_cat(self.Cat(), inputs)
9981

10082
def test_fp16_cat4(self):
10183
"""
@@ -107,15 +89,15 @@ def test_fp16_cat4(self):
10789
torch.randn(2, 2, 3).to(torch.float16),
10890
torch.randn(5, 2, 3).to(torch.float16),
10991
)
110-
self._test_cat(self.Cat4(), inputs)
92+
self._test_cat(self.Cat(), inputs)
11193

11294
def test_fp32_cat2(self):
11395
inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3))
114-
self._test_cat(self.Cat2(), inputs)
96+
self._test_cat(self.Cat(), inputs)
11597

11698
def test_fp32_cat3(self):
11799
inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(2, 2, 3))
118-
self._test_cat(self.Cat3(), inputs)
100+
self._test_cat(self.Cat(), inputs)
119101

120102
def test_fp32_cat4(self):
121103
inputs = (
@@ -124,15 +106,25 @@ def test_fp32_cat4(self):
124106
torch.randn(2, 2, 3),
125107
torch.randn(5, 2, 3),
126108
)
127-
self._test_cat(self.Cat4(), inputs)
109+
self._test_cat(self.Cat(), inputs)
110+
111+
def test_fp32_cat5(self):
112+
inputs = (
113+
torch.randn(1, 2, 3),
114+
torch.randn(3, 2, 3),
115+
torch.randn(2, 2, 3),
116+
torch.randn(5, 2, 3),
117+
torch.randn(1, 2, 3),
118+
)
119+
self._test_cat(self.Cat(), inputs)
128120

129121
def test_qs8_cat2(self):
130122
inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3))
131-
self._test_cat(self.Cat2(), inputs, cat_num=2, quant=True)
123+
self._test_cat(self.Cat(), inputs, cat_num=2, quant=True)
132124

133125
def test_qs8_cat3(self):
134126
inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(2, 2, 3))
135-
self._test_cat(self.Cat3(), inputs, cat_num=3, quant=True)
127+
self._test_cat(self.Cat(), inputs, cat_num=3, quant=True)
136128

137129
def test_qs8_cat4(self):
138130
inputs = (
@@ -141,7 +133,7 @@ def test_qs8_cat4(self):
141133
torch.randn(2, 2, 3),
142134
torch.randn(5, 2, 3),
143135
)
144-
self._test_cat(self.Cat4(), inputs, cat_num=4, quant=True)
136+
self._test_cat(self.Cat(), inputs, cat_num=4, quant=True)
145137

146138
def test_fp32_cat_unsupported(self):
147139
"""
@@ -153,9 +145,10 @@ def test_fp32_cat_unsupported(self):
153145
torch.randn(2, 2, 3),
154146
torch.randn(5, 2, 3),
155147
torch.randn(1, 2, 3),
148+
torch.randn(2, 2, 3),
156149
)
157150
(
158-
Tester(self.Cat5(), inputs)
151+
Tester(self.Cat(), inputs)
159152
.export()
160153
.check_count({"torch.ops.aten.cat": 1})
161154
.to_edge_transform_and_lower()
@@ -164,17 +157,18 @@ def test_fp32_cat_unsupported(self):
164157

165158
def test_fp32_cat_unsupported_legacy_mode(self):
166159
"""
167-
XNNPACK only supports concatenating up to 4 values, so it should not delegate here.
160+
XNNPACK only supports concatenating up to 5 values, so it should not delegate here.
168161
"""
169162
inputs = (
170163
torch.randn(1, 2, 3),
171164
torch.randn(3, 2, 3),
172165
torch.randn(2, 2, 3),
173166
torch.randn(5, 2, 3),
174167
torch.randn(1, 2, 3),
168+
torch.randn(6, 2, 3),
175169
)
176170
(
177-
Tester(self.Cat5(), inputs)
171+
Tester(self.Cat(), inputs)
178172
.export()
179173
.check_count({"torch.ops.aten.cat": 1})
180174
.to_edge()

0 commit comments

Comments
 (0)