Skip to content

Commit f2bb89b

Browse files
committed
Qualcomm AI Engine Direct - Fixed uint16 tensor
Summary: - Fixed uint16 data type of tensor - Fixed the bug "argument of type 'NoneType' is not iterable" in linear op
1 parent 73599f4 commit f2bb89b

File tree

2 files changed

+7
-16
lines changed

2 files changed

+7
-16
lines changed

backends/qualcomm/builders/node_visitor.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414

1515
from executorch.exir.dialects._ops import ops as exir_ops
1616

17-
from .qnn_constants import QNN_uint16
18-
1917
from .utils import get_parameter, is_graph_input, is_graph_output, is_parameter
2018

2119

@@ -26,7 +24,7 @@
2624
# Note that there is no int64 tensor data type in Qnn.
2725
torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UNDEFINED,
2826
torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8,
29-
QNN_uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16,
27+
torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16,
3028
}
3129
QNN_TENSOR_TYPE_MAP = {
3230
torch.bool: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
@@ -36,7 +34,7 @@
3634
torch.int32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32,
3735
torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_64,
3836
torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_8,
39-
QNN_uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16,
37+
torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16,
4038
float: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
4139
}
4240

@@ -170,7 +168,7 @@ def get_quant_encoding_conf(
170168
return self.make_qnn_per_tensor_config(quant_attrs)
171169

172170
def get_quant_tensor_value(
173-
self, tensor: torch.Tensor, quant_attrs: Dict, dtype, bitwidth
171+
self, tensor: torch.Tensor, quant_attrs: Dict, quant_configs: Dict
174172
) -> torch.Tensor:
175173
if quant_attrs["encoding"] in PER_TENSOR_ENCODING:
176174
scale = quant_attrs["scale"]
@@ -179,16 +177,11 @@ def get_quant_tensor_value(
179177
scale = quant_attrs["scales"]
180178
zero_point = quant_attrs["zero_points"]
181179

182-
# To bypass torch.uint16 quantization is not supported
183-
dtype = (
184-
torch.int32
185-
if dtype == PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16
186-
else quant_attrs["dtype"]
187-
)
180+
dtype = quant_configs["dtype"]
188181

189182
tensor = tensor.div(scale).add(zero_point).round().to(dtype)
190183
# Make the backends access data correctly
191-
if bitwidth == 4:
184+
if quant_configs.get("bitwidth") == 4:
192185
mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8)
193186
tensor = torch.bitwise_and(mask, tensor)
194187
return tensor
@@ -237,7 +230,7 @@ def get_data_type(
237230
<= torch.iinfo(torch.int16).max - torch.iinfo(torch.int16).min
238231
):
239232
if unsigned:
240-
quant_config["dtype"] = QNN_uint16
233+
quant_config["dtype"] = torch.uint16
241234
else:
242235
quant_config["dtype"] = torch.int16
243236
return QNN_QUANT_TYPE_MAP[quant_config["dtype"]]
@@ -328,8 +321,7 @@ def define_tensor(
328321
tensor = self.get_quant_tensor_value(
329322
tensor,
330323
node.meta["quant_attrs"],
331-
dtype,
332-
quant_configs.get("bitwidth"),
324+
quant_configs,
333325
)
334326
tensor_wrapper = PyQnnWrapper.TensorWrapper(
335327
tensor_name,

backends/qualcomm/builders/qnn_constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from enum import IntEnum, unique
99

1010
QNN_OP_PACKAGE_NAME_QTI_AISW = "qti.aisw"
11-
QNN_uint16 = "uint16"
1211

1312
# Below constants should be same as those in QNN headers.
1413
# Maybe someday we should expose these constants by pybind

0 commit comments

Comments
 (0)