Skip to content

Commit e24d503

Browse files
perfreddan80
authored andcommitted
Convert more NodeVisitors to folding DQ/Q pass usage
Signed-off-by: Per Åstrand <[email protected]> Change-Id: I9201d8bafd543204b697c7276d6929ad3aa09f25
1 parent eae61f7 commit e24d503

File tree

8 files changed

+143
-90
lines changed

8 files changed

+143
-90
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ def transform_to_backend_pipeline(
9090
exir_ops.edge.aten.minimum.default,
9191
exir_ops.edge.aten.maximum.default,
9292
exir_ops.edge.aten.add.Tensor,
93+
exir_ops.edge.aten.avg_pool2d.default,
94+
exir_ops.edge.aten.convolution.default,
9395
]
9496
)
9597
)

backends/arm/operators/op_avg_pool2d.py

Lines changed: 87 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,44 +8,118 @@
88

99
import serializer.tosa_serializer as ts
1010
import torch
11+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
12+
get_input_qparams,
13+
get_output_qparams,
14+
)
1115
from executorch.backends.arm.operators.node_visitor import (
1216
NodeVisitor,
1317
register_node_visitor,
1418
)
1519
from executorch.backends.arm.tosa_mapping import TosaArg
16-
from executorch.backends.arm.tosa_utils import build_avg_pool_2d_common
20+
from executorch.backends.arm.tosa_specification import TosaSpecification
1721

1822

1923
@register_node_visitor
20-
class AvgPool2dVisitor(NodeVisitor):
24+
class AvgPool2dVisitor_0_80_BI(NodeVisitor):
2125
target = "aten.avg_pool2d.default"
2226

27+
tosa_specs = [
28+
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
29+
]
30+
2331
def __init__(self, *args):
2432
super().__init__(*args)
2533

26-
def define_node(
34+
def _build_generic_avgpool2d(
2735
self,
2836
node: torch.fx.Node,
2937
tosa_graph: ts.TosaSerializer,
3038
inputs: List[TosaArg],
3139
output: TosaArg,
32-
is_quant_node: bool,
40+
input_zp: int,
41+
output_zp: int,
42+
accumulator_type,
3343
) -> None:
3444
input_tensor = inputs[0]
45+
3546
kernel_size_list = inputs[1].special
3647
stride_size_list = inputs[2].special
3748
try:
3849
pad_size_list = inputs[3].special
3950
except IndexError:
4051
pad_size_list = [0, 0, 0, 0]
4152

42-
build_avg_pool_2d_common(
43-
node,
44-
tosa_graph,
45-
input_tensor,
46-
kernel_size_list,
47-
stride_size_list,
48-
pad_size_list,
49-
is_quant_node,
50-
output,
53+
attr = ts.TosaSerializerAttribute()
54+
attr.PoolAttribute(
55+
kernel=kernel_size_list,
56+
stride=stride_size_list,
57+
pad=pad_size_list,
58+
input_zp=input_zp,
59+
output_zp=output_zp,
60+
accum_dtype=accumulator_type,
61+
)
62+
63+
tosa_graph.addOperator(
64+
ts.TosaOp.Op().AVG_POOL2D,
65+
[input_tensor.name],
66+
[output.name],
67+
attr,
68+
)
69+
70+
def define_node(
71+
self,
72+
node: torch.fx.Node,
73+
tosa_graph: ts.TosaSerializer,
74+
inputs: List[TosaArg],
75+
output: TosaArg,
76+
is_quant_node: bool,
77+
) -> None:
78+
input_tensor = inputs[0]
79+
assert input_tensor.dtype == ts.DType.INT8
80+
81+
accumulator_type = ts.DType.INT32
82+
83+
input_qargs = get_input_qparams(node)
84+
input_zp = input_qargs[0].zp
85+
86+
output_qargs = get_output_qparams(node)
87+
output_zp = output_qargs[0].zp
88+
89+
self._build_generic_avgpool2d(
90+
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
5191
)
92+
93+
94+
@register_node_visitor
95+
class AvgPool2dVisitor_0_80_MI(AvgPool2dVisitor_0_80_BI):
96+
# inheriting 'target' from BI class
97+
98+
tosa_specs = [
99+
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
100+
]
101+
102+
def define_node(
103+
self,
104+
node: torch.fx.Node,
105+
tosa_graph: ts.TosaSerializer,
106+
inputs: List[TosaArg],
107+
output: TosaArg,
108+
is_quant_node: bool,
109+
) -> None:
110+
assert (
111+
inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32
112+
), "Only FP32 and INT8 supported"
113+
114+
if inputs[0].dtype == ts.DType.INT8:
115+
super().define_node(node, tosa_graph, inputs, output, is_quant_node)
116+
117+
if inputs[0].dtype == ts.DType.FP32:
118+
accumulator_type = ts.DType.FP32
119+
# Initilize zero point to zero.
120+
input_zp = 0
121+
output_zp = 0
122+
123+
self._build_generic_avgpool2d(
124+
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
125+
)

backends/arm/operators/op_batch_norm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
register_node_visitor,
1414
)
1515
from executorch.backends.arm.tosa_mapping import TosaArg
16+
from executorch.backends.arm.tosa_specification import TosaSpecification
1617
from executorch.backends.arm.tosa_utils import promote_shape, tosa_shape
1718
from serializer.tosa_serializer import TosaOp
1819

@@ -21,6 +22,10 @@
2122
class BatchNormVisitor(NodeVisitor):
2223
target = "aten._native_batch_norm_legit_no_training.default"
2324

25+
tosa_specs = [
26+
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
27+
]
28+
2429
def __init__(self, *args):
2530
super().__init__(*args)
2631

backends/arm/operators/op_conv2d.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@
88

99
import serializer.tosa_serializer as ts
1010
import torch
11+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
12+
get_input_qparams,
13+
get_output_qparams,
14+
)
1115
from executorch.backends.arm.operators.node_visitor import (
1216
NodeVisitor,
1317
register_node_visitor,
1418
)
1519
from executorch.backends.arm.tosa_mapping import TosaArg
16-
from executorch.backends.arm.tosa_quant_utils import (
17-
build_rescale_conv_output,
18-
get_quant_arg_downstream,
19-
get_quant_arg_upstream,
20-
)
20+
from executorch.backends.arm.tosa_quant_utils import build_rescale_conv_output
2121
from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape
2222

2323
from serializer.tosa_serializer import TosaOp
@@ -57,9 +57,6 @@ def define_node(
5757
) -> None:
5858
input, weight, bias, stride, pad, dilation, _, _, group = inputs
5959

60-
# Currently only int8 is supported in quantized types.
61-
actual_out_type = ts.DType.INT8 if is_quant_node else output.dtype
62-
6360
# Get the attributes of convolution.
6461
attr = ts.TosaSerializerAttribute()
6562
pad_attr = [val for val in pad.special for _ in (0, 1)]
@@ -82,9 +79,11 @@ def define_node(
8279
dilation_attr[1],
8380
)
8481

85-
input_zp = (
86-
get_quant_arg_upstream(node.all_input_nodes[0]).zp if is_quant_node else 0
87-
)
82+
input_zp = 0
83+
if inputs[0].dtype == ts.DType.INT8:
84+
# int8 input requires quantization information
85+
input_qparams = get_input_qparams(node)
86+
input_zp = input_qparams[0].zp
8887

8988
attr.ConvAttribute(
9089
pad=pad_attr,
@@ -100,16 +99,22 @@ def define_node(
10099
# Create a zero bias tensor if not presented
101100
out_channels = weight.shape[0]
102101
bias_name = "bias" + node.name.split("default", 1)[1]
102+
bias_type = output.dtype
103+
if output.dtype == ts.DType.INT8:
104+
# Conv is quantized to int8, but the TOSA operator has
105+
# output type int32, and the bias must be the same type
106+
# as the TOSA output type
107+
bias_type = ts.DType.INT32
103108
bias = tosa_graph.addConst(
104109
[out_channels],
105-
ts.DType.INT32 if is_quant_node else output.dtype,
110+
bias_type,
106111
[0] * out_channels,
107112
name=bias_name,
108113
)
109114

110115
# The output type is int32 when input type is int8.
111116
conv2d_output_name = output.name
112-
if is_quant_node:
117+
if output.dtype == ts.DType.INT8:
113118
conv2d_res = tosa_graph.addIntermediate(
114119
tosa_shape(output.shape, output.dim_order), ts.DType.INT32
115120
)
@@ -132,7 +137,7 @@ def define_node(
132137

133138
weight_reshaped = tosa_graph.addIntermediate(
134139
weight_post_shape,
135-
ts.DType.INT8 if is_quant_node else weight.dtype,
140+
weight.dtype,
136141
)
137142
build_reshape(
138143
tosa_graph, weight.name, weight_post_shape, weight_reshaped.name
@@ -157,20 +162,19 @@ def define_node(
157162

158163
# For quantized convolution, rescale the output value back to the same
159164
# integer value domain of the next op. Otherwise return float32 output.
160-
if is_quant_node:
165+
if inputs[0].dtype == ts.DType.INT8:
161166
# Get scale_factor from input, weight, and output.
162-
input_scale = get_quant_arg_upstream(node.all_input_nodes[0]).scale
163-
weight_scale = get_quant_arg_upstream(node.all_input_nodes[1]).scale
164-
output_qargs = get_quant_arg_downstream(list(node.users)[0])
165-
167+
input_scale = input_qparams[0].scale
168+
weight_scale = input_qparams[1].scale
169+
output_qargs = get_output_qparams(node)
166170
build_rescale_conv_output(
167171
tosa_graph,
168172
# pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined.
169173
conv2d_res,
170174
output.name,
171-
actual_out_type,
175+
output.dtype,
172176
input_scale,
173177
weight_scale,
174-
output_qargs.scale,
175-
output_qargs.zp,
178+
output_qargs[0].scale,
179+
output_qargs[0].zp,
176180
)

backends/arm/operators/op_div.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
register_node_visitor,
1414
)
1515
from executorch.backends.arm.tosa_mapping import TosaArg
16+
from executorch.backends.arm.tosa_specification import TosaSpecification
1617
from executorch.backends.arm.tosa_utils import tosa_shape
1718
from serializer.tosa_serializer import TosaOp
1819

@@ -21,6 +22,11 @@
2122
class DivVisitor(NodeVisitor):
2223
target = "aten.div.Tensor"
2324

25+
# Only supported for MI
26+
tosa_specs = [
27+
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
28+
]
29+
2430
def __init__(self, *args):
2531
super().__init__(*args)
2632

backends/arm/operators/op_max_pool2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
register_node_visitor,
1414
)
1515
from executorch.backends.arm.tosa_mapping import TosaArg
16-
from executorch.backends.arm.tosa_utils import (
16+
from executorch.backends.arm.tosa_quant_utils import (
1717
get_quant_arg_downstream,
1818
get_quant_arg_upstream,
1919
)

backends/arm/process_node.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
import serializer.tosa_serializer as ts
1212
import torch
1313
import torch.fx
14+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
15+
get_input_qparams,
16+
)
1417
from executorch.backends.arm.operators.node_visitor import NodeVisitor
1518
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
1619
from executorch.backends.arm.tosa_quant_utils import (
17-
get_quant_arg_upstream,
1820
get_quantized_node_output_dtype,
1921
is_node_quantized,
2022
)
@@ -110,8 +112,10 @@ def process_quantized_bias(
110112
_,
111113
) = consumer_node.all_input_nodes
112114

113-
input_node_scale = get_quant_arg_upstream(input_node).scale
114-
weight_node_scale = get_quant_arg_upstream(weight_node).scale
115+
input_qargs = get_input_qparams(consumer_node)
116+
117+
input_node_scale = input_qargs[0].scale
118+
weight_node_scale = input_qargs[1].scale
115119
bias_values_quantized = (
116120
(parameter_values / (input_node_scale * weight_node_scale))
117121
.round()

0 commit comments

Comments
 (0)