Skip to content

Commit a18e0c7

Browse files
committed
Qualcomm AI Engine Direct - Refactor & centralize common keywords
Summary: - Summarize the QCOM specific keywords - Replace with the hard code part in qualcomm code base
1 parent 8bdafb0 commit a18e0c7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+359
-252
lines changed

backends/qualcomm/builders/node_visitor.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,16 @@
1111

1212
import numpy as np
1313
import torch
14+
from executorch.backends.qualcomm.utils.constants import (
15+
QCOM_AXIS_ORDER,
16+
QCOM_BITWIDTH,
17+
QCOM_ENCODING,
18+
QCOM_QUANT_ATTRS,
19+
QCOM_REQUANTIZE,
20+
QCOM_SCALE_OFFSET,
21+
QCOM_SCALES,
22+
QCOM_ZERO_POINTS,
23+
)
1424

1525
from executorch.exir.dialects._ops import ops as exir_ops
1626

@@ -89,15 +99,15 @@ def _get_tensor(node, index):
8999
return node.meta["val"]
90100

91101
tensor = _get_tensor(input_node, idx)
92-
if len(tensor.shape) != 0 and "axis_order" in op_node.meta:
93-
tensor = tensor.permute(dims=op_node.meta["axis_order"]).contiguous()
102+
if len(tensor.shape) != 0 and QCOM_AXIS_ORDER in op_node.meta:
103+
tensor = tensor.permute(dims=op_node.meta[QCOM_AXIS_ORDER]).contiguous()
94104
return tensor
95105

96106
def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
97107
quant_config = copy.deepcopy(quant_attrs)
98108

99-
scales = quant_attrs["scales"]
100-
zero_points = quant_attrs["zero_points"]
109+
scales = quant_attrs[QCOM_SCALES]
110+
zero_points = quant_attrs[QCOM_ZERO_POINTS]
101111
assert len(scales) == len(
102112
zero_points
103113
), f"Per channel encoding of node {node}, has different size for scales {len(scales)} and zero_points {len(zero_points)}"
@@ -120,13 +130,13 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
120130
else:
121131
quant_config["axis"] = quant_attrs["axis"]
122132

123-
quant_config["scale_offset"] = scale_offset
133+
quant_config[QCOM_SCALE_OFFSET] = scale_offset
124134
# special case for 4 bits
125135
if (
126136
quant_config["dtype"] == torch.int8
127137
and quant_config["quant_max"] - quant_config["quant_min"] <= 15
128138
):
129-
quant_config["bitwidth"] = 4
139+
quant_config[QCOM_BITWIDTH] = 4
130140
return (
131141
PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET,
132142
quant_config,
@@ -145,7 +155,7 @@ def make_qnn_per_tensor_config(self, quant_attrs: Dict):
145155
quant_config["dtype"] == torch.int8
146156
and quant_config["quant_max"] - quant_config["quant_min"] <= 15
147157
):
148-
quant_config["bitwidth"] = 4
158+
quant_config[QCOM_BITWIDTH] = 4
149159
return (
150160
PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET,
151161
quant_config,
@@ -158,36 +168,36 @@ def make_qnn_per_tensor_config(self, quant_attrs: Dict):
158168
def get_quant_encoding_conf(
159169
self, node: torch.fx.Node, is_input_tensor: bool = False
160170
) -> Tuple[Any, Dict]:
161-
if not node.meta.get("quant_attrs", None):
171+
if not node.meta.get(QCOM_QUANT_ATTRS, None):
162172
return (
163173
PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED,
164174
{},
165175
)
166176
quant_attrs = (
167-
node.meta["requantize"]
168-
if "requantize" in node.meta and is_input_tensor
169-
else node.meta["quant_attrs"]
177+
node.meta[QCOM_REQUANTIZE]
178+
if QCOM_REQUANTIZE in node.meta and is_input_tensor
179+
else node.meta[QCOM_QUANT_ATTRS]
170180
)
171-
if quant_attrs["encoding"] in PER_CHANNEL_ENCODING:
181+
if quant_attrs[QCOM_ENCODING] in PER_CHANNEL_ENCODING:
172182
return self.make_qnn_per_channel_config(node, quant_attrs)
173183

174184
return self.make_qnn_per_tensor_config(quant_attrs)
175185

176186
def get_quant_tensor_value(
177187
self, tensor: torch.Tensor, quant_attrs: Dict, quant_configs: Dict
178188
) -> torch.Tensor:
179-
if quant_attrs["encoding"] in PER_TENSOR_ENCODING:
189+
if quant_attrs[QCOM_ENCODING] in PER_TENSOR_ENCODING:
180190
scale = quant_attrs["scale"]
181191
zero_point = quant_attrs["zero_point"]
182192
else: # per channel case
183-
scale = quant_attrs["scales"]
184-
zero_point = quant_attrs["zero_points"]
193+
scale = quant_attrs[QCOM_SCALES]
194+
zero_point = quant_attrs[QCOM_ZERO_POINTS]
185195

186196
dtype = quant_configs["dtype"]
187197

188198
tensor = tensor.div(scale).add(zero_point).round().to(dtype)
189199
# Make the backends access data correctly
190-
if quant_configs.get("bitwidth") == 4:
200+
if quant_configs.get(QCOM_BITWIDTH) == 4:
191201
mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8)
192202
tensor = torch.bitwise_and(mask, tensor)
193203
return tensor
@@ -315,7 +325,7 @@ def define_tensor(
315325
if quant_configs:
316326
tensor = self.get_quant_tensor_value(
317327
tensor,
318-
node.meta["quant_attrs"],
328+
node.meta[QCOM_QUANT_ATTRS],
319329
quant_configs,
320330
)
321331
tensor_wrapper = PyQnnWrapper.TensorWrapper(

backends/qualcomm/builders/op_avg_pool2d.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import numpy as np
1111
import torch
12+
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
1213

1314
from .node_visitor import NodeVisitor, register_node_visitor
1415
from .qnn_constants import OpPoolAvg2d, QNN_OP_PACKAGE_NAME_QTI_AISW
@@ -132,12 +133,12 @@ def define_node(
132133
avg_pool2d_op.AddScalarParam(
133134
OpPoolAvg2d.param_rounding_mode,
134135
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
135-
{"data": np.uint32(mode)},
136+
{QCOM_DATA: np.uint32(mode)},
136137
)
137138
avg_pool2d_op.AddScalarParam(
138139
OpPoolAvg2d.param_count_pad_for_edges,
139140
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
140-
{"data": count_include_pad},
141+
{QCOM_DATA: count_include_pad},
141142
)
142143

143144
return avg_pool2d_op

backends/qualcomm/builders/op_cat.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import numpy as np
1111
import torch
12+
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
1213

1314
from .node_visitor import NodeVisitor, register_node_visitor
1415
from .qnn_constants import OpConcat, QNN_OP_PACKAGE_NAME_QTI_AISW
@@ -64,8 +65,8 @@ def define_node(
6465
if axis < 0:
6566
axis += node.meta["val"].dim()
6667

67-
if "axis_order" in node.meta:
68-
axis = node.meta["axis_order"].index(axis)
68+
if QCOM_AXIS_ORDER in node.meta:
69+
axis = node.meta[QCOM_AXIS_ORDER].index(axis)
6970

7071
concat_op = PyQnnWrapper.PyQnnOpWrapper(
7172
node.name,
@@ -78,7 +79,7 @@ def define_node(
7879
concat_op.AddScalarParam(
7980
OpConcat.param_axis,
8081
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
81-
{"data": np.uint32(axis)},
82+
{QCOM_DATA: np.uint32(axis)},
8283
)
8384

8485
return concat_op

backends/qualcomm/builders/op_clamp.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import numpy as np
1111
import torch
12+
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
1213

1314
from .node_visitor import NodeVisitor, register_node_visitor
1415
from .qnn_constants import OpReluMinMax, QNN_OP_PACKAGE_NAME_QTI_AISW
@@ -67,12 +68,12 @@ def define_node(
6768
clamp_op.AddScalarParam(
6869
OpReluMinMax.param_max_value,
6970
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
70-
{"data": np.float32(output_max)},
71+
{QCOM_DATA: np.float32(output_max)},
7172
)
7273
clamp_op.AddScalarParam(
7374
OpReluMinMax.param_min_value,
7475
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
75-
{"data": np.float32(output_min)},
76+
{QCOM_DATA: np.float32(output_min)},
7677
)
7778

7879
return clamp_op

backends/qualcomm/builders/op_conv2d.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import numpy as np
1212
import torch
13+
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
1314

1415
from .node_visitor import NodeVisitor, register_node_visitor
1516
from .qnn_constants import (
@@ -79,7 +80,7 @@ def _add_conv_op_parameter(
7980
conv_op.AddScalarParam(
8081
OP.param_group,
8182
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
82-
{"data": np.uint32(groups)},
83+
{QCOM_DATA: np.uint32(groups)},
8384
)
8485

8586
return conv_op
@@ -130,7 +131,7 @@ def _define_conv1d(
130131
unsqueeze_op.AddScalarParam(
131132
OpExpandDims.param_axis,
132133
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
133-
{"data": np.uint32(1)},
134+
{QCOM_DATA: np.uint32(1)},
134135
)
135136
op_wrapper_list.append(unsqueeze_op)
136137

backends/qualcomm/builders/op_depth_to_space.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import numpy as np
1212
import torch
13+
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
1314

1415
from .node_visitor import NodeVisitor, register_node_visitor
1516
from .qnn_constants import OpDepthToSpace, QNN_OP_PACKAGE_NAME_QTI_AISW
@@ -70,7 +71,7 @@ def define_node(
7071
depth_to_space_op.AddScalarParam(
7172
OpDepthToSpace.param_mode,
7273
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
73-
{"data": np.uint32(OpDepthToSpace.Mode.CRD)},
74+
{QCOM_DATA: np.uint32(OpDepthToSpace.Mode.CRD)},
7475
)
7576

7677
return depth_to_space_op

backends/qualcomm/builders/op_embedding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import numpy as np
1111
import torch
12+
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
1213

1314
from .node_visitor import NodeVisitor, register_node_visitor
1415
from .qnn_constants import OpGather, QNN_OP_PACKAGE_NAME_QTI_AISW
@@ -71,7 +72,7 @@ def define_node(
7172
gather_op.AddScalarParam(
7273
OpGather.param_axis,
7374
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32,
74-
{"data": np.int32(0)},
75+
{QCOM_DATA: np.int32(0)},
7576
)
7677

7778
return gather_op

backends/qualcomm/builders/op_hardsigmoid.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import numpy as np
1111

1212
import torch
13+
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
1314

1415
from .node_visitor import NodeVisitor, register_node_visitor
1516
from .qnn_constants import OpElementWiseNeuron, QNN_OP_PACKAGE_NAME_QTI_AISW
@@ -58,19 +59,19 @@ def define_node(
5859
hardsigmoid_op.AddScalarParam(
5960
OpElementWiseNeuron.param_operation,
6061
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
61-
{"data": np.uint32(2)},
62+
{QCOM_DATA: np.uint32(2)},
6263
)
6364

6465
# The parameter used in Pytorch definition for hardsigmoid
6566
hardsigmoid_op.AddScalarParam(
6667
OpElementWiseNeuron.param_alpha,
6768
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
68-
{"data": np.float32(1 / 6)},
69+
{QCOM_DATA: np.float32(1 / 6)},
6970
)
7071
hardsigmoid_op.AddScalarParam(
7172
OpElementWiseNeuron.param_beta,
7273
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
73-
{"data": np.float32(1 / 2)},
74+
{QCOM_DATA: np.float32(1 / 2)},
7475
)
7576

7677
return hardsigmoid_op

backends/qualcomm/builders/op_hardtanh.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import numpy as np
1212
import torch
13+
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
1314

1415
from .node_visitor import NodeVisitor, register_node_visitor
1516
from .qnn_constants import OpReluMinMax, QNN_OP_PACKAGE_NAME_QTI_AISW
@@ -66,12 +67,12 @@ def define_node(
6667
hardtanh_op.AddScalarParam(
6768
OpReluMinMax.param_max_value,
6869
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
69-
{"data": np.float32(output_max)},
70+
{QCOM_DATA: np.float32(output_max)},
7071
)
7172
hardtanh_op.AddScalarParam(
7273
OpReluMinMax.param_min_value,
7374
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
74-
{"data": np.float32(output_min)},
75+
{QCOM_DATA: np.float32(output_min)},
7576
)
7677

7778
return hardtanh_op

backends/qualcomm/builders/op_layer_norm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import numpy as np
1212
import torch
13+
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
1314

1415
from .node_visitor import NodeVisitor, register_node_visitor
1516
from .qnn_constants import OpLayerNorm, QNN_OP_PACKAGE_NAME_QTI_AISW
@@ -91,7 +92,7 @@ def define_node(
9192
layer_norm_op.AddScalarParam(
9293
OpLayerNorm.param_epsilon,
9394
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
94-
{"data": np.float32(epsilon)},
95+
{QCOM_DATA: np.float32(epsilon)},
9596
)
9697
layer_norm_op.AddTensorParam(
9798
OpLayerNorm.param_axes,

backends/qualcomm/builders/op_linear.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
1010

1111
import torch
12+
from executorch.backends.qualcomm.utils.constants import (
13+
QCOM_QUANT_ATTRS,
14+
QCOM_SCALES,
15+
QCOM_ZERO_POINTS,
16+
)
1217

1318
from .node_visitor import NodeVisitor, register_node_visitor
1419
from .qnn_constants import OpFullyConnected, QNN_OP_PACKAGE_NAME_QTI_AISW
@@ -41,12 +46,14 @@ def define_node(
4146

4247
weight_node = node.args[1]
4348
if (
44-
quant_attrs := weight_node.meta.get("quant_attrs")
45-
) and "scales" in quant_attrs:
49+
quant_attrs := weight_node.meta.get(QCOM_QUANT_ATTRS)
50+
) and QCOM_SCALES in quant_attrs:
4651
# Dimension of weight is [m, n], per channel quant params is [m]
4752
# Change to [m, 1] to fit the tensor.div(s).add(z)
48-
quant_attrs["scales"] = quant_attrs["scales"].reshape([-1, 1])
49-
quant_attrs["zero_points"] = quant_attrs["zero_points"].reshape([-1, 1])
53+
quant_attrs[QCOM_SCALES] = quant_attrs[QCOM_SCALES].reshape([-1, 1])
54+
quant_attrs[QCOM_ZERO_POINTS] = quant_attrs[QCOM_ZERO_POINTS].reshape(
55+
[-1, 1]
56+
)
5057

5158
weight_tensor = get_parameter(weight_node, self.edge_program)
5259
weight_tensor_wrapper = self.define_tensor(
@@ -62,7 +69,7 @@ def define_node(
6269
bias_node = node.args[2]
6370

6471
# TODO remove this when qnn sdk support
65-
if "scales" in bias_node.meta.get("quant_attrs", {}):
72+
if QCOM_SCALES in bias_node.meta.get(QCOM_QUANT_ATTRS, {}):
6673
print(
6774
f"[WARNING] Fallback linear bias, {bias_node}. per channel bias quantization is not support yet."
6875
)

backends/qualcomm/builders/op_log_softmax.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import numpy as np
1111
import torch
12+
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
1213

1314
from .node_visitor import NodeVisitor, register_node_visitor
1415
from .qnn_constants import OpLogSoftmax, QNN_OP_PACKAGE_NAME_QTI_AISW
@@ -52,8 +53,8 @@ def define_node(
5253
if dim < 0:
5354
dim = dim % len(input_tensor.shape)
5455

55-
if "axis_order" in node.meta:
56-
dim = node.meta["axis_order"].index(dim)
56+
if QCOM_AXIS_ORDER in node.meta:
57+
dim = node.meta[QCOM_AXIS_ORDER].index(dim)
5758

5859
# logsoftmax only supports last dimension for now, which is channel in QNN
5960
if dim != input_tensor.dim() - 1:
@@ -70,6 +71,6 @@ def define_node(
7071
log_softmax_op.AddScalarParam(
7172
OpLogSoftmax.param_axis,
7273
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
73-
{"data": np.uint32(dim)},
74+
{QCOM_DATA: np.uint32(dim)},
7475
)
7576
return log_softmax_op

0 commit comments

Comments
 (0)