Skip to content

Commit 9289b3f

Browse files
authored
Qualcomm AI Engine Direct - support dim_order (#7189)
Summary - remove _to_dim op and let backend take care of tensor layout - test cases
1 parent 78b60df commit 9289b3f

File tree

4 files changed

+36
-3
lines changed

4 files changed

+36
-3
lines changed

backends/qualcomm/_passes/remove_redundancy.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
class RemoveRedundancy(ExportPass):
1313
"""
14-
Trim the 'identity' operators to reduce the unnecessary copy overhead.
14+
Trim certain operators to reduce unnecessary overhead.
1515
"""
1616

1717
redundant_ops = {
@@ -21,6 +21,10 @@ class RemoveRedundancy(ExportPass):
2121
torch.ops.aten.alias.default,
2222
exir_ops.edge.aten.alias.default,
2323
exir_ops.edge.aten.lift_fresh_copy.default,
24+
# remove this target if '_skip_dim_order' is set to False
25+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
26+
# remove channel_last / contiguous _to_copy if '_skip_dim_order' is set to True
27+
exir_ops.edge.aten._to_copy.default,
2428
}
2529

2630
def __init__(self):
@@ -31,6 +35,13 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
3135
if n.target not in self.redundant_ops:
3236
continue
3337

38+
# do not remove cast operator
39+
if (
40+
n.target == exir_ops.edge.aten._to_copy.default
41+
and "memory_format" not in n.kwargs
42+
):
43+
continue
44+
3445
to_be_remove = n
3546
for user_n in list(n.users.keys()):
3647
user_n.replace_input_with(n, n.args[0])

backends/qualcomm/tests/models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def forward(self, x):
325325

326326

327327
class Conv2dSequential(torch.nn.Module):
328-
def __init__(self, bias=True):
328+
def __init__(self, bias=True, channel_last=False):
329329
super().__init__()
330330
self.first = torch.nn.Conv2d(
331331
in_channels=1,
@@ -341,8 +341,10 @@ def __init__(self, bias=True):
341341
padding=1,
342342
bias=bias,
343343
)
344+
self.channel_last = channel_last
344345

345346
def forward(self, x):
347+
x = x.to(memory_format=torch.channels_last) if self.channel_last else x
346348
return self.second(self.first(x))
347349

348350

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,16 @@ def test_qnn_backend_conv2d(self):
133133
with self.subTest(i=i):
134134
self.lower_module_and_test_output(module, sample_input)
135135

136+
def test_qnn_backend_conv2d_channel_last(self):
137+
modules = [
138+
Conv2dSequential(channel_last=True), # noqa: F405
139+
Conv2dSequential(bias=False, channel_last=True), # noqa: F405
140+
]
141+
sample_input = (torch.randn([1, 1, 3, 3]),)
142+
for i, module in enumerate(modules):
143+
with self.subTest(i=i):
144+
self.lower_module_and_test_output(module, sample_input)
145+
136146
def test_qnn_backend_conv_transpose2d(self):
137147
modules = [
138148
ConvTranspose2dSingle(), # noqa: F405
@@ -824,6 +834,17 @@ def test_qnn_backend_conv2d(self):
824834
module = self.get_qdq_module(module, sample_input)
825835
self.lower_module_and_test_output(module, sample_input)
826836

837+
def test_qnn_backend_conv2d_channel_last(self):
838+
modules = [
839+
Conv2dSequential(channel_last=True), # noqa: F405
840+
Conv2dSequential(bias=False, channel_last=True), # noqa: F405
841+
]
842+
sample_input = (torch.randn([1, 1, 3, 3]),)
843+
for i, module in enumerate(modules):
844+
with self.subTest(i=i):
845+
module = self.get_qdq_module(module, sample_input)
846+
self.lower_module_and_test_output(module, sample_input)
847+
827848
def test_qnn_backend_conv_transpose2d(self):
828849
modules = [
829850
ConvTranspose2dSingle(), # noqa: F405

backends/qualcomm/utils/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,6 @@ def qnn_capture_config():
166166
def qnn_edge_config() -> exir.EdgeCompileConfig:
167167
return exir.EdgeCompileConfig(
168168
_check_ir_validity=False,
169-
_skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend.
170169
)
171170

172171

0 commit comments

Comments
 (0)