Skip to content

Commit 5b2129e

Browse files
authored
Enable sin, cos op for QNN HTP backend (#6591)
- add sin / cos to op_builder - change quantizer to adopt new operator - add test cases for both fp16 & quantized version - write a description for the build_executorch_binary function."
1 parent 8861b9a commit 5b2129e

File tree

8 files changed

+195
-0
lines changed

8 files changed

+195
-0
lines changed

backends/qualcomm/builders/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
op_ceil,
1515
op_clamp,
1616
op_conv2d,
17+
op_cos,
1718
op_depth_to_space,
1819
op_dequantize,
1920
op_div,
@@ -43,6 +44,7 @@
4344
op_rsqrt,
4445
op_select_copy,
4546
op_sigmoid,
47+
op_sin,
4648
op_skip_ops,
4749
op_slice_copy,
4850
op_softmax,
@@ -71,6 +73,7 @@
7173
op_ceil,
7274
op_clamp,
7375
op_conv2d,
76+
op_cos,
7477
op_depth_to_space,
7578
op_dequantize,
7679
op_div,
@@ -100,6 +103,7 @@
100103
op_rsqrt,
101104
op_select_copy,
102105
op_sigmoid,
106+
op_sin,
103107
op_skip_ops,
104108
op_slice_copy,
105109
op_softmax,

backends/qualcomm/builders/op_cos.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict
8+
9+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
10+
11+
import torch
12+
13+
from .node_visitor import NodeVisitor, register_node_visitor
14+
from .qnn_constants import OpElementWiseCos, QNN_OP_PACKAGE_NAME_QTI_AISW
15+
16+
17+
@register_node_visitor
18+
class Cos(NodeVisitor):
19+
target = ["aten.cos.default"]
20+
21+
def __init__(self, *args) -> None:
22+
super().__init__(*args)
23+
24+
def define_node(
25+
self,
26+
node: torch.fx.Node,
27+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
28+
) -> PyQnnWrapper.PyQnnOpWrapper:
29+
input_node = node.args[0]
30+
input_tensor = self.get_tensor(input_node, node)
31+
input_tensor_wrapper = self.define_tensor(
32+
input_node,
33+
input_tensor,
34+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
35+
nodes_to_wrappers,
36+
is_input_tensor=True,
37+
)
38+
39+
output_tensor = self.get_tensor(node, node)
40+
output_tensor_wrapper = self.define_tensor(
41+
node,
42+
output_tensor,
43+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
44+
nodes_to_wrappers,
45+
is_input_tensor=False,
46+
)
47+
48+
cos_op = PyQnnWrapper.PyQnnOpWrapper(
49+
node.name,
50+
QNN_OP_PACKAGE_NAME_QTI_AISW,
51+
OpElementWiseCos.op_name,
52+
)
53+
cos_op.AddInputTensors([input_tensor_wrapper])
54+
cos_op.AddOutputTensors([output_tensor_wrapper])
55+
56+
return cos_op

backends/qualcomm/builders/op_sin.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict
8+
9+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
10+
11+
import torch
12+
13+
from .node_visitor import NodeVisitor, register_node_visitor
14+
from .qnn_constants import OpElementWiseSin, QNN_OP_PACKAGE_NAME_QTI_AISW
15+
16+
17+
@register_node_visitor
18+
class Sin(NodeVisitor):
19+
target = ["aten.sin.default"]
20+
21+
def __init__(self, *args) -> None:
22+
super().__init__(*args)
23+
24+
def define_node(
25+
self,
26+
node: torch.fx.Node,
27+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
28+
) -> PyQnnWrapper.PyQnnOpWrapper:
29+
input_node = node.args[0]
30+
input_tensor = self.get_tensor(input_node, node)
31+
input_tensor_wrapper = self.define_tensor(
32+
input_node,
33+
input_tensor,
34+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
35+
nodes_to_wrappers,
36+
is_input_tensor=True,
37+
)
38+
39+
output_tensor = self.get_tensor(node, node)
40+
output_tensor_wrapper = self.define_tensor(
41+
node,
42+
output_tensor,
43+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
44+
nodes_to_wrappers,
45+
is_input_tensor=False,
46+
)
47+
48+
sin_op = PyQnnWrapper.PyQnnOpWrapper(
49+
node.name,
50+
QNN_OP_PACKAGE_NAME_QTI_AISW,
51+
OpElementWiseSin.op_name,
52+
)
53+
sin_op.AddInputTensors([input_tensor_wrapper])
54+
sin_op.AddOutputTensors([output_tensor_wrapper])
55+
56+
return sin_op

backends/qualcomm/builders/qnn_constants.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,11 @@ class OpElementWiseCeil:
8585
op_name = "ElementWiseCeil"
8686

8787

88+
@dataclass(init=False, frozen=True)
89+
class OpElementWiseCos:
90+
op_name: str = "ElementWiseCos"
91+
92+
8893
@dataclass(init=False, frozen=True)
8994
class OpElementWiseDivide:
9095
op_name: str = "ElementWiseDivide"
@@ -113,6 +118,11 @@ class OpElementWiseRsqrt:
113118
op_name: str = "ElementWiseRsqrt"
114119

115120

121+
@dataclass(init=False, frozen=True)
122+
class OpElementWiseSin:
123+
op_name: str = "ElementWiseSin"
124+
125+
116126
@dataclass(init=False, frozen=True)
117127
class OpElementWiseSubtract:
118128
op_name = "ElementWiseSubtract"

backends/qualcomm/quantizer/annotators.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,16 @@ def annotate_relu(node: Node, quantization_config: QuantizationConfig) -> None:
271271
annotate_single_in_single_out(node, quantization_config)
272272

273273

274+
@register_annotator([torch.ops.aten.cos.default])
275+
def annotate_cos(node: Node, quantization_config: QuantizationConfig) -> None:
276+
annotate_single_in_single_out(node, quantization_config)
277+
278+
279+
@register_annotator([torch.ops.aten.sin.default])
280+
def annotate_sin(node: Node, quantization_config: QuantizationConfig) -> None:
281+
annotate_single_in_single_out(node, quantization_config)
282+
283+
274284
@register_annotator([torch.ops.aten.tanh.default])
275285
def annotate_tanh(node: Node, quantization_config: QuantizationConfig) -> None:
276286
annotate_single_in_single_out(node, quantization_config)

backends/qualcomm/tests/models.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,14 @@ def forward(self, x):
427427
return topk_values
428428

429429

430+
class Cos(torch.nn.Module):
431+
def __init__(self):
432+
super().__init__()
433+
434+
def forward(self, x):
435+
return torch.cos(x)
436+
437+
430438
class Div(torch.nn.Module):
431439
def __init__(self):
432440
super().__init__()
@@ -889,6 +897,14 @@ def forward(self, x):
889897
return torch.sigmoid(x)
890898

891899

900+
class Sin(torch.nn.Module):
901+
def __init__(self):
902+
super().__init__()
903+
904+
def forward(self, x):
905+
return torch.sin(x)
906+
907+
892908
class SimpleModel(torch.nn.Module):
893909
def __init__(self):
894910
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,11 @@ def test_qnn_backend_conv_transpose2d(self):
143143
with self.subTest(i=i):
144144
self.lower_module_and_test_output(module, sample_input)
145145

146+
def test_qnn_backend_cos(self):
147+
module = Cos() # noqa: F405
148+
sample_input = (torch.randn(2, 5, 1, 3),)
149+
self.lower_module_and_test_output(module, sample_input)
150+
146151
def test_qnn_backend_einsum_outer_product(self):
147152
module = EinsumOuterProduct() # noqa: F405
148153
x = torch.randn(5)
@@ -465,6 +470,11 @@ def test_qnn_backend_sigmoid(self):
465470
sample_input = (torch.randn([1, 3, 3, 3]),)
466471
self.lower_module_and_test_output(module, sample_input)
467472

473+
def test_qnn_backend_sin(self):
474+
module = Sin() # noqa: F405
475+
sample_input = (torch.randn(2, 5, 1, 3),)
476+
self.lower_module_and_test_output(module, sample_input)
477+
468478
def test_qnn_backend_select_copy(self):
469479
module = SelectCopy() # noqa: F405
470480
sample_input = (torch.randn([1, 3, 3, 3]),)
@@ -825,6 +835,12 @@ def test_qnn_backend_conv_transpose2d(self):
825835
module = self.get_qdq_module(module, sample_input)
826836
self.lower_module_and_test_output(module, sample_input)
827837

838+
def test_qnn_backend_cos(self):
839+
module = Cos() # noqa: F405
840+
sample_input = (torch.randn(2, 5, 1, 3),)
841+
module = self.get_qdq_module(module, sample_input)
842+
self.lower_module_and_test_output(module, sample_input)
843+
828844
def test_qnn_backend_einsum_outer_product(self):
829845
module = EinsumOuterProduct() # noqa: F405
830846
x = torch.randn(5)
@@ -1201,6 +1217,12 @@ def test_qnn_backend_sigmoid(self):
12011217
module = self.get_qdq_module(module, sample_input)
12021218
self.lower_module_and_test_output(module, sample_input)
12031219

1220+
def test_qnn_backend_sin(self):
1221+
module = Sin() # noqa: F405
1222+
sample_input = (torch.randn(2, 5, 1, 3),)
1223+
module = self.get_qdq_module(module, sample_input)
1224+
self.lower_module_and_test_output(module, sample_input)
1225+
12041226
def test_qnn_backend_slice_copy(self):
12051227
modules = [SliceCopy(), SliceCopyWithStep()] # noqa: F405
12061228
sample_input = (

examples/qualcomm/utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,27 @@ def build_executorch_binary(
256256
custom_pass_config=frozenset(),
257257
qat_training_data=None,
258258
):
259+
"""
260+
A function to generate an ExecuTorch binary for Qualcomm platforms.
261+
262+
Attributes:
263+
model (torch.nn.Module): The model to be converted into an ExecuTorch binary.
264+
inputs (torch.Tensor): Sample input tensors required for model export.
265+
soc_model (QcomChipset): The target Qualcomm System on Chip (SoC) model.
266+
file_name (str): Name for the output binary file (.pte).
267+
dataset (List[torch.Tensor] | Callable): A dataset for quantization calibration.
268+
skip_node_id_set (set, optional): Set of node IDs to be skipped during partition.
269+
skip_node_op_set (set, optional): Set of operation node to be skipped during partition.
270+
quant_dtype (QuantDtype, optional): Data type for quantization.
271+
custom_quantizer (Callable, optional): Custom quantizer.
272+
shared_buffer (bool, optional): Applies zero-copy mechanism to optimize runtime memory allocation.
273+
metadata (dict, optional): An optional dictionary that maps each method name to a constant value in eager mode.
274+
dump_intermediate_outputs (bool, optional): Enables dumping model intermediate outputs.
275+
custom_pass_config (frozenset, optional): Set of custom passes for model processing.
276+
277+
Returns:
278+
None: The function writes the output to a specified .pte file.
279+
"""
259280
if quant_dtype is not None:
260281
captured_model = torch.export.export(model, inputs).module()
261282
if qat_training_data:

0 commit comments

Comments
 (0)