Skip to content

Commit 82faf9b

Browse files
authored
Qualcomm AI Engine Direct - XR model mld_a enablement (#9129)
### Summary - make index op builder more general - small refactor on layout_transform - support new pattern of upsample2d ### Test plan ```bash python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNQuantizedOperator -s $SERIAL_NO -m SM8650 -b build-android ```
1 parent 37fa261 commit 82faf9b

File tree

4 files changed

+23
-18
lines changed

4 files changed

+23
-18
lines changed

backends/qualcomm/_passes/layout_transform.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
from executorch.exir.pass_base import ExportPass, PassResult
2020
from executorch.exir.sym_util import eval_shape
2121

22-
from .utils import dq_ops, q_ops
23-
2422

2523
class LayoutTransform(ExportPass):
2624
"""
@@ -91,8 +89,6 @@ class LayoutTransform(ExportPass):
9189
exir_ops.edge.aten.topk.default,
9290
exir_ops.edge.aten._to_copy.default,
9391
exir_ops.edge.aten.where.self,
94-
*q_ops,
95-
*dq_ops,
9692
_operator.getitem,
9793
}
9894

@@ -117,7 +113,6 @@ def __init__(
117113
super(LayoutTransform, self).__init__()
118114
self.edge_program = edge_program
119115
self.insert_permute = insert_permute
120-
self.qdq_opset = {*q_ops, *dq_ops}
121116
self.transformed_tag = QCOM_AXIS_ORDER
122117

123118
def mark_as_transformed(self, node: torch.fx.Node) -> None:

backends/qualcomm/builders/op_index.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ def define_node(
3838
nodes_to_wrappers,
3939
)
4040

41-
if len(node.args[1]) > 1:
42-
# TODO consider to implement it in a recursive way.
43-
raise NotImplementedError("Not support tuple of tensor.")
44-
45-
indices_node = node.args[1][0]
41+
# e.g. x[:, index]:
42+
# > node.args[1] = [None, indices]
43+
# > axis = 1
44+
axis = len(node.args[1]) - 1
45+
indices_node = node.args[1][axis]
4646
indices_tensor = self.get_tensor(indices_node, node).to(torch.int32)
4747
assert indices_tensor.size(0) != 0, "Not support empty indices list"
4848

@@ -78,7 +78,7 @@ def define_node(
7878
gather_op.AddScalarParam(
7979
OpGather.param_axis,
8080
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32,
81-
{QCOM_DATA: np.int32(0)},
81+
{QCOM_DATA: np.int32(axis)},
8282
)
8383

8484
return gather_op

backends/qualcomm/tests/models.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -746,13 +746,19 @@ def forward(self, x):
746746

747747

748748
class Index(torch.nn.Module):
749-
def __init__(self):
749+
def __init__(self, axis):
750750
super().__init__()
751751
self.idx0 = torch.tensor([[0, 1], [2, 3], [4, 5]], dtype=torch.int32)
752752
self.idx1 = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.int32)
753+
self.axis = axis
754+
self.dispatcher = {
755+
0: lambda x: x[self.idx0] + x[self.idx1],
756+
1: lambda x: x[:, self.idx0] + x[:, self.idx1],
757+
2: lambda x: x[:, :, self.idx0] + x[:, :, self.idx1],
758+
}
753759

754760
def forward(self, x):
755-
return x[self.idx0] + x[self.idx1]
761+
return self.dispatcher[self.axis](x)
756762

757763

758764
class IndexPut(torch.nn.Module):

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -469,9 +469,11 @@ def test_qnn_backend_hardtanh(self):
469469
self.lower_module_and_test_output(module, sample_input)
470470

471471
def test_qnn_backend_index(self):
472-
module = Index() # noqa: F405
472+
modules = [Index(0), Index(1), Index(2)] # noqa: F405
473473
sample_input = (torch.randn([8, 172, 64]),)
474-
self.lower_module_and_test_output(module, sample_input)
474+
for i, module in enumerate(modules):
475+
with self.subTest(i=i):
476+
self.lower_module_and_test_output(module, sample_input)
475477

476478
def test_qnn_backend_index_put(self):
477479
module = IndexPut() # noqa: F405
@@ -1457,10 +1459,12 @@ def test_qnn_backend_hardtanh(self):
14571459
self.lower_module_and_test_output(module, sample_input)
14581460

14591461
def test_qnn_backend_index(self):
1460-
module = Index() # noqa: F405
1462+
modules = [Index(0), Index(1), Index(2)] # noqa: F405
14611463
sample_input = (torch.randn([8, 172, 64]),)
1462-
module = self.get_qdq_module(module, sample_input)
1463-
self.lower_module_and_test_output(module, sample_input)
1464+
for i, module in enumerate(modules):
1465+
with self.subTest(i=i):
1466+
module = self.get_qdq_module(module, sample_input)
1467+
self.lower_module_and_test_output(module, sample_input)
14641468

14651469
def test_qnn_backend_index_put(self):
14661470
module = IndexPut() # noqa: F405

0 commit comments

Comments
 (0)