Skip to content

Commit 821a2fe

Browse files
authored
add argmin, add_ ops
Differential Revision: D68811662 Pull Request resolved: #8040
1 parent a3455d9 commit 821a2fe

File tree

7 files changed

+120
-0
lines changed

7 files changed

+120
-0
lines changed

backends/qualcomm/_passes/layout_transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class LayoutTransform(ExportPass):
4545
layout_agnostic_ops = {
4646
exir_ops.edge.aten.abs.default,
4747
exir_ops.edge.aten.add.Tensor,
48+
exir_ops.edge.aten.argmin.default,
4849
exir_ops.edge.aten.bmm.default,
4950
exir_ops.edge.aten.cat.default,
5051
exir_ops.edge.aten.ceil.default,

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
op_adaptive_avg_pool2d,
1111
op_add,
1212
op_arange,
13+
op_argmin,
1314
op_avg_pool2d,
1415
op_batch_norm,
1516
op_bmm,
@@ -82,6 +83,7 @@
8283
op_adaptive_avg_pool2d,
8384
op_add,
8485
op_arange,
86+
op_argmin,
8587
op_avg_pool2d,
8688
op_batch_norm,
8789
op_bmm,
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import cast, Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
import numpy as np
10+
import torch
11+
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
12+
13+
from .node_visitor import NodeVisitor, register_node_visitor
14+
from .qnn_constants import OpArgmin, QNN_OP_PACKAGE_NAME_QTI_AISW
15+
16+
17+
@register_node_visitor
18+
class Argmin(NodeVisitor):
19+
target = ["aten.argmin.default"]
20+
21+
def __init__(self, *args) -> None:
22+
super().__init__(*args)
23+
24+
def define_node(
25+
self,
26+
node: torch.fx.Node,
27+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
28+
) -> PyQnnWrapper.PyQnnOpWrapper:
29+
input_node = node.args[0]
30+
input_tensor = self.get_tensor(input_node, node)
31+
argmin_inp_tensor_wrapper = self.define_tensor(
32+
input_node,
33+
node,
34+
input_tensor,
35+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
36+
nodes_to_wrappers,
37+
)
38+
argmin_input_tensors = [argmin_inp_tensor_wrapper]
39+
40+
output_tensor = self.get_tensor(node, node).to(torch.int32)
41+
# arg output is index, do not quantize it.
42+
node.meta.pop("quant_attrs", None)
43+
output_tensor_wrapper = self.define_tensor(
44+
node,
45+
node,
46+
output_tensor,
47+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
48+
nodes_to_wrappers,
49+
)
50+
argmin_output_tensors = [output_tensor_wrapper]
51+
52+
dim = cast(int, node.args[1])
53+
if dim < 0:
54+
dim = dim % len(input_tensor.shape)
55+
if QCOM_AXIS_ORDER in node.meta:
56+
dim = node.meta[QCOM_AXIS_ORDER].index(dim)
57+
58+
argmin_op = PyQnnWrapper.PyQnnOpWrapper(
59+
node.name,
60+
QNN_OP_PACKAGE_NAME_QTI_AISW,
61+
OpArgmin.op_name,
62+
)
63+
argmin_op.AddInputTensors(argmin_input_tensors)
64+
argmin_op.AddOutputTensors(argmin_output_tensors)
65+
66+
argmin_op.AddScalarParam(
67+
OpArgmin.param_axis,
68+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
69+
{QCOM_DATA: np.uint32(dim)},
70+
)
71+
72+
if len(node.args) > 2:
73+
keep_dims = cast(bool, node.args[2])
74+
argmin_op.AddScalarParam(
75+
OpArgmin.param_keep_dims,
76+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
77+
{QCOM_DATA: keep_dims},
78+
)
79+
80+
return argmin_op

backends/qualcomm/builders/qnn_constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,13 @@ class OpReduceMean:
301301
param_keep_dims: str = "keep_dims"
302302

303303

304+
@dataclass(init=False, frozen=True)
305+
class OpArgmin:
306+
op_name: str = "Argmin"
307+
param_axis: str = "axis"
308+
param_keep_dims: str = "keep_dims"
309+
310+
304311
@dataclass(init=False, frozen=True)
305312
class OpReduceSum:
306313
op_name: str = "ReduceSum"

backends/qualcomm/quantizer/annotators.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,13 @@ def annotate_add(node: Node, quantization_config: QuantizationConfig) -> None:
167167
annotate_binary(node, quantization_config)
168168

169169

170+
@register_annotator([torch.ops.aten.argmin.default])
171+
def annotate_argmin(node: Node, quantization_config: QuantizationConfig) -> None:
172+
if _is_annotated([node]):
173+
return
174+
annotate_single_in_single_out(node, quantization_config)
175+
176+
170177
@register_annotator([torch.ops.aten.sub, torch.ops.aten.sub.Tensor])
171178
def annotate_sub(node: Node, quantization_config: QuantizationConfig) -> None:
172179
annotate_binary(node, quantization_config)

backends/qualcomm/tests/models.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,18 @@ def forward(self, x):
275275
return x
276276

277277

278+
class Conv2dArgmin(torch.nn.Module):
279+
def __init__(self):
280+
super().__init__()
281+
self.conv = torch.nn.Conv2d(
282+
3, 16, 7, bias=True, stride=2, padding=3, dilation=1
283+
)
284+
285+
def forward(self, x):
286+
x = self.conv(x)
287+
return torch.argmin(x, dim=0, keepdim=True)
288+
289+
278290
class Conv2dAvgPool2d(torch.nn.Module):
279291
def __init__(self):
280292
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,11 @@ def test_qnn_backend_arange(self):
112112
with self.subTest(i=i):
113113
self.lower_module_and_test_output(module, sample_input)
114114

115+
def test_qnn_backend_argmin(self):
116+
module = Conv2dArgmin() # noqa: F405
117+
sample_input = (torch.randn(16, 3, 4, 4),)
118+
self.lower_module_and_test_output(module, sample_input)
119+
115120
def test_qnn_backend_avg_pool2d(self):
116121
module = AvgPoolModule() # noqa: F405
117122
sample_input = (torch.randn(1, 3, 2, 2),)
@@ -939,6 +944,12 @@ def test_qnn_backend_arange(self):
939944
module = self.get_qdq_module(module, sample_input)
940945
self.lower_module_and_test_output(module, sample_input)
941946

947+
def test_qnn_backend_argmin(self):
948+
module = Conv2dArgmin() # noqa: F405
949+
sample_input = (torch.randn(16, 3, 4, 4),)
950+
module = self.get_qdq_module(module, sample_input)
951+
self.lower_module_and_test_output(module, sample_input)
952+
942953
def test_qnn_backend_avg_pool2d(self):
943954
module = AvgPoolModule() # noqa: F405
944955
sample_input = (torch.randn(1, 3, 2, 2),)

0 commit comments

Comments
 (0)