Skip to content

Commit e4e2790

Browse files
committed
Qualcomm AI Engine Direct - Fixed uint16 tensor
Summary: - Fixed uint16 data type of tensor - Fixed the bug "argument of type 'NoneType' is not iterable" in linear op
1 parent bae0387 commit e4e2790

File tree

3 files changed

+10
-17
lines changed

3 files changed

+10
-17
lines changed

backends/qualcomm/builders/node_visitor.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414

1515
from executorch.exir.dialects._ops import ops as exir_ops
1616

17-
from .qnn_constants import QNN_uint16
18-
1917
from .utils import get_parameter, is_graph_input, is_graph_output, is_parameter
2018

2119

@@ -26,7 +24,7 @@
2624
# Note that there is no int64 tensor data type in Qnn.
2725
torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UNDEFINED,
2826
torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8,
29-
QNN_uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16,
27+
torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16,
3028
}
3129
QNN_TENSOR_TYPE_MAP = {
3230
torch.float32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
@@ -35,7 +33,7 @@
3533
torch.int32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32,
3634
torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_64,
3735
torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_8,
38-
QNN_uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16,
36+
torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16,
3937
float: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
4038
}
4139

@@ -169,7 +167,7 @@ def get_quant_encoding_conf(
169167
return self.make_qnn_per_tensor_config(quant_attrs)
170168

171169
def get_quant_tensor_value(
172-
self, tensor: torch.Tensor, quant_attrs: Dict, dtype, bitwidth
170+
self, tensor: torch.Tensor, quant_attrs: Dict, quant_configs: Dict
173171
) -> torch.Tensor:
174172
if quant_attrs["encoding"] in PER_TENSOR_ENCODING:
175173
scale = quant_attrs["scale"]
@@ -178,16 +176,11 @@ def get_quant_tensor_value(
178176
scale = quant_attrs["scales"]
179177
zero_point = quant_attrs["zero_points"]
180178

181-
# To bypass torch.uint16 quantization is not supported
182-
dtype = (
183-
torch.int32
184-
if dtype == PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16
185-
else quant_attrs["dtype"]
186-
)
179+
dtype = quant_configs["dtype"]
187180

188181
tensor = tensor.div(scale).add(zero_point).round().to(dtype)
189182
# Make the backends access data correctly
190-
if bitwidth == 4:
183+
if quant_configs.get("bitwidth") == 4:
191184
mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8)
192185
tensor = torch.bitwise_and(mask, tensor)
193186
return tensor
@@ -236,7 +229,7 @@ def get_data_type(
236229
<= torch.iinfo(torch.int16).max - torch.iinfo(torch.int16).min
237230
):
238231
if unsigned:
239-
quant_config["dtype"] = QNN_uint16
232+
quant_config["dtype"] = torch.uint16
240233
else:
241234
quant_config["dtype"] = torch.int16
242235
return QNN_QUANT_TYPE_MAP[quant_config["dtype"]]
@@ -327,8 +320,7 @@ def define_tensor(
327320
tensor = self.get_quant_tensor_value(
328321
tensor,
329322
node.meta["quant_attrs"],
330-
dtype,
331-
quant_configs.get("bitwidth"),
323+
quant_configs,
332324
)
333325
tensor_wrapper = PyQnnWrapper.TensorWrapper(
334326
tensor_name,

backends/qualcomm/builders/op_linear.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def define_node(
6262
bias_node = node.args[2]
6363

6464
# TODO remove this when qnn sdk support
65-
if "scales" in bias_node.meta.get("quant_attrs"):
65+
if (
66+
quant_attrs := bias_node.meta.get("quant_attrs")
67+
) and "scales" in quant_attrs:
6668
print(
6769
f"[WARNING] Fallback linear bias, {bias_node}. per channel bias quantization is not support yet."
6870
)

backends/qualcomm/builders/qnn_constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from enum import IntEnum, unique
99

1010
QNN_OP_PACKAGE_NAME_QTI_AISW = "qti.aisw"
11-
QNN_uint16 = "uint16"
1211

1312
# Below constants should be same as those in QNN headers.
1413
# Maybe someday we should expose these constants by pybind

0 commit comments

Comments
 (0)