Skip to content

Enable sin, cos op for QNN HTP backend #6591

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
op_ceil,
op_clamp,
op_conv2d,
op_cos,
op_depth_to_space,
op_dequantize,
op_div,
Expand Down Expand Up @@ -43,6 +44,7 @@
op_rsqrt,
op_select_copy,
op_sigmoid,
op_sin,
op_skip_ops,
op_slice_copy,
op_softmax,
Expand Down Expand Up @@ -71,6 +73,7 @@
op_ceil,
op_clamp,
op_conv2d,
op_cos,
op_depth_to_space,
op_dequantize,
op_div,
Expand Down Expand Up @@ -100,6 +103,7 @@
op_rsqrt,
op_select_copy,
op_sigmoid,
op_sin,
op_skip_ops,
op_slice_copy,
op_softmax,
Expand Down
56 changes: 56 additions & 0 deletions backends/qualcomm/builders/op_cos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import torch

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpElementWiseCos, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class Cos(NodeVisitor):
target = ["aten.cos.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
input_node = node.args[0]
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=True,
)

output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=False,
)

cos_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpElementWiseCos.op_name,
)
cos_op.AddInputTensors([input_tensor_wrapper])
cos_op.AddOutputTensors([output_tensor_wrapper])

return cos_op
56 changes: 56 additions & 0 deletions backends/qualcomm/builders/op_sin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import torch

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpElementWiseSin, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class Sin(NodeVisitor):
target = ["aten.sin.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
input_node = node.args[0]
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=True,
)

output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=False,
)

sin_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpElementWiseSin.op_name,
)
sin_op.AddInputTensors([input_tensor_wrapper])
sin_op.AddOutputTensors([output_tensor_wrapper])

return sin_op
10 changes: 10 additions & 0 deletions backends/qualcomm/builders/qnn_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ class OpElementWiseCeil:
op_name = "ElementWiseCeil"


@dataclass(init=False, frozen=True)
class OpElementWiseCos:
op_name: str = "ElementWiseCos"


@dataclass(init=False, frozen=True)
class OpElementWiseDivide:
op_name: str = "ElementWiseDivide"
Expand Down Expand Up @@ -113,6 +118,11 @@ class OpElementWiseRsqrt:
op_name: str = "ElementWiseRsqrt"


@dataclass(init=False, frozen=True)
class OpElementWiseSin:
op_name: str = "ElementWiseSin"


@dataclass(init=False, frozen=True)
class OpElementWiseSubtract:
op_name = "ElementWiseSubtract"
Expand Down
10 changes: 10 additions & 0 deletions backends/qualcomm/quantizer/annotators.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,16 @@ def annotate_relu(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.cos.default])
def annotate_cos(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.sin.default])
def annotate_sin(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.tanh.default])
def annotate_tanh(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
Expand Down
16 changes: 16 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,14 @@ def forward(self, x):
return topk_values


class Cos(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.cos(x)


class Div(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -889,6 +897,14 @@ def forward(self, x):
return torch.sigmoid(x)


class Sin(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.sin(x)


class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
22 changes: 22 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ def test_qnn_backend_conv_transpose2d(self):
with self.subTest(i=i):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_cos(self):
module = Cos() # noqa: F405
sample_input = (torch.randn(2, 5, 1, 3),)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_einsum_outer_product(self):
module = EinsumOuterProduct() # noqa: F405
x = torch.randn(5)
Expand Down Expand Up @@ -465,6 +470,11 @@ def test_qnn_backend_sigmoid(self):
sample_input = (torch.randn([1, 3, 3, 3]),)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_sin(self):
module = Sin() # noqa: F405
sample_input = (torch.randn(2, 5, 1, 3),)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_select_copy(self):
module = SelectCopy() # noqa: F405
sample_input = (torch.randn([1, 3, 3, 3]),)
Expand Down Expand Up @@ -825,6 +835,12 @@ def test_qnn_backend_conv_transpose2d(self):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_cos(self):
module = Cos() # noqa: F405
sample_input = (torch.randn(2, 5, 1, 3),)
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_einsum_outer_product(self):
module = EinsumOuterProduct() # noqa: F405
x = torch.randn(5)
Expand Down Expand Up @@ -1201,6 +1217,12 @@ def test_qnn_backend_sigmoid(self):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_sin(self):
module = Sin() # noqa: F405
sample_input = (torch.randn(2, 5, 1, 3),)
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_slice_copy(self):
modules = [SliceCopy(), SliceCopyWithStep()] # noqa: F405
sample_input = (
Expand Down
21 changes: 21 additions & 0 deletions examples/qualcomm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,27 @@ def build_executorch_binary(
custom_pass_config=frozenset(),
qat_training_data=None,
):
"""
A function to generate an ExecuTorch binary for Qualcomm platforms.

Attributes:
model (torch.nn.Module): The model to be converted into an ExecuTorch binary.
inputs (torch.Tensor): Sample input tensors required for model export.
soc_model (QcomChipset): The target Qualcomm System on Chip (SoC) model.
file_name (str): Name for the output binary file (.pte).
dataset (List[torch.Tensor] | Callable): A dataset for quantization calibration.
skip_node_id_set (set, optional): Set of node IDs to be skipped during partition.
skip_node_op_set (set, optional): Set of operation node to be skipped during partition.
quant_dtype (QuantDtype, optional): Data type for quantization.
custom_quantizer (Callable, optional): Custom quantizer.
shared_buffer (bool, optional): Applies zero-copy mechanism to optimize runtime memory allocation.
metadata (dict, optional): An optional dictionary that maps each method name to a constant value in eager mode.
dump_intermediate_outputs (bool, optional): Enables dumping model intermediate outputs.
custom_pass_config (frozenset, optional): Set of custom passes for model processing.

Returns:
None: The function writes the output to a specified .pte file.
"""
if quant_dtype is not None:
captured_model = torch.export.export(model, inputs).module()
if qat_training_data:
Expand Down
Loading