Skip to content

Commit 3378258

Browse files
committed
Revert "Search graph for quantization nodes (#6452)"
This reverts commit 63017e4.
1 parent 3a1f8d2 commit 3378258

24 files changed

+175
-292
lines changed

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
get_first_fake_tensor,
1515
insert_q_dq_pair,
1616
)
17-
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, register_passable_op
17+
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
1818
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
1919
from executorch.exir.dialects._ops import ops as exir_ops
2020
from executorch.exir.pass_base import ExportPass, PassResult
@@ -42,9 +42,6 @@ def _transpose_impl(*args, **kwargs):
4242
return args[0]
4343

4444

45-
register_passable_op(torch.ops.passthrough_to_tosa._transpose)
46-
47-
4845
class AnnotateChannelsLastDimOrder(ExportPass):
4946
"""
5047
Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order

backends/arm/_passes/insert_squeeze_after_sum_pass.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
import torch
1010
import torch.fx
11-
from executorch.backends.arm._passes.arm_pass_utils import create_node
11+
from executorch.backends.arm._passes.arm_pass_utils import create_node, insert_q_dq_pair
12+
13+
from executorch.backends.arm.tosa_quant_utils import get_quant_node_args, is_quant_node
1214
from executorch.exir.dialects._ops import ops as exir_ops
1315
from executorch.exir.pass_base import ExportPass, PassResult
1416

@@ -26,6 +28,8 @@ class InsertSqueezeAfterSumPass(ExportPass):
2628
sum(dims, keep_dim = False)
2729
After pass:
2830
sum(dims, keep_dim = True)
31+
(q)
32+
(dq)
2933
squeeze(dim = dims)
3034
"""
3135

@@ -41,6 +45,12 @@ def call(self, graph_module: torch.fx.GraphModule):
4145
continue
4246

4347
dim_list = cast(list[int], sum_node.args[1])
48+
quantized = is_quant_node(sum_node)
49+
if quantized:
50+
qparams = get_quant_node_args(sum_node.all_input_nodes[0])
51+
qparams = qparams + (torch.int8,)
52+
else:
53+
qparams = None
4454

4555
# Add keep_dim = True arg to sum node.
4656
sum_node.args = sum_node.args[0:2] + (True,)
@@ -51,6 +61,8 @@ def call(self, graph_module: torch.fx.GraphModule):
5161
)
5262
sum_node.replace_all_uses_with(squeeze_node)
5363
squeeze_node.args = (sum_node, dim_list)
64+
if quantized:
65+
sum_node = insert_q_dq_pair(graph_module.graph, sum_node, qparams)
5466
graph_module.graph.eliminate_dead_code()
5567
graph_module.recompile()
5668
graph_module = super().call(graph_module).graph_module

backends/arm/_passes/size_adjust_conv2d_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import cast, Optional
1010

1111
import torch.fx
12-
from executorch.backends.arm.tosa_quant_utils import is_node_quantized
12+
from executorch.backends.arm.tosa_quant_utils import is_quant_node
1313
from executorch.exir.dialects._ops import ops as exir_ops
1414
from executorch.exir.pass_base import ExportPass, PassResult
1515
from torch._ops import OpOverload
@@ -113,7 +113,7 @@ def call(self, graph_module: torch.fx.GraphModule):
113113
slice_node = graph.create_node(
114114
"call_function", self.slice_op, (last_node,) + args
115115
)
116-
if is_node_quantized(last_node):
116+
if is_quant_node(last_node):
117117
q_params = last_node.args[1:]
118118
dq_node = insert_q_dq_pair(
119119
graph_module.graph, slice_node, q_params

backends/arm/operators/op_addmm.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,10 @@
1414
register_node_visitor,
1515
)
1616
from executorch.backends.arm.tosa_mapping import TosaArg
17-
from executorch.backends.arm.tosa_quant_utils import (
18-
build_rescale,
19-
search_quant_arg_downstream,
20-
search_quant_arg_upstream,
21-
)
17+
from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args
2218

2319
from executorch.backends.arm.tosa_utils import build_reshape
20+
from executorch.exir.dialects._ops import ops as exir_ops
2421
from serializer.tosa_serializer import TosaOp
2522

2623

@@ -70,7 +67,12 @@ def define_node(
7067
input_zp = 0
7168
if is_quant_node:
7269
input_node = node.all_input_nodes[1]
73-
input_zp = search_quant_arg_upstream(input_node).zp
70+
# rank > 2 linear layer
71+
if input_node.target == exir_ops.edge.aten.view_copy.default:
72+
quant_node = input_node.all_input_nodes[0]
73+
else:
74+
quant_node = input_node
75+
input_zp = get_quant_node_args(quant_node).zp
7476
attr.ConvAttribute(
7577
pad=pad_attr,
7678
stride=stride_attr,
@@ -105,16 +107,24 @@ def define_node(
105107
# Read inputs' parent nodes
106108
_, input_node, weight_node = node.all_input_nodes
107109

108-
qargs = search_quant_arg_upstream(input_node)
109-
input_scale = qargs.scale
110-
consumer_node = list(node.users)[0]
111-
quant_args = search_quant_arg_downstream(consumer_node)
112-
113-
consumer_node_scale = quant_args.scale
114-
consumer_node_node_zp = quant_args.zp
110+
# rank > 2 linear layer
111+
if input_node.target == exir_ops.edge.aten.view_copy.default:
112+
quant_node = input_node.all_input_nodes[0]
113+
input_scale = get_quant_node_args(quant_node).scale
114+
consumer_node = list(node.users)[0]
115+
consumer_consumer_node = list(consumer_node.users)[0]
116+
quant_args = get_quant_node_args(consumer_consumer_node)
117+
consumer_node_scale = quant_args.scale
118+
consumer_node_node_zp = quant_args.zp
119+
else:
120+
input_scale = get_quant_node_args(input_node).scale
121+
consumer_node = list(node.users)[0]
122+
quant_args = get_quant_node_args(consumer_node)
123+
consumer_node_scale = quant_args.scale
124+
consumer_node_node_zp = quant_args.zp
115125

116126
weight_node_q_node = weight_node.all_input_nodes[0]
117-
weight_scale = search_quant_arg_upstream(weight_node_q_node).scale
127+
weight_scale = get_quant_node_args(weight_node_q_node).scale
118128

119129
output_rescale_scale = (input_scale * weight_scale) / consumer_node_scale
120130

backends/arm/operators/op_bmm.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,7 @@
1414
register_node_visitor,
1515
)
1616
from executorch.backends.arm.tosa_mapping import TosaArg
17-
from executorch.backends.arm.tosa_quant_utils import (
18-
build_rescale,
19-
search_quant_arg_downstream,
20-
search_quant_arg_upstream,
21-
)
17+
from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args
2218
from executorch.backends.arm.tosa_utils import get_two_inputs
2319
from serializer.tosa_serializer import TosaOp
2420

@@ -46,10 +42,8 @@ def define_node(
4642
# For INT8, we need to get the zero points and add an intermediate tensor
4743
# for a later rescale.
4844
if is_quant_node:
49-
input0_q_params = search_quant_arg_upstream(input0)
50-
input1_q_params = search_quant_arg_upstream(input1)
51-
input0_zp = input0_q_params.zp
52-
input1_zp = input1_q_params.zp
45+
input0_zp = get_quant_node_args(input0).zp
46+
input1_zp = get_quant_node_args(input1).zp
5347
bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
5448
bmm_output_name = bmm_result.name
5549
else:
@@ -69,7 +63,9 @@ def define_node(
6963

7064
# As INT8 accumulates into INT32, we need to rescale it back to INT8
7165
if is_quant_node:
72-
output_q_params = search_quant_arg_downstream(list(node.users)[0])
66+
input0_q_params = get_quant_node_args(input0)
67+
input1_q_params = get_quant_node_args(input1)
68+
output_q_params = get_quant_node_args(list(node.users)[0])
7369

7470
final_output_scale = (
7571
input0_q_params.scale * input1_q_params.scale

backends/arm/operators/op_conv2d.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import List
7+
from typing import cast, List
88

99
import serializer.tosa_serializer as ts
1010
import torch
@@ -15,10 +15,9 @@
1515
from executorch.backends.arm.tosa_mapping import TosaArg
1616
from executorch.backends.arm.tosa_quant_utils import (
1717
build_rescale_conv_output,
18-
search_quant_arg_downstream,
19-
search_quant_arg_upstream,
18+
get_quant_node_args,
2019
)
21-
from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape
20+
from executorch.backends.arm.tosa_utils import build_reshape, getNodeArgs, tosa_shape
2221

2322
from serializer.tosa_serializer import TosaOp
2423

@@ -83,9 +82,7 @@ def define_node(
8382
)
8483

8584
input_zp = (
86-
search_quant_arg_upstream(node.all_input_nodes[0]).zp
87-
if is_quant_node
88-
else 0
85+
get_quant_node_args(node.all_input_nodes[0]).zp if is_quant_node else 0
8986
)
9087

9188
attr.ConvAttribute(
@@ -161,10 +158,9 @@ def define_node(
161158
# integer value domain of the next op. Otherwise return float32 output.
162159
if is_quant_node:
163160
# Get scale_factor from input, weight, and output.
164-
input_scale = search_quant_arg_upstream(node.all_input_nodes[0]).scale
165-
weight_scale = search_quant_arg_upstream(node.all_input_nodes[1]).scale
166-
output_qargs = search_quant_arg_downstream(list(node.users)[0])
167-
161+
_, input_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[0]))
162+
_, weight_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[1]))
163+
_, output_scale, output_zp, _, _, _ = getNodeArgs(list(node.users)[0])
168164
build_rescale_conv_output(
169165
tosa_graph,
170166
# pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined.
@@ -173,6 +169,6 @@ def define_node(
173169
actual_out_type,
174170
input_scale,
175171
weight_scale,
176-
output_qargs.scale,
177-
output_qargs.zp,
172+
output_scale,
173+
output_zp,
178174
)

backends/arm/operators/op_exp.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,9 @@
1717

1818
from executorch.backends.arm.tosa_quant_utils import (
1919
dequantize_value,
20+
get_quant_node_args,
2021
QuantArgs,
2122
quantize_value,
22-
search_quant_arg_downstream,
23-
search_quant_arg_upstream,
2423
)
2524
from serializer.tosa_serializer import TosaOp
2625
from torch.fx import Node
@@ -49,9 +48,9 @@ def define_node(
4948

5049
# Create attribute for 8 bit table lookup.
5150
input_node = node.all_input_nodes[0]
52-
in_quantargs = search_quant_arg_upstream(input_node)
51+
in_quantargs = get_quant_node_args(input_node)
5352
output_node = list(node.users)[0]
54-
out_quantargs = search_quant_arg_downstream(output_node)
53+
out_quantargs = get_quant_node_args(output_node)
5554

5655
table = exp_table_8bit(in_quantargs, out_quantargs)
5756
table_attr = ts.TosaSerializerAttribute()

backends/arm/operators/op_full.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@
1414
register_node_visitor,
1515
)
1616
from executorch.backends.arm.tosa_mapping import TosaArg
17-
from executorch.backends.arm.tosa_quant_utils import (
18-
quantize_value,
19-
search_quant_arg_downstream,
20-
)
17+
from executorch.backends.arm.tosa_quant_utils import get_quant_node_args
2118
from executorch.backends.arm.tosa_utils import tosa_shape
2219
from torch.fx import Node
2320

@@ -42,8 +39,10 @@ def define_node(
4239

4340
value = inputs[1].number
4441
if is_quant_node:
45-
qargs = search_quant_arg_downstream(list(node.users)[0])
46-
qvalue = quantize_value(value, qargs)
42+
qargs = get_quant_node_args(list(node.users)[0])
43+
qvalue = np.clip(
44+
np.round(value / qargs.scale) + qargs.zp, qargs.qmin, qargs.qmax
45+
)
4746
dtype = ts.DType.INT8
4847
data = np.full(shape, qvalue, dtype=np.int8)
4948
else:

backends/arm/operators/op_hardtanh.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@
1414
)
1515
from executorch.backends.arm.tosa_mapping import TosaArg
1616

17-
from executorch.backends.arm.tosa_quant_utils import (
18-
quantize_value,
19-
search_quant_arg_upstream,
20-
)
17+
from executorch.backends.arm.tosa_quant_utils import get_quant_node_args
2118
from serializer.tosa_serializer import TosaOp
2219

2320

@@ -40,10 +37,12 @@ def define_node(
4037

4138
if is_quant_node:
4239
# Get quant parameters
43-
qargs = search_quant_arg_upstream(node.all_input_nodes[0])
40+
scale, zp, qmin, qmax = get_quant_node_args(node.all_input_nodes[0])
4441
# Convert to quantized representation
45-
clamp_min_qs = quantize_value(inputs[1].number, qargs)
46-
clamp_max_qs = quantize_value(inputs[2].number, qargs)
42+
clamp_min_qs = round((inputs[1].number / scale) + zp)
43+
clamp_min_qs = max(clamp_min_qs, qmin)
44+
clamp_max_qs = round((inputs[2].number / scale) + zp)
45+
clamp_max_qs = min(clamp_max_qs, qmax)
4746
# Set fp values to 0.0 since they are not used
4847
clamp_min_fp = 0.0
4948
clamp_max_fp = 0.0

backends/arm/operators/op_log.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,9 @@
1717

1818
from executorch.backends.arm.tosa_quant_utils import (
1919
dequantize_value,
20+
get_quant_node_args,
2021
QuantArgs,
2122
quantize_value,
22-
search_quant_arg_downstream,
23-
search_quant_arg_upstream,
2423
)
2524
from serializer.tosa_serializer import TosaOp
2625
from torch.fx import Node
@@ -50,9 +49,9 @@ def define_node(
5049

5150
# Create attribute for 8 bit table lookup.
5251
input_node = node.all_input_nodes[0]
53-
in_quantargs = search_quant_arg_upstream(input_node)
52+
in_quantargs = get_quant_node_args(input_node)
5453
output_node = list(node.users)[0]
55-
out_quantargs = search_quant_arg_downstream(output_node)
54+
out_quantargs = get_quant_node_args(output_node)
5655

5756
table = log_table_8bit(in_quantargs, out_quantargs)
5857
table_attr = ts.TosaSerializerAttribute()

backends/arm/operators/op_mm.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,7 @@
1414
register_node_visitor,
1515
)
1616
from executorch.backends.arm.tosa_mapping import TosaArg
17-
from executorch.backends.arm.tosa_quant_utils import (
18-
build_rescale,
19-
search_quant_arg_downstream,
20-
search_quant_arg_upstream,
21-
)
17+
from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args
2218
from executorch.backends.arm.tosa_utils import (
2319
build_reshape,
2420
expand_dims,
@@ -58,8 +54,8 @@ def define_node(
5854
# For INT8, we need to get the zero point, otherwise it is 0
5955
input0_zp, input1_zp = 0, 0
6056
if is_quant_node:
61-
input0_zp = search_quant_arg_upstream(input0).zp
62-
input1_zp = search_quant_arg_upstream(input1).zp
57+
input0_zp = get_quant_node_args(input0).zp
58+
input1_zp = get_quant_node_args(input1).zp
6359

6460
mat_mul_result = tosa_graph.addIntermediate(
6561
output_new_shape, ts.DType.INT32 if is_quant_node else output.dtype
@@ -90,9 +86,9 @@ def define_node(
9086

9187
# As INT8 accumulates into INT32, we need to rescale it back to INT8
9288
if is_quant_node:
93-
input0_q_params = search_quant_arg_upstream(input0)
94-
input1_q_params = search_quant_arg_upstream(input1)
95-
output_q_params = search_quant_arg_downstream(list(node.users)[0])
89+
input0_q_params = get_quant_node_args(input0)
90+
input1_q_params = get_quant_node_args(input1)
91+
output_q_params = get_quant_node_args(list(node.users)[0])
9692

9793
final_output_scale = (
9894
input0_q_params.scale * input1_q_params.scale

backends/arm/operators/op_mul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ def define_node(
3737
if is_quant_node:
3838
input_A = inputs[0]
3939
input_B = inputs[1]
40-
input_A_qargs = tqutils.search_quant_arg_upstream(
40+
input_A_qargs = tqutils.get_quant_node_args(
4141
cast(torch.fx.Node, node.args[0])
4242
)
43-
input_B_qargs = tqutils.search_quant_arg_upstream(
43+
input_B_qargs = tqutils.get_quant_node_args(
4444
cast(torch.fx.Node, node.args[1])
4545
)
4646

0 commit comments

Comments
 (0)