Skip to content

Commit feb2354

Browse files
committed
Fix type-checking issues in Arm backend
Signed-off-by: Erik Lundell <[email protected]> Change-Id: I4feb4b5c6d269d7c0ff4312c17fec52da413fa5a
1 parent 41b60aa commit feb2354

20 files changed

+111
-82
lines changed

backends/arm/operators/op_addmm.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
from executorch.backends.arm.tosa_mapping import TosaArg
1717
from executorch.backends.arm.tosa_quant_utils import (
1818
build_rescale,
19-
search_quant_arg_downstream,
20-
search_quant_arg_upstream,
19+
get_quant_arg_downstream,
20+
get_quant_arg_upstream,
2121
)
2222

2323
from executorch.backends.arm.tosa_utils import build_reshape
@@ -70,7 +70,7 @@ def define_node(
7070
input_zp = 0
7171
if is_quant_node:
7272
input_node = node.all_input_nodes[1]
73-
input_zp = search_quant_arg_upstream(input_node).zp
73+
input_zp = get_quant_arg_upstream(input_node).zp
7474
attr.ConvAttribute(
7575
pad=pad_attr,
7676
stride=stride_attr,
@@ -105,16 +105,16 @@ def define_node(
105105
# Read inputs' parent nodes
106106
_, input_node, weight_node = node.all_input_nodes
107107

108-
qargs = search_quant_arg_upstream(input_node)
108+
qargs = get_quant_arg_upstream(input_node)
109109
input_scale = qargs.scale
110110
consumer_node = list(node.users)[0]
111-
quant_args = search_quant_arg_downstream(consumer_node)
111+
quant_args = get_quant_arg_downstream(consumer_node)
112112

113113
consumer_node_scale = quant_args.scale
114114
consumer_node_node_zp = quant_args.zp
115115

116116
weight_node_q_node = weight_node.all_input_nodes[0]
117-
weight_scale = search_quant_arg_upstream(weight_node_q_node).scale
117+
weight_scale = get_quant_arg_upstream(weight_node_q_node).scale
118118

119119
output_rescale_scale = (input_scale * weight_scale) / consumer_node_scale
120120

backends/arm/operators/op_bmm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
from executorch.backends.arm.tosa_mapping import TosaArg
1717
from executorch.backends.arm.tosa_quant_utils import (
1818
build_rescale,
19-
search_quant_arg_downstream,
20-
search_quant_arg_upstream,
19+
get_quant_arg_downstream,
20+
get_quant_arg_upstream,
2121
)
2222
from executorch.backends.arm.tosa_utils import get_two_inputs
2323
from serializer.tosa_serializer import TosaOp
@@ -46,8 +46,8 @@ def define_node(
4646
# For INT8, we need to get the zero points and add an intermediate tensor
4747
# for a later rescale.
4848
if is_quant_node:
49-
input0_q_params = search_quant_arg_upstream(input0)
50-
input1_q_params = search_quant_arg_upstream(input1)
49+
input0_q_params = get_quant_arg_upstream(input0)
50+
input1_q_params = get_quant_arg_upstream(input1)
5151
input0_zp = input0_q_params.zp
5252
input1_zp = input1_q_params.zp
5353
bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
@@ -69,7 +69,7 @@ def define_node(
6969

7070
# As INT8 accumulates into INT32, we need to rescale it back to INT8
7171
if is_quant_node:
72-
output_q_params = search_quant_arg_downstream(list(node.users)[0])
72+
output_q_params = get_quant_arg_downstream(list(node.users)[0])
7373

7474
final_output_scale = (
7575
input0_q_params.scale * input1_q_params.scale

backends/arm/operators/op_conv2d.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from executorch.backends.arm.tosa_mapping import TosaArg
1616
from executorch.backends.arm.tosa_quant_utils import (
1717
build_rescale_conv_output,
18-
search_quant_arg_downstream,
19-
search_quant_arg_upstream,
18+
get_quant_arg_downstream,
19+
get_quant_arg_upstream,
2020
)
2121
from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape
2222

@@ -83,9 +83,7 @@ def define_node(
8383
)
8484

8585
input_zp = (
86-
search_quant_arg_upstream(node.all_input_nodes[0]).zp
87-
if is_quant_node
88-
else 0
86+
get_quant_arg_upstream(node.all_input_nodes[0]).zp if is_quant_node else 0
8987
)
9088

9189
attr.ConvAttribute(
@@ -161,9 +159,9 @@ def define_node(
161159
# integer value domain of the next op. Otherwise return float32 output.
162160
if is_quant_node:
163161
# Get scale_factor from input, weight, and output.
164-
input_scale = search_quant_arg_upstream(node.all_input_nodes[0]).scale
165-
weight_scale = search_quant_arg_upstream(node.all_input_nodes[1]).scale
166-
output_qargs = search_quant_arg_downstream(list(node.users)[0])
162+
input_scale = get_quant_arg_upstream(node.all_input_nodes[0]).scale
163+
weight_scale = get_quant_arg_upstream(node.all_input_nodes[1]).scale
164+
output_qargs = get_quant_arg_downstream(list(node.users)[0])
167165

168166
build_rescale_conv_output(
169167
tosa_graph,

backends/arm/operators/op_exp.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
from executorch.backends.arm.tosa_quant_utils import (
1919
dequantize_value,
20+
get_quant_arg_downstream,
21+
get_quant_arg_upstream,
2022
QuantArgs,
2123
quantize_value,
22-
search_quant_arg_downstream,
23-
search_quant_arg_upstream,
2424
)
2525
from serializer.tosa_serializer import TosaOp
2626
from torch.fx import Node
@@ -49,9 +49,9 @@ def define_node(
4949

5050
# Create attribute for 8 bit table lookup.
5151
input_node = node.all_input_nodes[0]
52-
in_quantargs = search_quant_arg_upstream(input_node)
52+
in_quantargs = get_quant_arg_upstream(input_node)
5353
output_node = list(node.users)[0]
54-
out_quantargs = search_quant_arg_downstream(output_node)
54+
out_quantargs = get_quant_arg_downstream(output_node)
5555

5656
table = exp_table_8bit(in_quantargs, out_quantargs)
5757
table_attr = ts.TosaSerializerAttribute()

backends/arm/operators/op_full.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
)
1616
from executorch.backends.arm.tosa_mapping import TosaArg
1717
from executorch.backends.arm.tosa_quant_utils import (
18+
get_quant_arg_downstream,
1819
quantize_value,
19-
search_quant_arg_downstream,
2020
)
2121
from executorch.backends.arm.tosa_utils import tosa_shape
2222
from torch.fx import Node
@@ -42,7 +42,7 @@ def define_node(
4242

4343
value = inputs[1].number
4444
if is_quant_node:
45-
qargs = search_quant_arg_downstream(list(node.users)[0])
45+
qargs = get_quant_arg_downstream(list(node.users)[0])
4646
qvalue = quantize_value(value, qargs)
4747
dtype = ts.DType.INT8
4848
data = np.full(shape, qvalue, dtype=np.int8)

backends/arm/operators/op_hardtanh.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from executorch.backends.arm.tosa_mapping import TosaArg
1616

1717
from executorch.backends.arm.tosa_quant_utils import (
18+
get_quant_arg_upstream,
1819
quantize_value,
19-
search_quant_arg_upstream,
2020
)
2121
from serializer.tosa_serializer import TosaOp
2222

@@ -40,7 +40,7 @@ def define_node(
4040

4141
if is_quant_node:
4242
# Get quant parameters
43-
qargs = search_quant_arg_upstream(node.all_input_nodes[0])
43+
qargs = get_quant_arg_upstream(node.all_input_nodes[0])
4444
# Convert to quantized representation
4545
clamp_min_qs = quantize_value(inputs[1].number, qargs)
4646
clamp_max_qs = quantize_value(inputs[2].number, qargs)

backends/arm/operators/op_log.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
from executorch.backends.arm.tosa_quant_utils import (
1919
dequantize_value,
20+
get_quant_arg_downstream,
21+
get_quant_arg_upstream,
2022
QuantArgs,
2123
quantize_value,
22-
search_quant_arg_downstream,
23-
search_quant_arg_upstream,
2424
)
2525
from serializer.tosa_serializer import TosaOp
2626
from torch.fx import Node
@@ -50,9 +50,9 @@ def define_node(
5050

5151
# Create attribute for 8 bit table lookup.
5252
input_node = node.all_input_nodes[0]
53-
in_quantargs = search_quant_arg_upstream(input_node)
53+
in_quantargs = get_quant_arg_upstream(input_node)
5454
output_node = list(node.users)[0]
55-
out_quantargs = search_quant_arg_downstream(output_node)
55+
out_quantargs = get_quant_arg_downstream(output_node)
5656

5757
table = log_table_8bit(in_quantargs, out_quantargs)
5858
table_attr = ts.TosaSerializerAttribute()

backends/arm/operators/op_max_pool2d.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
)
1515
from executorch.backends.arm.tosa_mapping import TosaArg
1616
from executorch.backends.arm.tosa_utils import (
17-
search_quant_arg_downstream,
18-
search_quant_arg_upstream,
17+
get_quant_arg_downstream,
18+
get_quant_arg_upstream,
1919
)
2020

2121
from serializer.tosa_serializer import TosaOp
@@ -57,10 +57,10 @@ def define_node(
5757
output_zp = 0
5858

5959
if is_quant_node:
60-
input_zp = search_quant_arg_upstream(
60+
input_zp = get_quant_arg_upstream(
6161
cast(torch.fx.Node, node.all_input_nodes[0])
6262
).zp
63-
output_zp = search_quant_arg_downstream(list(node.users)[0]).zp
63+
output_zp = get_quant_arg_downstream(list(node.users)[0]).zp
6464

6565
attr = ts.TosaSerializerAttribute()
6666
attr.PoolAttribute(

backends/arm/operators/op_mm.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
from executorch.backends.arm.tosa_mapping import TosaArg
1717
from executorch.backends.arm.tosa_quant_utils import (
1818
build_rescale,
19-
search_quant_arg_downstream,
20-
search_quant_arg_upstream,
19+
get_quant_arg_downstream,
20+
get_quant_arg_upstream,
2121
)
2222
from executorch.backends.arm.tosa_utils import (
2323
build_reshape,
@@ -58,8 +58,8 @@ def define_node(
5858
# For INT8, we need to get the zero point, otherwise it is 0
5959
input0_zp, input1_zp = 0, 0
6060
if is_quant_node:
61-
input0_zp = search_quant_arg_upstream(input0).zp
62-
input1_zp = search_quant_arg_upstream(input1).zp
61+
input0_zp = get_quant_arg_upstream(input0).zp
62+
input1_zp = get_quant_arg_upstream(input1).zp
6363

6464
mat_mul_result = tosa_graph.addIntermediate(
6565
output_new_shape, ts.DType.INT32 if is_quant_node else output.dtype
@@ -90,9 +90,9 @@ def define_node(
9090

9191
# As INT8 accumulates into INT32, we need to rescale it back to INT8
9292
if is_quant_node:
93-
input0_q_params = search_quant_arg_upstream(input0)
94-
input1_q_params = search_quant_arg_upstream(input1)
95-
output_q_params = search_quant_arg_downstream(list(node.users)[0])
93+
input0_q_params = get_quant_arg_upstream(input0)
94+
input1_q_params = get_quant_arg_upstream(input1)
95+
output_q_params = get_quant_arg_downstream(list(node.users)[0])
9696

9797
final_output_scale = (
9898
input0_q_params.scale * input1_q_params.scale

backends/arm/operators/op_mul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ def define_node(
3737
if is_quant_node:
3838
input_A = inputs[0]
3939
input_B = inputs[1]
40-
input_A_qargs = tqutils.search_quant_arg_upstream(
40+
input_A_qargs = tqutils.get_quant_arg_upstream(
4141
cast(torch.fx.Node, node.args[0])
4242
)
43-
input_B_qargs = tqutils.search_quant_arg_upstream(
43+
input_B_qargs = tqutils.get_quant_arg_upstream(
4444
cast(torch.fx.Node, node.args[1])
4545
)
4646

backends/arm/operators/op_placeholder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
import torch.fx
1111
from executorch.backends.arm.tosa_mapping import TosaArg
1212
from executorch.backends.arm.tosa_quant_utils import (
13+
get_quant_arg_upstream,
1314
get_quantized_node_output_dtype,
1415
is_node_quantized,
15-
search_quant_arg_upstream,
1616
)
1717
from executorch.backends.arm.tosa_utils import (
1818
is_bias_node_for_quantized_addmm,
@@ -80,8 +80,8 @@ def process_quantized_bias(
8080
_,
8181
) = consumer_node.all_input_nodes
8282

83-
input_node_scale = search_quant_arg_upstream(input_node).scale
84-
weight_node_scale = search_quant_arg_upstream(weight_node).scale
83+
input_node_scale = get_quant_arg_upstream(input_node).scale
84+
weight_node_scale = get_quant_arg_upstream(weight_node).scale
8585
bias_values_quantized = (
8686
(parameter_values / (input_node_scale * weight_node_scale))
8787
.round()

backends/arm/operators/op_reciprocal.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
from executorch.backends.arm.tosa_mapping import TosaArg
1616
from executorch.backends.arm.tosa_quant_utils import (
1717
dequantize_value,
18+
get_quant_arg_downstream,
19+
get_quant_arg_upstream,
1820
QuantArgs,
1921
quantize_value,
20-
search_quant_arg_downstream,
21-
search_quant_arg_upstream,
2222
)
2323
from serializer.tosa_serializer import TosaOp
2424

@@ -42,8 +42,8 @@ def define_node(
4242

4343
if is_quant_node:
4444
input = inputs[0]
45-
input_qargs = search_quant_arg_upstream(node.all_input_nodes[0])
46-
output_qargs = search_quant_arg_downstream(list(node.users)[0])
45+
input_qargs = get_quant_arg_upstream(node.all_input_nodes[0])
46+
output_qargs = get_quant_arg_downstream(list(node.users)[0])
4747

4848
div_table = div_table_8bit(input_qargs, output_qargs)
4949

backends/arm/operators/op_relu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def define_node(
3838
clamp_min_qs = 0
3939
clamp_max_qs = 0
4040
if is_quant_node:
41-
out_qargs = tqutils.search_quant_arg_downstream(list(node.users)[0])
41+
out_qargs = tqutils.get_quant_arg_downstream(list(node.users)[0])
4242
clamp_min_qs = tqutils.quantize_value(0, out_qargs)
4343
clamp_max_qs = tqutils.quantize_value(float("inf"), out_qargs)
4444

backends/arm/operators/op_rsqrt.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
from executorch.backends.arm.tosa_mapping import TosaArg
1717
from executorch.backends.arm.tosa_quant_utils import (
1818
dequantize_value,
19+
get_quant_arg_downstream,
20+
get_quant_arg_upstream,
1921
QuantArgs,
2022
quantize_value,
21-
search_quant_arg_downstream,
22-
search_quant_arg_upstream,
2323
)
2424
from serializer.tosa_serializer import TosaOp
2525

@@ -40,9 +40,9 @@ def define_node(
4040
# Assume quantized input is 8 bit.
4141
# Create attribute for 8 bit table lookup.
4242
input_node = node.all_input_nodes[0]
43-
in_quantargs = search_quant_arg_upstream(input_node)
43+
in_quantargs = get_quant_arg_upstream(input_node)
4444
output_node = list(node.users)[0]
45-
out_quantargs = search_quant_arg_downstream(output_node)
45+
out_quantargs = get_quant_arg_downstream(output_node)
4646
table = rsqrt_table_8bit(in_quantargs, out_quantargs)
4747
table_attr = ts.TosaSerializerAttribute()
4848
table_attr.TableAttribute(table)

backends/arm/operators/op_sigmoid.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
from executorch.backends.arm.tosa_quant_utils import (
1919
dequantize_value,
20+
get_quant_arg_downstream,
21+
get_quant_arg_upstream,
2022
QuantArgs,
2123
quantize_value,
22-
search_quant_arg_downstream,
23-
search_quant_arg_upstream,
2424
)
2525
from serializer.tosa_serializer import TosaOp
2626
from torch.fx import Node
@@ -50,9 +50,9 @@ def define_node(
5050

5151
# Create attribute for 8 bit table lookup.
5252
input_node = node.all_input_nodes[0]
53-
in_quantargs = search_quant_arg_upstream(input_node)
53+
in_quantargs = get_quant_arg_upstream(input_node)
5454
output_node = list(node.users)[0]
55-
out_quantargs = search_quant_arg_downstream(output_node)
55+
out_quantargs = get_quant_arg_downstream(output_node)
5656

5757
table = sigmoid_table_8bit(in_quantargs, out_quantargs)
5858
table_attr = ts.TosaSerializerAttribute()

backends/arm/operators/op_tanh.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
from executorch.backends.arm.tosa_quant_utils import (
1919
dequantize_value,
20+
get_quant_arg_downstream,
21+
get_quant_arg_upstream,
2022
QuantArgs,
2123
quantize_value,
22-
search_quant_arg_downstream,
23-
search_quant_arg_upstream,
2424
)
2525
from serializer.tosa_serializer import TosaOp
2626
from torch.fx import Node
@@ -50,9 +50,9 @@ def define_node(
5050

5151
# Create attribute for 8 bit table lookup.
5252
input_node = node.all_input_nodes[0]
53-
in_quantargs = search_quant_arg_upstream(input_node)
53+
in_quantargs = get_quant_arg_upstream(input_node)
5454
output_node = list(node.users)[0]
55-
out_quantargs = search_quant_arg_downstream(output_node)
55+
out_quantargs = get_quant_arg_downstream(output_node)
5656

5757
table = tanh_table_8bit(in_quantargs, out_quantargs)
5858
table_attr = ts.TosaSerializerAttribute()

0 commit comments

Comments
 (0)