14
14
15
15
from executorch .exir .dialects ._ops import ops as exir_ops
16
16
17
- from .qnn_constants import QNN_uint16
18
-
19
17
from .utils import get_parameter , is_graph_input , is_graph_output , is_parameter
20
18
21
19
26
24
# Note that there is no int64 tensor data type in Qnn.
27
25
torch .int64 : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UNDEFINED ,
28
26
torch .uint8 : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UFIXED_POINT_8 ,
29
- QNN_uint16 : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UFIXED_POINT_16 ,
27
+ torch . uint16 : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UFIXED_POINT_16 ,
30
28
}
31
29
QNN_TENSOR_TYPE_MAP = {
32
30
torch .bool : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_BOOL_8 ,
36
34
torch .int32 : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_INT_32 ,
37
35
torch .int64 : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_INT_64 ,
38
36
torch .uint8 : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UINT_8 ,
39
- QNN_uint16 : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UINT_16 ,
37
+ torch . uint16 : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UINT_16 ,
40
38
float : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_FLOAT_32 ,
41
39
}
42
40
@@ -170,7 +168,7 @@ def get_quant_encoding_conf(
170
168
return self .make_qnn_per_tensor_config (quant_attrs )
171
169
172
170
def get_quant_tensor_value (
173
- self , tensor : torch .Tensor , quant_attrs : Dict , dtype , bitwidth
171
+ self , tensor : torch .Tensor , quant_attrs : Dict , quant_configs : Dict
174
172
) -> torch .Tensor :
175
173
if quant_attrs ["encoding" ] in PER_TENSOR_ENCODING :
176
174
scale = quant_attrs ["scale" ]
@@ -179,16 +177,11 @@ def get_quant_tensor_value(
179
177
scale = quant_attrs ["scales" ]
180
178
zero_point = quant_attrs ["zero_points" ]
181
179
182
- # To bypass torch.uint16 quantization is not supported
183
- dtype = (
184
- torch .int32
185
- if dtype == PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UINT_16
186
- else quant_attrs ["dtype" ]
187
- )
180
+ dtype = quant_configs ["dtype" ]
188
181
189
182
tensor = tensor .div (scale ).add (zero_point ).round ().to (dtype )
190
183
# Make the backends access data correctly
191
- if bitwidth == 4 :
184
+ if quant_configs . get ( " bitwidth" ) == 4 :
192
185
mask = torch .full (tensor .size (), 0x0F , dtype = torch .int8 )
193
186
tensor = torch .bitwise_and (mask , tensor )
194
187
return tensor
@@ -237,7 +230,7 @@ def get_data_type(
237
230
<= torch .iinfo (torch .int16 ).max - torch .iinfo (torch .int16 ).min
238
231
):
239
232
if unsigned :
240
- quant_config ["dtype" ] = QNN_uint16
233
+ quant_config ["dtype" ] = torch . uint16
241
234
else :
242
235
quant_config ["dtype" ] = torch .int16
243
236
return QNN_QUANT_TYPE_MAP [quant_config ["dtype" ]]
@@ -328,8 +321,7 @@ def define_tensor(
328
321
tensor = self .get_quant_tensor_value (
329
322
tensor ,
330
323
node .meta ["quant_attrs" ],
331
- dtype ,
332
- quant_configs .get ("bitwidth" ),
324
+ quant_configs ,
333
325
)
334
326
tensor_wrapper = PyQnnWrapper .TensorWrapper (
335
327
tensor_name ,
0 commit comments