Skip to content

Commit 1c6dbb6

Browse files
haowhsu-quicfacebook-github-bot
authored andcommitted
Qualcomm AI Engine Direct - support Conv2dTranspose (#5461)
Summary: - Conv2dTranspose op enablement - test cases Pull Request resolved: #5461 Reviewed By: kirklandsign Differential Revision: D63568634 Pulled By: cccclai fbshipit-source-id: 8add7116b6a40db1654d0edd50483d43ade31ff2
1 parent e19677c commit 1c6dbb6

File tree

5 files changed

+149
-58
lines changed

5 files changed

+149
-58
lines changed

backends/qualcomm/builders/op_conv2d.py

Lines changed: 61 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
OpDepthWiseConv2d,
1919
OpExpandDims,
2020
OpReshape,
21+
OpTransposeConv2d,
2122
QNN_OP_PACKAGE_NAME_QTI_AISW,
2223
)
2324
from .utils import get_parameter
@@ -42,6 +43,9 @@ def _add_conv_op_parameter(
4243
padding_shape,
4344
dilation,
4445
dilation_shape,
46+
output_padding=None,
47+
output_padding_shape=None,
48+
transpose_conv=False,
4549
groups=None,
4650
) -> PyQnnWrapper.PyQnnOpWrapper:
4751
"""
@@ -68,14 +72,26 @@ def _add_conv_op_parameter(
6872
),
6973
True,
7074
)
71-
conv_op.AddTensorParam(
72-
OP.param_dilation,
73-
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
74-
len(dilation_shape),
75-
dilation_shape,
76-
np.array(dilation, dtype=np.uint32),
77-
True,
78-
)
75+
76+
if transpose_conv:
77+
conv_op.AddTensorParam(
78+
OP.param_output_padding,
79+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
80+
len(output_padding_shape),
81+
output_padding_shape,
82+
np.array(output_padding, dtype=np.uint32),
83+
True,
84+
)
85+
else:
86+
conv_op.AddTensorParam(
87+
OP.param_dilation,
88+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
89+
len(dilation_shape),
90+
dilation_shape,
91+
np.array(dilation, dtype=np.uint32),
92+
True,
93+
)
94+
7995
if groups is not None:
8096
conv_op.AddScalarParam(
8197
OP.param_group,
@@ -94,6 +110,11 @@ def _define_conv1d(
94110
Conv1D is a special case for convolutional operation. QNN does not support Conv1D, therefore,
95111
we need to cast from input -> Conv1d -> output to input -> unsqueeze -> Conv2d -> squeeze -> output.
96112
"""
113+
transpose_conv = cast(bool, node.args[6])
114+
if transpose_conv:
115+
print("ConvTranspose1d is not yet supported")
116+
return
117+
97118
op_wrapper_list = [] # op_wrapper to return
98119
unsqueeze_input_node = node.args[0]
99120
input_quant_encoding, input_quant_configs = self.get_quant_encoding_conf(
@@ -239,9 +260,9 @@ def define_node(
239260
node: torch.fx.Node,
240261
nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper],
241262
) -> PyQnnWrapper.PyQnnOpWrapper:
242-
243263
if get_parameter(node.args[1], self.edge_program).dim() == 3:
244264
return self._define_conv1d(node, nodes_to_wrappers)
265+
245266
input_node = node.args[0]
246267
input_tensor = self.get_tensor(input_node, node)
247268
input_tensor_wrapper = self.define_tensor(
@@ -254,8 +275,9 @@ def define_node(
254275

255276
filter_node = node.args[1]
256277
filter_tensor = get_parameter(filter_node, self.edge_program)
257-
# weight of pytorch OIHW, yet QNN is HWIO
258-
filter_axis_order = (2, 3, 1, 0)
278+
# weight of pytorch OIHW(conv2d) | IOHW(conv_transpose2d), yet QNN is HWIO
279+
is_transpose_conv = cast(bool, node.args[6])
280+
filter_axis_order = (2, 3, 0, 1) if is_transpose_conv else (2, 3, 1, 0)
259281
filter_tensor = filter_tensor.permute(dims=filter_axis_order).contiguous()
260282
filter_tensor_wrapper = self.define_tensor(
261283
filter_node,
@@ -291,6 +313,7 @@ def define_node(
291313
stride = cast(List[int], node.args[3])
292314
padding = cast(List[int], node.args[4])
293315
dilation = cast(List[int], node.args[5])
316+
output_padding = cast(List[int], node.args[7])
294317

295318
groups = cast(int, node.args[8])
296319
# Qnn filter tensor is (H, W, Cin, Cout)
@@ -308,57 +331,38 @@ def define_node(
308331
if len(padding) == 1:
309332
padding = padding + padding
310333

311-
# args[6] = transposed
312-
if cast(bool, node.args[6]):
313-
print("Currently, No support for transposed convolution")
314-
return
315-
316-
# args[7] = output padding
317-
if not all(out_pad == 0 for out_pad in cast(List[int], node.args[7])):
318-
print("QNN does not support output padding")
319-
return
320-
321334
stride_shape = [len(stride)]
322335
padding_shape = [2, 2]
323336
dilation_shape = [len(dilation)]
337+
output_padding_shape = [len(output_padding)]
324338

325339
if is_depthwise_conv:
326-
conv_op = PyQnnWrapper.PyQnnOpWrapper(
327-
node.name,
328-
QNN_OP_PACKAGE_NAME_QTI_AISW,
329-
OpDepthWiseConv2d.op_name,
330-
)
331-
conv_op = self._add_conv_op_parameter(
332-
OpDepthWiseConv2d,
333-
conv_op,
334-
conv_input_tensors,
335-
conv_output_tensors,
336-
stride,
337-
stride_shape,
338-
padding,
339-
padding_shape,
340-
dilation,
341-
dilation_shape,
342-
)
343-
340+
op_class = OpDepthWiseConv2d
341+
elif is_transpose_conv:
342+
op_class = OpTransposeConv2d
344343
else:
345-
conv_op = PyQnnWrapper.PyQnnOpWrapper(
346-
node.name,
347-
QNN_OP_PACKAGE_NAME_QTI_AISW,
348-
OpConv2d.op_name,
349-
)
350-
conv_op = self._add_conv_op_parameter(
351-
OpConv2d,
352-
conv_op,
353-
conv_input_tensors,
354-
conv_output_tensors,
355-
stride,
356-
stride_shape,
357-
padding,
358-
padding_shape,
359-
dilation,
360-
dilation_shape,
361-
groups,
362-
)
344+
op_class = OpConv2d
345+
346+
conv_op = PyQnnWrapper.PyQnnOpWrapper(
347+
node.name,
348+
QNN_OP_PACKAGE_NAME_QTI_AISW,
349+
op_class.op_name,
350+
)
351+
conv_op = self._add_conv_op_parameter(
352+
op_class,
353+
conv_op,
354+
conv_input_tensors,
355+
conv_output_tensors,
356+
stride,
357+
stride_shape,
358+
padding,
359+
padding_shape,
360+
dilation,
361+
dilation_shape,
362+
output_padding,
363+
output_padding_shape,
364+
is_transpose_conv,
365+
None if is_depthwise_conv else groups,
366+
)
363367

364368
return conv_op

backends/qualcomm/builders/qnn_constants.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,3 +356,12 @@ class OpTile:
356356
class OpTranspose:
357357
op_name: str = "Transpose"
358358
param_perm: str = "perm"
359+
360+
361+
@dataclass(init=False, frozen=True)
362+
class OpTransposeConv2d:
363+
op_name: str = "TransposeConv2d"
364+
param_stride: str = "stride"
365+
param_pad_amount: str = "pad_amount"
366+
param_group: str = "group"
367+
param_output_padding: str = "output_padding"

backends/qualcomm/quantizer/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -941,7 +941,13 @@ def annotate_bmm(node: Node, quantization_config: QuantizationConfig) -> None:
941941
node.meta["source_fn_stack"] = [(node, torch.bmm)]
942942

943943

944-
@register_annotator([torch.ops.aten.conv2d.default, torch.ops.aten.conv1d.default])
944+
@register_annotator(
945+
[
946+
torch.ops.aten.conv2d.default,
947+
torch.ops.aten.conv1d.default,
948+
torch.ops.aten.conv_transpose2d.input,
949+
]
950+
)
945951
def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None:
946952
if _is_annotated([node]):
947953
return

backends/qualcomm/tests/models.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,46 @@ def forward(self, x):
361361
return self.conv(x)
362362

363363

364+
class ConvTranspose2dSingle(torch.nn.Module):
365+
def __init__(self, bias=True):
366+
super().__init__()
367+
self.conv_transpose = torch.nn.ConvTranspose2d(
368+
in_channels=1,
369+
out_channels=3,
370+
kernel_size=3,
371+
stride=2,
372+
padding=1,
373+
bias=bias,
374+
)
375+
376+
def forward(self, x):
377+
return self.conv_transpose(x)
378+
379+
380+
class Conv2dDownUpSample(torch.nn.Module):
381+
def __init__(self, bias=True):
382+
super().__init__()
383+
self.conv = torch.nn.Conv2d(
384+
in_channels=16,
385+
out_channels=16,
386+
kernel_size=3,
387+
stride=2,
388+
padding=1,
389+
bias=bias,
390+
)
391+
self.conv_transpose = torch.nn.ConvTranspose2d(
392+
in_channels=16,
393+
out_channels=16,
394+
kernel_size=3,
395+
stride=2,
396+
padding=1,
397+
bias=bias,
398+
)
399+
400+
def forward(self, x):
401+
return self.conv_transpose(self.conv(x))
402+
403+
364404
class Conv2dSumReduceDim(torch.nn.Module):
365405
def __init__(self):
366406
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,16 @@ def test_qnn_backend_conv2d(self):
130130
with self.subTest(i=i):
131131
self.lower_module_and_test_output(module, sample_input)
132132

133+
def test_qnn_backend_conv_transpose2d(self):
134+
modules = [
135+
ConvTranspose2dSingle(), # noqa: F405
136+
ConvTranspose2dSingle(bias=False), # noqa: F405
137+
]
138+
sample_input = (torch.randn([1, 1, 3, 3]),)
139+
for i, module in enumerate(modules):
140+
with self.subTest(i=i):
141+
self.lower_module_and_test_output(module, sample_input)
142+
133143
def test_qnn_backend_element_wise_add(self):
134144
test_comb = [
135145
{
@@ -521,6 +531,11 @@ def test_qnn_backend_conv2d_cat(self):
521531
sample_input = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5))
522532
self.lower_module_and_test_output(module, sample_input)
523533

534+
def test_qnn_backend_conv2d_down_up_sample(self):
535+
module = Conv2dDownUpSample() # noqa: F405
536+
sample_input = (torch.randn(1, 16, 224, 224),)
537+
self.lower_module_and_test_output(module, sample_input)
538+
524539
def test_qnn_backend_conv2d_max_pool2d(self):
525540
module = Conv2dMaxPool2d() # noqa: F405
526541
sample_input = (torch.rand(1, 2, 14, 14),)
@@ -713,6 +728,17 @@ def test_qnn_backend_conv2d(self):
713728
module = self.get_qdq_module(module, sample_input)
714729
self.lower_module_and_test_output(module, sample_input)
715730

731+
def test_qnn_backend_conv_transpose2d(self):
732+
modules = [
733+
ConvTranspose2dSingle(), # noqa: F405
734+
ConvTranspose2dSingle(bias=False), # noqa: F405
735+
] # noqa: F405
736+
sample_input = (torch.randn([1, 1, 3, 3]),)
737+
for i, module in enumerate(modules):
738+
with self.subTest(i=i):
739+
module = self.get_qdq_module(module, sample_input)
740+
self.lower_module_and_test_output(module, sample_input)
741+
716742
def test_qnn_backend_element_wise_add(self):
717743
test_comb = [
718744
{
@@ -1157,6 +1183,12 @@ def test_qnn_backend_conv2d_cat(self):
11571183
module = self.get_qdq_module(module, sample_input)
11581184
self.lower_module_and_test_output(module, sample_input)
11591185

1186+
def test_qnn_backend_conv2d_down_up_sample(self):
1187+
module = Conv2dDownUpSample() # noqa: F405
1188+
sample_input = (torch.randn(1, 16, 224, 224),)
1189+
module = self.get_qdq_module(module, sample_input)
1190+
self.lower_module_and_test_output(module, sample_input)
1191+
11601192
def test_qnn_backend_conv2d_max_pool2d(self):
11611193
module = Conv2dMaxPool2d() # noqa: F405
11621194
sample_input = (torch.rand(1, 2, 14, 14),)

0 commit comments

Comments
 (0)