Skip to content

Commit 468b5f8

Browse files
committed
Qualcomm AI Engine Direct - Enable SSD300_VGG16
1 parent ab323a5 commit 468b5f8

File tree

11 files changed

+586
-53
lines changed

11 files changed

+586
-53
lines changed

backends/qualcomm/builders/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@
4141
op_skip_ops,
4242
op_slice_copy,
4343
op_softmax,
44+
op_sqrt,
4445
op_squeeze,
4546
op_sub,
47+
op_sum_int_list,
4648
op_tanh,
4749
op_transpose,
4850
op_unsqueeze,
@@ -86,7 +88,9 @@
8688
op_slice_copy,
8789
op_softmax,
8890
op_squeeze,
91+
op_sqrt,
8992
op_sub,
93+
op_sum_int_list,
9094
op_tanh,
9195
op_transpose,
9296
op_unsqueeze,

backends/qualcomm/builders/op_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def define_node(
6262
bias_node = node.args[2]
6363

6464
# TODO remove this when qnn sdk support
65-
if "scales" in bias_node.meta.get("quant_attrs"):
65+
if "scales" in bias_node.meta.get("quant_attrs", {}):
6666
print(
6767
f"[WARNING] Fallback linear bias, {bias_node}. per channel bias quantization is not support yet."
6868
)

backends/qualcomm/builders/op_log_softmax.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,5 +72,4 @@ def define_node(
7272
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
7373
{"data": np.uint32(dim)},
7474
)
75-
# pdb.set_trace()
7675
return log_softmax_op

backends/qualcomm/builders/op_sqrt.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
10+
import torch
11+
12+
from .node_visitor import NodeVisitor, register_node_visitor
13+
from .qnn_constants import OpSqrt, QNN_OP_PACKAGE_NAME_QTI_AISW
14+
15+
16+
@register_node_visitor
17+
class SQRT(NodeVisitor):
18+
target = ["aten.sqrt.default"]
19+
20+
def __init__(self, *args) -> None:
21+
super().__init__(*args)
22+
23+
def define_node(
24+
self,
25+
node: torch.fx.Node,
26+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
27+
) -> PyQnnWrapper.PyQnnOpWrapper:
28+
# tensor input
29+
input_node = node.args[0]
30+
input_tensor = self.get_tensor(input_node, node)
31+
32+
input_tensor_wrapper = self.define_tensor(
33+
input_node,
34+
input_tensor,
35+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
36+
nodes_to_wrappers,
37+
is_input_tensor=True,
38+
)
39+
sqrt_input_tensors = [input_tensor_wrapper]
40+
41+
out_tensor = self.get_tensor(node, node)
42+
output_tensor_wrapper = self.define_tensor(
43+
node,
44+
out_tensor,
45+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
46+
nodes_to_wrappers,
47+
is_input_tensor=False,
48+
)
49+
sqrt_output_tensors = [output_tensor_wrapper]
50+
51+
sqrt_op = PyQnnWrapper.PyQnnOpWrapper(
52+
node.name,
53+
QNN_OP_PACKAGE_NAME_QTI_AISW,
54+
OpSqrt.op_name,
55+
)
56+
sqrt_op.AddInputTensors(sqrt_input_tensors)
57+
sqrt_op.AddOutputTensors(sqrt_output_tensors)
58+
59+
return sqrt_op
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import cast, Dict, List
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
10+
import numpy as np
11+
import torch
12+
13+
from .node_visitor import NodeVisitor, register_node_visitor
14+
from .qnn_constants import OpReduceSum, QNN_OP_PACKAGE_NAME_QTI_AISW
15+
16+
17+
@register_node_visitor
18+
class Sum(NodeVisitor):
19+
target = ["aten.sum.dim_IntList"]
20+
21+
def __init__(self, *args) -> None:
22+
super().__init__(*args)
23+
24+
def define_node(
25+
self,
26+
node: torch.fx.Node,
27+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
28+
) -> PyQnnWrapper.PyQnnOpWrapper:
29+
30+
input_node = node.args[0]
31+
input_tensor = self.get_tensor(input_node, node)
32+
input_tensor_wrapper = self.define_tensor(
33+
input_node,
34+
input_tensor,
35+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
36+
nodes_to_wrappers,
37+
is_input_tensor=True,
38+
)
39+
sum_input_tensors = [input_tensor_wrapper]
40+
41+
# sum dims
42+
sum_dims = cast(List[int], node.args[1])
43+
sum_dims = [sum_dim % len(input_node.meta["val"].shape) for sum_dim in sum_dims]
44+
if "axis_order" in node.meta:
45+
sum_dims = [node.meta["axis_order"].index(sum_dim) for sum_dim in sum_dims]
46+
sum_dims_shape = [len(sum_dims)]
47+
48+
output_tensor = self.get_tensor(node, node)
49+
output_tensor_wrapper = self.define_tensor(
50+
node,
51+
output_tensor,
52+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
53+
nodes_to_wrappers,
54+
is_input_tensor=False,
55+
)
56+
sum_output_tensors = [output_tensor_wrapper]
57+
sum_op = PyQnnWrapper.PyQnnOpWrapper(
58+
node.name,
59+
QNN_OP_PACKAGE_NAME_QTI_AISW,
60+
OpReduceSum.op_name,
61+
)
62+
sum_op.AddInputTensors(sum_input_tensors)
63+
sum_op.AddOutputTensors(sum_output_tensors)
64+
sum_op.AddTensorParam(
65+
OpReduceSum.param_axes,
66+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
67+
len(sum_dims_shape),
68+
sum_dims_shape,
69+
np.array(sum_dims, dtype=np.uint32),
70+
True,
71+
)
72+
73+
if len(node.args) > 2:
74+
keep_dims = cast(bool, node.args[2])
75+
sum_op.AddScalarParam(
76+
OpReduceSum.param_keep_dims,
77+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
78+
{"data": keep_dims},
79+
)
80+
return sum_op

backends/qualcomm/builders/qnn_constants.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,13 @@ class OpExpandDims:
106106
param_axis: str = "axis"
107107

108108

109+
@dataclass(init=False, frozen=True)
110+
class OpReduceSum:
111+
op_name: str = "ReduceSum"
112+
param_axes: str = "axes"
113+
param_keep_dims: str = "keep_dims"
114+
115+
109116
@dataclass(init=False, frozen=True)
110117
class OpFullyConnected:
111118
op_name: str = "FullyConnected"
@@ -123,6 +130,11 @@ class OpGelu:
123130
op_name: str = "Gelu"
124131

125132

133+
@dataclass(init=False, frozen=True)
134+
class OpSqrt:
135+
op_name: str = "ElementWiseSquareRoot"
136+
137+
126138
@dataclass(init=False, frozen=True)
127139
class OpHardSwish:
128140
op_name: str = "HardSwish"

backends/qualcomm/passes/layout_transform.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ class LayoutTransform(ExportPass):
5252
exir_ops.edge.aten.bmm.default,
5353
exir_ops.edge.aten.full.default,
5454
exir_ops.edge.aten.gelu.default,
55+
exir_ops.edge.aten.sqrt.default,
56+
exir_ops.edge.aten.sum.dim_IntList,
57+
exir_ops.edge.aten.pow.Tensor_Scalar,
5558
*q_ops,
5659
*dq_ops,
5760
_operator.getitem,
@@ -109,7 +112,10 @@ def is_layout_sensitive(self, node: torch.fx.Node) -> bool:
109112
return node.target in self.layout_sensitive_ops
110113

111114
def is_layout_agnostic(self, node: torch.fx.Node) -> bool:
112-
if node.target == exir_ops.edge.aten.mean.dim:
115+
if node.target in [
116+
exir_ops.edge.aten.mean.dim,
117+
exir_ops.edge.aten.sum.dim_IntList,
118+
]:
113119
# if dimemsion is not kept, we'll have no clue how to do layout transform
114120
if len(node.args) < 3 or not node.args[2]:
115121
return False

backends/qualcomm/quantizer/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def decorator(annotator: Callable):
4242

4343
return decorator
4444

45+
4546
def _is_input_float_tensor(node: Node):
4647
"""Check if the input is not a float tensor, so that we can skip quantization for the node
4748
since observers only works with float Tensors
@@ -175,6 +176,11 @@ def annotate_rsub(node: Node, quantization_config: QuantizationConfig) -> None:
175176
annotate_binary(node, quantization_config)
176177

177178

179+
@register_annotator([torch.ops.aten.sum.dim_IntList])
180+
def annotate_sum(node: Node, quantization_config: QuantizationConfig) -> None:
181+
annotate_binary(node, quantization_config)
182+
183+
178184
@register_annotator([torch.ops.aten.ceil.default])
179185
def annotate_ceil(node: Node, quantization_config: QuantizationConfig) -> None:
180186
annotate_single_in_single_out(node, quantization_config)
@@ -302,6 +308,11 @@ def annotate_slice(node: Node, quantization_config: QuantizationConfig) -> None:
302308
annotate_single_in_single_out(node, quantization_config)
303309

304310

311+
@register_annotator([torch.ops.aten.sqrt.default])
312+
def annotate_sqrt(node: Node, quantization_config: QuantizationConfig) -> None:
313+
annotate_single_in_single_out(node, quantization_config)
314+
315+
305316
@register_annotator([torch.ops.aten.gelu.default])
306317
def annotate_gelu(node: Node, quantization_config: QuantizationConfig) -> None:
307318
annotate_single_in_single_out(node, quantization_config)

0 commit comments

Comments
 (0)