Skip to content

Commit 00b0ce5

Browse files
authored
Add support for quantized transposed conv
Differential Revision: D68939306 Pull Request resolved: #8090
1 parent ae61caa commit 00b0ce5

File tree

3 files changed

+53
-33
lines changed

3 files changed

+53
-33
lines changed

backends/xnnpack/quantizer/xnnpack_quantizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,9 @@ class XNNPACKQuantizer(Quantizer):
249249
STATIC_OPS = [
250250
"linear_relu",
251251
"linear",
252-
"conv_relu",
253252
"conv",
253+
"conv_transpose",
254+
"conv_relu",
254255
"conv_transpose_relu",
255256
"adaptive_avg_pool2d",
256257
# TODO: move this to BoltNNQuantizer?

backends/xnnpack/quantizer/xnnpack_quantizer_utils.py

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,16 @@ class OperatorConfig(NamedTuple):
9191
operators: list[OperatorPatternType]
9292

9393

94+
def is_relu_node(node: Node) -> bool:
95+
"""
96+
Check if a given node is a relu node
97+
"""
98+
return node.op == "call_function" and node.target in [
99+
torch.ops.aten.relu.default,
100+
torch.ops.aten.relu_.default,
101+
]
102+
103+
94104
def _is_annotated(nodes: list[Node]):
95105
"""
96106
Given a list of nodes (that represents an operator pattern),
@@ -231,10 +241,7 @@ def _annotate_linear_relu(
231241
weight_qspec = get_weight_qspec(quantization_config)
232242
bias_qspec = get_bias_qspec(quantization_config)
233243
for node in gm.graph.nodes:
234-
if node.op != "call_function" or node.target not in [
235-
torch.ops.aten.relu.default,
236-
torch.ops.aten.relu_.default,
237-
]:
244+
if not is_relu_node(node):
238245
continue
239246
relu_node = node
240247
maybe_linear_node = node.args[0]
@@ -285,21 +292,28 @@ def _annotate_linear_relu(
285292
return annotated_partitions
286293

287294

288-
@register_annotator("conv")
289-
def _annotate_conv(
295+
def _do_annotate_conv(
290296
gm: torch.fx.GraphModule,
291297
quantization_config: Optional[QuantizationConfig],
292298
filter_fn: Optional[Callable[[Node], bool]] = None,
299+
is_conv_transpose: bool = False,
293300
) -> Optional[list[list[Node]]]:
294301
annotated_partitions = []
302+
is_conv_node = _is_conv_transpose_node if is_conv_transpose else _is_conv_node
303+
295304
for n in gm.graph.nodes:
296-
if n.op != "call_function" or n.target not in [
297-
torch.ops.aten.conv1d.default,
298-
torch.ops.aten.conv2d.default,
299-
]:
305+
if not is_conv_node(n):
300306
continue
301307
conv_node = n
302308

309+
# This is hacky!
310+
# We do not want to annotate conv node independently if there is a conv + relu pattern
311+
# So we skip if the conv node is consumed by a single relu node
312+
if len(conv_node.users) == 1:
313+
user = list(conv_node.users.keys())[0]
314+
if is_relu_node(user):
315+
continue
316+
303317
input_qspec_map = {}
304318
input_act = conv_node.args[0]
305319
assert isinstance(input_act, Node)
@@ -341,10 +355,7 @@ def _do_annotate_conv_relu(
341355
):
342356
annotated_partitions = []
343357
for n in gm.graph.nodes:
344-
if n.op != "call_function" or n.target not in [
345-
torch.ops.aten.relu.default,
346-
torch.ops.aten.relu_.default,
347-
]:
358+
if not is_relu_node(n):
348359
continue
349360
relu_node = n
350361
maybe_conv_node = n.args[0]
@@ -393,6 +404,26 @@ def _do_annotate_conv_relu(
393404
return annotated_partitions
394405

395406

407+
@register_annotator("conv")
408+
def _annotate_conv(
409+
gm: torch.fx.GraphModule,
410+
quantization_config: Optional[QuantizationConfig],
411+
filter_fn: Optional[Callable[[Node], bool]] = None,
412+
) -> Optional[list[list[Node]]]:
413+
return _do_annotate_conv(
414+
gm, quantization_config, filter_fn, is_conv_transpose=False
415+
)
416+
417+
418+
@register_annotator("conv_transpose")
419+
def _annotate_transpose_conv(
420+
gm: torch.fx.GraphModule,
421+
quantization_config: Optional[QuantizationConfig],
422+
filter_fn: Optional[Callable[[Node], bool]] = None,
423+
) -> Optional[list[list[Node]]]:
424+
return _do_annotate_conv(gm, quantization_config, filter_fn, is_conv_transpose=True)
425+
426+
396427
@register_annotator("conv_relu")
397428
def _annotate_conv_relu(
398429
gm: torch.fx.GraphModule,
@@ -744,10 +775,7 @@ def _annotate_add_relu( # noqa: C901
744775
) -> Optional[list[list[Node]]]:
745776
annotated_partitions = []
746777
for node in gm.graph.nodes:
747-
if node.op != "call_function" or node.target not in [
748-
torch.ops.aten.relu.default,
749-
torch.ops.aten.relu_.default,
750-
]:
778+
if not is_relu_node(node):
751779
continue
752780
relu_node = node
753781
maybe_add = node.args[0]
@@ -872,10 +900,7 @@ def _annotate_mul_relu( # noqa: C901
872900
) -> Optional[list[list[Node]]]:
873901
annotated_partitions = []
874902
for node in gm.graph.nodes:
875-
if node.op != "call_function" or node.target not in [
876-
torch.ops.aten.relu.default,
877-
torch.ops.aten.relu_.default,
878-
]:
903+
if not is_relu_node(node):
879904
continue
880905
relu_node = node
881906
maybe_mul = node.args[0]

backends/xnnpack/test/ops/test_conv2d.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -243,15 +243,14 @@ def test_qs8_conv2d_test(self) -> None:
243243
self._test(
244244
Conv2d(bias=has_bias, transpose=transpose),
245245
quant_config=get_symmetric_quantization_config(),
246-
check_quantized=not transpose, # XNNPackQuantizer does not quantize this pattern yet
247246
)
248247

249248
def test_qs8_conv2d_per_channel(self) -> None:
250249
for transpose in (True, False):
251250
self._test(
252251
Conv2d(transpose=transpose),
253252
quant_config=get_symmetric_quantization_config(is_per_channel=True),
254-
check_quantized=not transpose, # XNNPackQuantizer does not quantize this pattern yet
253+
delegated=not transpose, # XNNPACK does not support per input channel quantization for transpose convolutions with groups > 1
255254
)
256255

257256
def test_fp32_conv2d_seq(self) -> None:
@@ -264,7 +263,6 @@ def test_qs8_conv2d_seq(self) -> None:
264263
Conv2dSeq(transpose=transpose),
265264
conv_count=2,
266265
quant_config=get_symmetric_quantization_config(),
267-
check_quantized=not transpose, # XNNPackQuantizer does not quantize this pattern yet
268266
)
269267

270268
def test_fp32_conv2d_single_int_params(self):
@@ -282,7 +280,6 @@ def test_fp32_conv2d_depthwise(self):
282280
# - Groups must equal In Channels
283281
# - Out Channels must be a positive multiple of In Channels
284282
for transpose in (True, False):
285-
286283
self._test(
287284
Conv2d(groups=2, in_channels=2, out_channels=6, transpose=transpose)
288285
)
@@ -292,7 +289,6 @@ def test_qs8_conv2d_depthwise(self):
292289
self._test(
293290
Conv2d(groups=2, in_channels=2, out_channels=6, transpose=transpose),
294291
quant_config=get_symmetric_quantization_config(),
295-
check_quantized=not transpose, # XNNPackQuantizer does not quantize this pattern yet
296292
)
297293

298294
def test_fp32_conv2d_bn(self):
@@ -384,7 +380,6 @@ def test_qs8_conv2d_bn(self):
384380
Conv2dBatchNorm(transpose=transpose),
385381
quant_config=get_symmetric_quantization_config(),
386382
conv_count=2,
387-
check_quantized=not transpose, # XNNPackQuantizer does not quantize this pattern yet
388383
)
389384

390385
def test_qs8_conv2d_relu(self):
@@ -415,7 +410,7 @@ def get_inputs(self):
415410
self._test(
416411
ConvReLU(transpose=transpose),
417412
quant_config=get_symmetric_quantization_config(is_per_channel=True),
418-
delegated=not transpose,
413+
delegated=not transpose, # XNNPACK does not support per input channel quantization for transpose convolutions with groups > 1
419414
)
420415

421416
def test_qs8_conv2d_dw_relu(self):
@@ -467,9 +462,8 @@ def get_inputs(self):
467462
quant_config=get_symmetric_quantization_config(
468463
is_per_channel=per_channel_quant
469464
),
470-
# xnnpack only supports per output channel quantization for transposed convolutions
471-
# XNNPackQuantizer quantizes per input channel currently
472-
delegated=not transpose or not per_channel_quant,
465+
# XNNPACK does not support per input channel quantization for transpose convolutions with groups > 1
466+
delegated=not (transpose and per_channel_quant),
473467
)
474468

475469
def test_qs8_conv2d_relu_seq(self):

0 commit comments

Comments
 (0)