Skip to content

Commit c630569

Browse files
salilsdesaifacebook-github-bot
authored andcommitted
Add pass to properly check if q and dq nodes are implicit (#49)
Summary: Pull Request resolved: #49 This diff adds a pass which is used to tag "implicit" q/dq nodes, which should be ignored during preprocessing. A q or dq node is deemed to be "implicit" if any of the following hold: a) All of its inputs are constants (get_attr nodes), since (de)quantizing constants is done outside of executing the graph b) It is the q or dq surrounding a "supported" group of nodes, ordered as dq -> [supported group] -> q. A "supported" group is comprised of one of the following: (i) A single supported op, from SUPPORTED_QUANT_OPS_SET, (ii) A single supported module, from SUPPORTED_QUANT_MODULES_SET, or (iii) a chain of nodes matching a supported chain from SUPPORTED_QUANT_CHAINS. q/dq nodes which match this condition should be ignore during preprocessing because they are only used as signaling for q params of node inputs c) It is a dq followed by aten.linear.default and then an output node. This is because aten.linear.default is a special op corresponding with dqlinear which doesn't necessarily have an q after it Reviewed By: digantdesai, mcr229 Differential Revision: D47900500 fbshipit-source-id: b1a34901f41e4a0d528c26073c1f3afdc7820b4b
1 parent b66faf2 commit c630569

File tree

10 files changed

+423
-119
lines changed

10 files changed

+423
-119
lines changed

backends/xnnpack/operators/op_dequantize_per_tensor.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
NodeVisitor,
1212
register_node_visitor,
1313
)
14+
from executorch.backends.xnnpack.passes.tag_implicit_q_dq_pass import TagImplicitQDqPass
1415
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
1516
XNNConvert,
1617
XNNGraph,
@@ -23,12 +24,7 @@
2324
@register_node_visitor
2425
class OpDeQuantizePerTensor(NodeVisitor):
2526
"""
26-
Dequantize Per Tensor Node visitor. We only insert an XNNPACK node if
27-
this op was found as a graph input or graph output. This is so we
28-
dequantize the input going in. Every other instance of quantize per
29-
tensor is only used as signaling for q params of node inputs, so
30-
we ignore those. This is because xnnpack only supports entire graph
31-
quantization
27+
Dequantize Per Tensor Node visitor
3228
"""
3329

3430
target = "quantized_decomposed.dequantize_per_tensor.default"
@@ -44,10 +40,9 @@ def define_node(
4440
debug_handle: int,
4541
) -> None:
4642
"""
47-
We only define a node if it is a graph output
43+
We only define a node if it is not an implict dq node
4844
"""
49-
# TODO:@maxren better handle in-graph quantization conversions, this is hacky
50-
if self.is_graph_output(node):
45+
if not TagImplicitQDqPass.is_tagged_as_implicit_q_dq(node):
5146
dq_input = get_input_node(node, 0)
5247
input_quant_params = QuantParams.from_q_dq_node(node)
5348
# fp32 output

backends/xnnpack/operators/op_quantize_per_tensor.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
NodeVisitor,
1212
register_node_visitor,
1313
)
14+
from executorch.backends.xnnpack.passes.tag_implicit_q_dq_pass import TagImplicitQDqPass
1415
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
1516
XNNConvert,
1617
XNNGraph,
@@ -23,12 +24,7 @@
2324
@register_node_visitor
2425
class OpQuantizePerTensor(NodeVisitor):
2526
"""
26-
Quantize Per Tensor Node visitor. We only insert an XNNPACK node if
27-
this op was found as a graph input or graph output. This is so we
28-
quantize the input going in. Every other instance of quantize per
29-
tensor is only used as signaling for q params of node inputs, so
30-
we ignore those. This is because xnnpack only supports entire graph
31-
quantization
27+
Quantize Per Tensor Node visitor
3228
"""
3329

3430
target = "quantized_decomposed.quantize_per_tensor.default"
@@ -44,11 +40,10 @@ def define_node(
4440
debug_handle: int,
4541
) -> None:
4642
"""
47-
We only define a node if it is a graph input
43+
We only define a node if it is not an implict q node
4844
"""
49-
# TODO:@maxren better handle in-graph quantization conversions, this is hacky
5045
q_input = get_input_node(node, 0)
51-
if self.is_graph_input(q_input):
46+
if not TagImplicitQDqPass.is_tagged_as_implicit_q_dq(node):
5247
input_quant_params = QuantParams.from_q_dq_node(node)
5348
# fp32 input
5449
self.define_tensor(q_input, xnn_graph, vals_to_ids)

backends/xnnpack/partition/TARGETS

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ runtime.python_library(
2525
"@EXECUTORCH_CLIENTS",
2626
],
2727
deps = [
28+
":configs",
2829
":support_patterns",
2930
"//executorch/backends/xnnpack:xnnpack_preprocess",
3031
"//executorch/exir:delegate",
@@ -34,3 +35,17 @@ runtime.python_library(
3435
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
3536
],
3637
)
38+
39+
runtime.python_library(
40+
name = "configs",
41+
srcs = [
42+
"configs.py",
43+
],
44+
visibility = [
45+
"//executorch/...",
46+
"@EXECUTORCH_CLIENTS",
47+
],
48+
deps = [
49+
"//executorch/exir:lib",
50+
],
51+
)

backends/xnnpack/partition/configs.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
10+
###
11+
### Module based partitioners
12+
###
13+
14+
SUPPORTED_OPS = [
15+
exir_ops.edge.aten.div.Tensor,
16+
exir_ops.edge.aten.add.Tensor,
17+
exir_ops.edge.aten.clamp.default,
18+
exir_ops.edge.aten.sub.Tensor,
19+
exir_ops.edge.aten.floor.default,
20+
exir_ops.edge.aten.maximum.default,
21+
exir_ops.edge.aten.minimum.default,
22+
exir_ops.edge.aten.mul.Tensor,
23+
exir_ops.edge.aten.constant_pad_nd.default,
24+
exir_ops.edge.aten.upsample_bilinear2d.default,
25+
exir_ops.edge.aten.mean.dim,
26+
exir_ops.edge.aten.max.dim,
27+
exir_ops.edge.aten.hardtanh.default,
28+
exir_ops.edge.aten.sqrt.default,
29+
exir_ops.edge.aten.ceil.default,
30+
exir_ops.edge.aten.hardswish.default,
31+
exir_ops.edge.aten.neg.default,
32+
exir_ops.edge.aten.pow.Tensor_Scalar,
33+
exir_ops.edge.aten.abs.default,
34+
exir_ops.edge.aten._prelu_kernel.default,
35+
exir_ops.edge.aten.slice_copy.Tensor,
36+
]
37+
38+
SUPPORTED_MODULES = [
39+
torch.nn.Conv1d,
40+
torch.nn.Conv2d,
41+
torch.nn.ReLU,
42+
torch.nn.Sigmoid,
43+
torch.nn.Softmax,
44+
torch.nn.BatchNorm1d,
45+
torch.nn.BatchNorm2d,
46+
torch.nn.Linear,
47+
torch.nn.functional.linear,
48+
torch.nn.Hardtanh,
49+
torch.nn.MaxPool2d,
50+
torch.nn.LeakyReLU,
51+
torch.nn.ELU,
52+
torch.nn.AvgPool2d,
53+
torch.nn.PReLU, # Without this, the PReLU weight becomes not a get_attr
54+
torch.cat,
55+
torch.concat,
56+
torch.concatenate,
57+
]
58+
59+
# TODO delete this and should use SUPPORTED_OPS instead once we align fp32 and quant support
60+
SUPPORTED_QUANT_OPS = [
61+
exir_ops.edge.aten.add.Tensor,
62+
exir_ops.edge.aten.sub.Tensor,
63+
exir_ops.edge.aten.mul.Tensor,
64+
exir_ops.edge.aten.mean.dim,
65+
exir_ops.edge.aten.hardtanh.default, # TODO - which one module or op or both?
66+
exir_ops.edge.aten.slice_copy.Tensor,
67+
]
68+
69+
SUPPORTED_IMPLICIT_Q_DQ_OP_NAMES_SET = {
70+
op.name()
71+
for op in (
72+
SUPPORTED_QUANT_OPS
73+
+ [
74+
exir_ops.edge.aten._to_copy.default,
75+
exir_ops.edge.aten.max_pool2d.default,
76+
exir_ops.edge.aten.linear.default,
77+
]
78+
)
79+
}
80+
81+
# TODO delete this and should use SUPPORTED_MODULES instead once we align fp32 and quant support
82+
SUPPORTED_QUANT_MODULES = [
83+
torch.clamp,
84+
torch.mean,
85+
torch.permute,
86+
torch.permute_copy,
87+
torch.cat,
88+
torch.concat,
89+
torch.concatenate,
90+
torch.nn.Linear,
91+
torch.nn.functional.linear,
92+
# TODO - T158982884
93+
# torch.ao.nn.quantized.reference.modules.linear.Linear,
94+
torch.nn.MaxPool2d,
95+
torch.nn.Conv1d,
96+
torch.nn.functional.conv1d,
97+
torch.ao.nn.quantized.reference.modules.conv.Conv1d,
98+
torch.nn.Conv2d,
99+
torch.nn.functional.conv2d,
100+
torch.nn.functional.pad,
101+
torch.nn.functional.elu,
102+
torch.ao.nn.quantized.reference.modules.conv.Conv2d,
103+
torch.nn.BatchNorm1d,
104+
torch.nn.BatchNorm2d,
105+
torch.nn.ConstantPad2d,
106+
torch.nn.ELU,
107+
torch.nn.Hardtanh,
108+
torch.nn.ReLU,
109+
torch.nn.functional.relu,
110+
torch.nn.functional.relu_,
111+
torch.nn.functional.leaky_relu,
112+
torch.nn.functional.leaky_relu_,
113+
torch.nn.LeakyReLU,
114+
]
115+
116+
SUPPORTED_IMPLICIT_Q_DQ_MODULES_SET = set(SUPPORTED_QUANT_MODULES)
117+
118+
# Modules which support dynamic quantization
119+
SUPPORTED_DYN_QUANT_MODULES = [
120+
torch.nn.Linear,
121+
torch.nn.functional.linear,
122+
]

backends/xnnpack/partition/xnnpack_partitioner.py

Lines changed: 8 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@
1010
from typing import Any, Callable, cast, Dict, List, Optional, Union
1111

1212
import torch
13+
14+
from executorch.backends.xnnpack.partition.configs import (
15+
SUPPORTED_DYN_QUANT_MODULES,
16+
SUPPORTED_MODULES,
17+
SUPPORTED_OPS,
18+
SUPPORTED_QUANT_MODULES,
19+
SUPPORTED_QUANT_OPS,
20+
)
1321
from executorch.backends.xnnpack.partition.support_patterns import (
1422
get_add_graphs,
1523
get_all_dynamically_quantized_linear_pattern,
@@ -522,107 +530,6 @@ def __init__(self):
522530
)
523531

524532

525-
###
526-
### Module based partitioners
527-
###
528-
529-
SUPPORTED_OPS = [
530-
exir_ops.edge.aten.div.Tensor,
531-
exir_ops.edge.aten.add.Tensor,
532-
exir_ops.edge.aten.clamp.default,
533-
exir_ops.edge.aten.sub.Tensor,
534-
exir_ops.edge.aten.floor.default,
535-
exir_ops.edge.aten.maximum.default,
536-
exir_ops.edge.aten.minimum.default,
537-
exir_ops.edge.aten.mul.Tensor,
538-
exir_ops.edge.aten.constant_pad_nd.default,
539-
exir_ops.edge.aten.upsample_bilinear2d.default,
540-
exir_ops.edge.aten.mean.dim,
541-
exir_ops.edge.aten.max.dim,
542-
exir_ops.edge.aten.hardtanh.default,
543-
exir_ops.edge.aten.sqrt.default,
544-
exir_ops.edge.aten.ceil.default,
545-
exir_ops.edge.aten.hardswish.default,
546-
exir_ops.edge.aten.neg.default,
547-
exir_ops.edge.aten.pow.Tensor_Scalar,
548-
exir_ops.edge.aten.abs.default,
549-
exir_ops.edge.aten._prelu_kernel.default,
550-
exir_ops.edge.aten.slice_copy.Tensor,
551-
]
552-
553-
SUPPORTED_MODULES = [
554-
torch.nn.Conv1d,
555-
torch.nn.Conv2d,
556-
torch.nn.ReLU,
557-
torch.nn.Sigmoid,
558-
torch.nn.Softmax,
559-
torch.nn.BatchNorm1d,
560-
torch.nn.BatchNorm2d,
561-
torch.nn.Linear,
562-
torch.nn.functional.linear,
563-
torch.nn.Hardtanh,
564-
torch.nn.MaxPool2d,
565-
torch.nn.LeakyReLU,
566-
torch.nn.ELU,
567-
torch.nn.AvgPool2d,
568-
torch.nn.PReLU, # Without this, the PReLU weight becomes not a get_attr
569-
torch.cat,
570-
torch.concat,
571-
torch.concatenate,
572-
]
573-
574-
# TODO delete this and should use SUPPORTED_OPS instead once we align fp32 and quant support
575-
SUPPORTED_QUANT_OPS = [
576-
exir_ops.edge.aten.add.Tensor,
577-
exir_ops.edge.aten.sub.Tensor,
578-
exir_ops.edge.aten.mul.Tensor,
579-
exir_ops.edge.aten.mean.dim,
580-
exir_ops.edge.aten.hardtanh.default, # TODO - which one module or op or both?
581-
exir_ops.edge.aten.slice_copy.Tensor,
582-
]
583-
584-
# TODO delete this and should use SUPPORTED_MODULES instead once we align fp32 and quant support
585-
SUPPORTED_QUANT_MODULES = [
586-
torch.clamp,
587-
torch.mean,
588-
torch.permute,
589-
torch.permute_copy,
590-
torch.cat,
591-
torch.concat,
592-
torch.concatenate,
593-
torch.nn.Linear,
594-
torch.nn.functional.linear,
595-
# TODO - T158982884
596-
# torch.ao.nn.quantized.reference.modules.linear.Linear,
597-
torch.nn.MaxPool2d,
598-
torch.nn.Conv1d,
599-
torch.nn.functional.conv1d,
600-
torch.ao.nn.quantized.reference.modules.conv.Conv1d,
601-
torch.nn.Conv2d,
602-
torch.nn.functional.conv2d,
603-
torch.nn.functional.pad,
604-
torch.nn.functional.elu,
605-
torch.ao.nn.quantized.reference.modules.conv.Conv2d,
606-
torch.nn.BatchNorm1d,
607-
torch.nn.BatchNorm2d,
608-
torch.nn.ConstantPad2d,
609-
torch.nn.ELU,
610-
torch.nn.Hardtanh,
611-
torch.nn.ReLU,
612-
torch.nn.functional.relu,
613-
torch.nn.functional.relu_,
614-
torch.nn.functional.leaky_relu,
615-
torch.nn.functional.leaky_relu_,
616-
torch.nn.LeakyReLU,
617-
]
618-
619-
# Modules which support dynamic quantization
620-
SUPPORTED_DYN_QUANT_MODULES = [
621-
torch.nn.Linear,
622-
torch.nn.functional.linear,
623-
]
624-
625-
626533
class XnnpackFloatingPointPartitioner(Partitioner):
627534
"""
628535
Module and Opname based partitioner for FP32 modules/ops listed in

backends/xnnpack/passes/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@ python_library(
1010
"fuse_batch_norm_with_conv.py",
1111
"prelu_reshape_pass.py",
1212
"remove_getitem_op.py",
13+
"tag_implicit_q_dq_pass.py",
1314
],
1415
deps = [
1516
"//caffe2:torch",
1617
"//executorch/backends/transforms:lib",
18+
"//executorch/backends/xnnpack/partition:configs",
1719
"//executorch/backends/xnnpack/utils:xnnpack_utils",
1820
"//executorch/exir:pass_base",
1921
"//executorch/exir/dialects:lib",

backends/xnnpack/passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
from executorch.backends.xnnpack.passes.prelu_reshape_pass import PReLUReshapePass
1616
from executorch.backends.xnnpack.passes.remove_getitem_op import RemoveGetItemPass
17+
from executorch.backends.xnnpack.passes.tag_implicit_q_dq_pass import TagImplicitQDqPass
1718

1819
from executorch.exir.passes import PassManager
1920
from executorch.exir.passes.const_prop_pass import ConstPropPass
@@ -27,5 +28,6 @@
2728
Conv1dUnsqueezePass(),
2829
PReLUReshapePass(),
2930
ChannelsLastTaggedReshapePass(),
31+
TagImplicitQDqPass(),
3032
]
3133
)

0 commit comments

Comments
 (0)