14
14
15
15
from executorch .exir .dialects ._ops import ops as exir_ops
16
16
17
- from .qnn_constants import QNN_uint16
18
-
19
17
from .utils import get_parameter , is_graph_input , is_graph_output , is_parameter
20
18
21
19
26
24
# Note that there is no int64 tensor data type in Qnn.
27
25
torch .int64 : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UNDEFINED ,
28
26
torch .uint8 : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UFIXED_POINT_8 ,
29
- QNN_uint16 : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UFIXED_POINT_16 ,
27
+ torch . uint16 : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UFIXED_POINT_16 ,
30
28
}
31
29
QNN_TENSOR_TYPE_MAP = {
32
30
torch .float32 : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_FLOAT_32 ,
35
33
torch .int32 : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_INT_32 ,
36
34
torch .int64 : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_INT_64 ,
37
35
torch .uint8 : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UINT_8 ,
38
- QNN_uint16 : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UINT_16 ,
36
+ torch . uint16 : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UINT_16 ,
39
37
float : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_FLOAT_32 ,
40
38
}
41
39
@@ -169,7 +167,7 @@ def get_quant_encoding_conf(
169
167
return self .make_qnn_per_tensor_config (quant_attrs )
170
168
171
169
def get_quant_tensor_value (
172
- self , tensor : torch .Tensor , quant_attrs : Dict , dtype , bitwidth
170
+ self , tensor : torch .Tensor , quant_attrs : Dict , quant_configs : Dict
173
171
) -> torch .Tensor :
174
172
if quant_attrs ["encoding" ] in PER_TENSOR_ENCODING :
175
173
scale = quant_attrs ["scale" ]
@@ -178,16 +176,11 @@ def get_quant_tensor_value(
178
176
scale = quant_attrs ["scales" ]
179
177
zero_point = quant_attrs ["zero_points" ]
180
178
181
- # To bypass torch.uint16 quantization is not supported
182
- dtype = (
183
- torch .int32
184
- if dtype == PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UINT_16
185
- else quant_attrs ["dtype" ]
186
- )
179
+ dtype = quant_configs ["dtype" ]
187
180
188
181
tensor = tensor .div (scale ).add (zero_point ).round ().to (dtype )
189
182
# Make the backends access data correctly
190
- if bitwidth == 4 :
183
+ if quant_configs . get ( " bitwidth" ) == 4 :
191
184
mask = torch .full (tensor .size (), 0x0F , dtype = torch .int8 )
192
185
tensor = torch .bitwise_and (mask , tensor )
193
186
return tensor
@@ -236,7 +229,7 @@ def get_data_type(
236
229
<= torch .iinfo (torch .int16 ).max - torch .iinfo (torch .int16 ).min
237
230
):
238
231
if unsigned :
239
- quant_config ["dtype" ] = QNN_uint16
232
+ quant_config ["dtype" ] = torch . uint16
240
233
else :
241
234
quant_config ["dtype" ] = torch .int16
242
235
return QNN_QUANT_TYPE_MAP [quant_config ["dtype" ]]
@@ -327,8 +320,7 @@ def define_tensor(
327
320
tensor = self .get_quant_tensor_value (
328
321
tensor ,
329
322
node .meta ["quant_attrs" ],
330
- dtype ,
331
- quant_configs .get ("bitwidth" ),
323
+ quant_configs ,
332
324
)
333
325
tensor_wrapper = PyQnnWrapper .TensorWrapper (
334
326
tensor_name ,
0 commit comments