1
+ import warnings
1
2
from typing import Dict
2
3
3
4
import executorch .backends .qualcomm .python .PyQnnWrapperAdaptor as PyQnnWrapper
5
+ import numpy as np
4
6
import torch
5
7
6
- from executorch .backends .qualcomm .utils .constants import QCOM_QUANT_ATTRS
8
+ from executorch .backends .qualcomm .utils .constants import QCOM_DATA , QCOM_QUANT_ATTRS
9
+ from executorch .exir .dialects ._ops import ops as exir_ops
7
10
8
- from .node_visitor import NodeVisitor
11
+ from .node_visitor import NodeVisitor , QNN_TENSOR_TYPE_MAP
9
12
from .node_visitor_manager import register_node_visitor
10
- from .qnn_constants import OpScatterNd , QNN_OP_PACKAGE_NAME_QTI_AISW
13
+ from .qnn_constants import (
14
+ OpConcat ,
15
+ OpReshape ,
16
+ OpScatterNd ,
17
+ OpTile ,
18
+ QNN_OP_PACKAGE_NAME_QTI_AISW ,
19
+ )
11
20
12
21
13
22
@register_node_visitor
@@ -22,6 +31,7 @@ def define_node(
22
31
node : torch .fx .Node ,
23
32
nodes_to_wrappers : Dict [torch .fx .Node , PyQnnWrapper .TensorWrapper ],
24
33
) -> PyQnnWrapper .PyQnnOpWrapper :
34
+ op_wrapper_list = []
25
35
input_node = self .get_node (node .args [0 ])
26
36
# Because the args[0] of index_put op doesn't annotate, need to fill in the quant_attr with the node here.
27
37
if quant_attrs := node .meta .get (QCOM_QUANT_ATTRS ):
@@ -35,38 +45,206 @@ def define_node(
35
45
PyQnnWrapper .Qnn_TensorType_t .QNN_TENSOR_TYPE_NATIVE ,
36
46
nodes_to_wrappers ,
37
47
)
38
- indicies_node = node .args [1 ]
39
- indices_list = [
40
- self .get_tensor (idx , idx ) for idx in indicies_node if idx is not None
41
- ]
42
-
43
- # Unpack the tuple
44
- indices_unpacked = [torch .flatten (idx ) for idx in indices_list ]
45
-
46
- # Convert to 2-D tensor
47
- indices_qnn = torch .cat (indices_unpacked ).unsqueeze (0 )
48
- indice_node = [n for n in indicies_node if isinstance (n , torch .fx .Node )]
49
- # TODO consider to write a pass to combine to one input tensor for indices
50
- assert len (indice_node ) == 1 , "Not support multiple indices tensor"
51
48
49
+ indicies_node = node .args [1 ]
50
+ index_node_dim = None
51
+ index_nodes = []
52
+ index_tensors = []
53
+ target_index = []
54
+ # If there is None in a list, it means all range at that dimension
55
+ # E.g., indicies_node: [None, None, aten__to_copy_default_1]
56
+ if isinstance (indicies_node , list ):
57
+ for index , idx_node in enumerate (indicies_node ):
58
+ # First, collect the indice_node and index of None to construct the shape of index node
59
+ # E.g., shape of input: [1, 1024, 12, 64]
60
+ # For "None" axis (assume indicies_node: [None, None, aten__to_copy_default_1]),
61
+ # target_index: [1, 1024, x], x is the shape of index_tensor, index_node_dim: 2
62
+ if isinstance (idx_node , torch .fx .Node ):
63
+ index_nodes .append (idx_node )
64
+ index_tensors .append (self .get_tensor (idx_node , idx_node ))
65
+ target_index .extend (index_tensors [- 1 ].size ())
66
+ index_node_dim = index
67
+ elif idx_node is None and index_node_dim is None :
68
+ # E.g., indicies_node: [None, aten__to_copy_default_1, None]
69
+ # Don't need to consider "None" after index_node.
70
+ target_index .append (input_tensor .size (index ))
71
+ else :
72
+ warnings .warn (
73
+ f"[QNN Delegate Op Builder]: Get the index { idx_node } that is neither a node nor None" ,
74
+ stacklevel = 1 ,
75
+ )
76
+ return
77
+ # Assume that there is only one node in list
78
+ assert len (index_nodes ) == 1 , "Not support multiple indices tensor"
79
+ indice_node = index_nodes [0 ]
80
+ indice_tensor = index_tensors [0 ]
52
81
indices_tensor_wrapper = self .define_tensor (
53
- indice_node [ 0 ] ,
82
+ indice_node ,
54
83
node ,
55
- indices_qnn ,
84
+ indice_tensor ,
56
85
PyQnnWrapper .Qnn_TensorType_t .QNN_TENSOR_TYPE_NATIVE ,
57
86
nodes_to_wrappers ,
58
87
)
59
- value_node = self .get_node (node .args [2 ])
60
88
61
- value_tensor = self .get_tensor (value_node , node )
89
+ # Need to reconstruct the index tensor.
90
+ # E.g., based on ScatterND Op Def in QNN Docs.
91
+ # Given that
92
+ # shape of input: [1, 12, 1024, 64]
93
+ # indicies_node: [None, None, aten__to_copy_default_1]
94
+ # shape of aten__to_copy_default_1: [1]
95
+ # The shape of index tensor should be [1, 12, 1, 3]
96
+ # The index tensor is treated as 4-dimensional tensor of 3-tuples,
97
+ # where each 3-tuple is a partial-index into input
98
+ # Reference code for QNN ScatterNd:
99
+ # output = np.copy(input)
100
+ # update_indices = indices.shape[:-1]
101
+ # for idx in np.ndindex(update_indices):
102
+ # output[indices[idx]] = updates[idx]
103
+
104
+ # Append one dimension to specify x-tuple
105
+ index_shape = target_index + [1 ]
106
+ # Reshape the index_node for tile op
107
+ reshape_shape = [
108
+ shape if id == index_node_dim else 1 for id , shape in enumerate (index_shape )
109
+ ]
110
+ reshape_output_tensor = indice_tensor .reshape (reshape_shape )
111
+ reshape_output_tensor_wrapper = self .define_custom_tensor_wrapper (
112
+ node_name = node .name + "_reshape" ,
113
+ tensor_type = PyQnnWrapper .Qnn_TensorType_t .QNN_TENSOR_TYPE_NATIVE ,
114
+ dtype = QNN_TENSOR_TYPE_MAP [reshape_output_tensor .dtype ],
115
+ quant_encoding = PyQnnWrapper .Qnn_QuantizationEncoding_t .QNN_QUANTIZATION_ENCODING_UNDEFINED ,
116
+ quant_configs = {},
117
+ dims = reshape_output_tensor .size (),
118
+ tensor = reshape_output_tensor ,
119
+ is_fake_tensor = True ,
120
+ nodes_to_wrappers = nodes_to_wrappers ,
121
+ )
122
+ reshape_op = PyQnnWrapper .PyQnnOpWrapper (
123
+ node .name ,
124
+ QNN_OP_PACKAGE_NAME_QTI_AISW ,
125
+ OpReshape .op_name ,
126
+ )
127
+ reshape_op .AddInputTensors ([indices_tensor_wrapper ])
128
+ reshape_op .AddOutputTensors ([reshape_output_tensor_wrapper ])
129
+ op_wrapper_list .append (reshape_op )
130
+ index_put_index_input_tensor_wrapper = reshape_output_tensor_wrapper
131
+
132
+ # Tile the index_node and concat the target index
133
+ if None in indicies_node :
134
+ tile_output_tensor = reshape_output_tensor .expand (index_shape )
135
+ # Tile the index_node to align with the shape of target_index
136
+ # Only need to tile the dim of None axis
137
+ # E.g., indicies_node: [None, None, aten__to_copy_default_1]
138
+ # Should tile the first two dimension.
139
+ multiples = [
140
+ shape if id != index_node_dim else 1
141
+ for id , shape in enumerate (index_shape )
142
+ ]
143
+ multiples_shape = [len (index_shape )]
144
+ tile_output_tensor_wrapper = self .define_custom_tensor_wrapper (
145
+ node_name = node .name + "_tile" ,
146
+ tensor_type = PyQnnWrapper .Qnn_TensorType_t .QNN_TENSOR_TYPE_NATIVE ,
147
+ dtype = QNN_TENSOR_TYPE_MAP [tile_output_tensor .dtype ],
148
+ quant_encoding = PyQnnWrapper .Qnn_QuantizationEncoding_t .QNN_QUANTIZATION_ENCODING_UNDEFINED ,
149
+ quant_configs = {},
150
+ dims = tile_output_tensor .size (),
151
+ tensor = tile_output_tensor ,
152
+ is_fake_tensor = True ,
153
+ nodes_to_wrappers = nodes_to_wrappers ,
154
+ )
155
+ tile_op = PyQnnWrapper .PyQnnOpWrapper (
156
+ node .name ,
157
+ QNN_OP_PACKAGE_NAME_QTI_AISW ,
158
+ OpTile .op_name ,
159
+ )
160
+ tile_op .AddInputTensors ([reshape_output_tensor_wrapper ])
161
+ tile_op .AddOutputTensors ([tile_output_tensor_wrapper ])
162
+ tile_op .AddTensorParam (
163
+ OpTile .param_multiples ,
164
+ PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UINT_32 ,
165
+ len (multiples_shape ),
166
+ multiples_shape ,
167
+ np .array (multiples , dtype = np .uint32 ),
168
+ True ,
169
+ )
170
+ op_wrapper_list .append (tile_op )
171
+
172
+ # Repeat index for "None" axis in indicies_node
173
+ ranges = [
174
+ torch .arange (dim , dtype = indice_tensor .dtype )
175
+ for dim in target_index [:- 1 ]
176
+ ]
177
+ target_index_shape = target_index + [len (ranges )]
178
+ target_index_tensor = torch .cartesian_prod (* ranges )
179
+ reshape_target_index_shape = [
180
+ shape if id != index_node_dim else 1
181
+ for id , shape in enumerate (target_index_shape )
182
+ ]
183
+ target_index_tensor = target_index_tensor .reshape (
184
+ reshape_target_index_shape
185
+ )
186
+ target_index_tensor = target_index_tensor .expand (
187
+ target_index_shape
188
+ ).contiguous ()
189
+ target_index_node = torch .fx .Node (
190
+ node .graph ,
191
+ node .name + "_target_index" ,
192
+ "call_function" ,
193
+ exir_ops .edge .aten .tensor .default ,
194
+ (), # args
195
+ {}, # kwargs
196
+ )
197
+ target_index_tensor_wrapper = self .define_tensor (
198
+ target_index_node ,
199
+ node ,
200
+ target_index_tensor ,
201
+ PyQnnWrapper .Qnn_TensorType_t .QNN_TENSOR_TYPE_STATIC ,
202
+ nodes_to_wrappers ,
203
+ )
62
204
205
+ # Concat target_index and tile output to reconstruct index_node
206
+ # Cannot use QNN Pack (stack) since QNN Pack is not support int32 dtype
207
+ concat_output_tensor = torch .concat (
208
+ (target_index_tensor , tile_output_tensor ), dim = - 1
209
+ )
210
+ concat_output_tensor_wrapper = self .define_custom_tensor_wrapper (
211
+ node_name = node .name + "_concat" ,
212
+ tensor_type = PyQnnWrapper .Qnn_TensorType_t .QNN_TENSOR_TYPE_NATIVE ,
213
+ dtype = QNN_TENSOR_TYPE_MAP [concat_output_tensor .dtype ],
214
+ quant_encoding = PyQnnWrapper .Qnn_QuantizationEncoding_t .QNN_QUANTIZATION_ENCODING_UNDEFINED ,
215
+ quant_configs = {},
216
+ dims = concat_output_tensor .size (),
217
+ tensor = concat_output_tensor ,
218
+ is_fake_tensor = True ,
219
+ nodes_to_wrappers = nodes_to_wrappers ,
220
+ )
221
+ concat_op = PyQnnWrapper .PyQnnOpWrapper (
222
+ node .name ,
223
+ QNN_OP_PACKAGE_NAME_QTI_AISW ,
224
+ OpConcat .op_name ,
225
+ )
226
+ concat_op .AddInputTensors (
227
+ [target_index_tensor_wrapper , tile_output_tensor_wrapper ]
228
+ )
229
+ concat_op .AddOutputTensors ([concat_output_tensor_wrapper ])
230
+ concat_op .AddScalarParam (
231
+ OpConcat .param_axis ,
232
+ PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UINT_32 ,
233
+ {QCOM_DATA : np .uint32 (concat_output_tensor .dim () - 1 )},
234
+ )
235
+ op_wrapper_list .append (concat_op )
236
+ index_put_index_input_tensor_wrapper = concat_output_tensor_wrapper
237
+
238
+ value_node = self .get_node (node .args [2 ])
239
+ value_tensor = self .get_tensor (value_node , node )
63
240
value_tensor_wrapper = self .define_tensor (
64
241
value_node ,
65
242
node ,
66
243
value_tensor ,
67
244
PyQnnWrapper .Qnn_TensorType_t .QNN_TENSOR_TYPE_NATIVE ,
68
245
nodes_to_wrappers ,
69
246
)
247
+
70
248
output_tensor = self .get_tensor (node , node )
71
249
output_tensor_wrapper = self .define_tensor (
72
250
node ,
@@ -82,8 +260,12 @@ def define_node(
82
260
OpScatterNd .op_name ,
83
261
)
84
262
index_put_op .AddInputTensors (
85
- [input_tensor_wrapper , indices_tensor_wrapper , value_tensor_wrapper ]
263
+ [
264
+ input_tensor_wrapper ,
265
+ index_put_index_input_tensor_wrapper ,
266
+ value_tensor_wrapper ,
267
+ ]
86
268
)
87
269
index_put_op .AddOutputTensors ([output_tensor_wrapper ])
88
-
89
- return index_put_op
270
+ op_wrapper_list . append ( index_put_op )
271
+ return op_wrapper_list
0 commit comments