Skip to content

Commit d6df0cb

Browse files
authored
Merge branch 'main' into gh/helunwencser/74/orig
2 parents 8e2d359 + 671f9c5 commit d6df0cb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+2170
-203
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,6 @@
6464
[submodule "third-party/pybind11"]
6565
path = third-party/pybind11
6666
url = https://github.com/pybind/pybind11.git
67+
[submodule "third-party/ao"]
68+
path = third-party/ao
69+
url = https://github.com/pytorch/ao.git

backends/arm/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,18 @@ python_library(
7070
],
7171
)
7272

73+
python_library(
74+
name = "tosa_specification",
75+
srcs = [
76+
"tosa_specification.py",
77+
],
78+
typing = True,
79+
deps = [
80+
"fbsource//third-party/pypi/packaging:packaging",
81+
"//executorch/exir/backend:compile_spec_schema",
82+
],
83+
)
84+
7385
python_library(
7486
name = "tosa_utils",
7587
srcs = [

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
get_first_fake_tensor,
1515
insert_q_dq_pair,
1616
)
17-
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
17+
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, register_passable_op
1818
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
1919
from executorch.exir.dialects._ops import ops as exir_ops
2020
from executorch.exir.pass_base import ExportPass, PassResult
@@ -42,6 +42,9 @@ def _transpose_impl(*args, **kwargs):
4242
return args[0]
4343

4444

45+
register_passable_op(torch.ops.passthrough_to_tosa._transpose)
46+
47+
4548
class AnnotateChannelsLastDimOrder(ExportPass):
4649
"""
4750
Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order

backends/arm/_passes/insert_squeeze_after_sum_pass.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88

99
import torch
1010
import torch.fx
11-
from executorch.backends.arm._passes.arm_pass_utils import create_node, insert_q_dq_pair
12-
13-
from executorch.backends.arm.tosa_quant_utils import get_quant_node_args, is_quant_node
11+
from executorch.backends.arm._passes.arm_pass_utils import create_node
1412
from executorch.exir.dialects._ops import ops as exir_ops
1513
from executorch.exir.pass_base import ExportPass, PassResult
1614

@@ -28,8 +26,6 @@ class InsertSqueezeAfterSumPass(ExportPass):
2826
sum(dims, keep_dim = False)
2927
After pass:
3028
sum(dims, keep_dim = True)
31-
(q)
32-
(dq)
3329
squeeze(dim = dims)
3430
"""
3531

@@ -45,12 +41,6 @@ def call(self, graph_module: torch.fx.GraphModule):
4541
continue
4642

4743
dim_list = cast(list[int], sum_node.args[1])
48-
quantized = is_quant_node(sum_node)
49-
if quantized:
50-
qparams = get_quant_node_args(sum_node.all_input_nodes[0])
51-
qparams = qparams + (torch.int8,)
52-
else:
53-
qparams = None
5444

5545
# Add keep_dim = True arg to sum node.
5646
sum_node.args = sum_node.args[0:2] + (True,)
@@ -61,8 +51,6 @@ def call(self, graph_module: torch.fx.GraphModule):
6151
)
6252
sum_node.replace_all_uses_with(squeeze_node)
6353
squeeze_node.args = (sum_node, dim_list)
64-
if quantized:
65-
sum_node = insert_q_dq_pair(graph_module.graph, sum_node, qparams)
6654
graph_module.graph.eliminate_dead_code()
6755
graph_module.recompile()
6856
graph_module = super().call(graph_module).graph_module

backends/arm/_passes/size_adjust_conv2d_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import cast, Optional
1010

1111
import torch.fx
12-
from executorch.backends.arm.tosa_quant_utils import is_quant_node
12+
from executorch.backends.arm.tosa_quant_utils import is_node_quantized
1313
from executorch.exir.dialects._ops import ops as exir_ops
1414
from executorch.exir.pass_base import ExportPass, PassResult
1515
from torch._ops import OpOverload
@@ -113,7 +113,7 @@ def call(self, graph_module: torch.fx.GraphModule):
113113
slice_node = graph.create_node(
114114
"call_function", self.slice_op, (last_node,) + args
115115
)
116-
if is_quant_node(last_node):
116+
if is_node_quantized(last_node):
117117
q_params = last_node.args[1:]
118118
dq_node = insert_q_dq_pair(
119119
graph_module.graph, slice_node, q_params

backends/arm/operators/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ python_library(
77
typing = True,
88
deps = [
99
"//executorch/backends/arm:tosa_mapping",
10+
"//executorch/backends/arm:tosa_specification",
1011
],
1112
)
1213

backends/arm/operators/op_bmm.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
register_node_visitor,
1515
)
1616
from executorch.backends.arm.tosa_mapping import TosaArg
17-
from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args
17+
from executorch.backends.arm.tosa_quant_utils import (
18+
build_rescale,
19+
get_quant_arg_downstream,
20+
get_quant_arg_upstream,
21+
)
1822
from executorch.backends.arm.tosa_utils import get_two_inputs
1923
from serializer.tosa_serializer import TosaOp
2024

@@ -42,8 +46,10 @@ def define_node(
4246
# For INT8, we need to get the zero points and add an intermediate tensor
4347
# for a later rescale.
4448
if is_quant_node:
45-
input0_zp = get_quant_node_args(input0).zp
46-
input1_zp = get_quant_node_args(input1).zp
49+
input0_q_params = get_quant_arg_upstream(input0)
50+
input1_q_params = get_quant_arg_upstream(input1)
51+
input0_zp = input0_q_params.zp
52+
input1_zp = input1_q_params.zp
4753
bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
4854
bmm_output_name = bmm_result.name
4955
else:
@@ -63,9 +69,7 @@ def define_node(
6369

6470
# As INT8 accumulates into INT32, we need to rescale it back to INT8
6571
if is_quant_node:
66-
input0_q_params = get_quant_node_args(input0)
67-
input1_q_params = get_quant_node_args(input1)
68-
output_q_params = get_quant_node_args(list(node.users)[0])
72+
output_q_params = get_quant_arg_downstream(list(node.users)[0])
6973

7074
final_output_scale = (
7175
input0_q_params.scale * input1_q_params.scale

backends/arm/operators/op_conv2d.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import cast, List
7+
from typing import List
88

99
import serializer.tosa_serializer as ts
1010
import torch
@@ -15,9 +15,10 @@
1515
from executorch.backends.arm.tosa_mapping import TosaArg
1616
from executorch.backends.arm.tosa_quant_utils import (
1717
build_rescale_conv_output,
18-
get_quant_node_args,
18+
get_quant_arg_downstream,
19+
get_quant_arg_upstream,
1920
)
20-
from executorch.backends.arm.tosa_utils import build_reshape, getNodeArgs, tosa_shape
21+
from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape
2122

2223
from serializer.tosa_serializer import TosaOp
2324

@@ -82,7 +83,7 @@ def define_node(
8283
)
8384

8485
input_zp = (
85-
get_quant_node_args(node.all_input_nodes[0]).zp if is_quant_node else 0
86+
get_quant_arg_upstream(node.all_input_nodes[0]).zp if is_quant_node else 0
8687
)
8788

8889
attr.ConvAttribute(
@@ -158,9 +159,10 @@ def define_node(
158159
# integer value domain of the next op. Otherwise return float32 output.
159160
if is_quant_node:
160161
# Get scale_factor from input, weight, and output.
161-
_, input_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[0]))
162-
_, weight_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[1]))
163-
_, output_scale, output_zp, _, _, _ = getNodeArgs(list(node.users)[0])
162+
input_scale = get_quant_arg_upstream(node.all_input_nodes[0]).scale
163+
weight_scale = get_quant_arg_upstream(node.all_input_nodes[1]).scale
164+
output_qargs = get_quant_arg_downstream(list(node.users)[0])
165+
164166
build_rescale_conv_output(
165167
tosa_graph,
166168
# pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined.
@@ -169,6 +171,6 @@ def define_node(
169171
actual_out_type,
170172
input_scale,
171173
weight_scale,
172-
output_scale,
173-
output_zp,
174+
output_qargs.scale,
175+
output_qargs.zp,
174176
)

backends/arm/operators/op_exp.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
from executorch.backends.arm.tosa_quant_utils import (
1919
dequantize_value,
20-
get_quant_node_args,
20+
get_quant_arg_downstream,
21+
get_quant_arg_upstream,
2122
QuantArgs,
2223
quantize_value,
2324
)
@@ -48,9 +49,9 @@ def define_node(
4849

4950
# Create attribute for 8 bit table lookup.
5051
input_node = node.all_input_nodes[0]
51-
in_quantargs = get_quant_node_args(input_node)
52+
in_quantargs = get_quant_arg_upstream(input_node)
5253
output_node = list(node.users)[0]
53-
out_quantargs = get_quant_node_args(output_node)
54+
out_quantargs = get_quant_arg_downstream(output_node)
5455

5556
table = exp_table_8bit(in_quantargs, out_quantargs)
5657
table_attr = ts.TosaSerializerAttribute()

backends/arm/operators/op_full.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
register_node_visitor,
1515
)
1616
from executorch.backends.arm.tosa_mapping import TosaArg
17-
from executorch.backends.arm.tosa_quant_utils import get_quant_node_args
17+
from executorch.backends.arm.tosa_quant_utils import (
18+
get_quant_arg_downstream,
19+
quantize_value,
20+
)
1821
from executorch.backends.arm.tosa_utils import tosa_shape
1922
from torch.fx import Node
2023

@@ -39,10 +42,8 @@ def define_node(
3942

4043
value = inputs[1].number
4144
if is_quant_node:
42-
qargs = get_quant_node_args(list(node.users)[0])
43-
qvalue = np.clip(
44-
np.round(value / qargs.scale) + qargs.zp, qargs.qmin, qargs.qmax
45-
)
45+
qargs = get_quant_arg_downstream(list(node.users)[0])
46+
qvalue = quantize_value(value, qargs)
4647
dtype = ts.DType.INT8
4748
data = np.full(shape, qvalue, dtype=np.int8)
4849
else:

backends/arm/operators/op_hardtanh.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
)
1515
from executorch.backends.arm.tosa_mapping import TosaArg
1616

17-
from executorch.backends.arm.tosa_quant_utils import get_quant_node_args
17+
from executorch.backends.arm.tosa_quant_utils import (
18+
get_quant_arg_upstream,
19+
quantize_value,
20+
)
1821
from serializer.tosa_serializer import TosaOp
1922

2023

@@ -37,12 +40,10 @@ def define_node(
3740

3841
if is_quant_node:
3942
# Get quant parameters
40-
scale, zp, qmin, qmax = get_quant_node_args(node.all_input_nodes[0])
43+
qargs = get_quant_arg_upstream(node.all_input_nodes[0])
4144
# Convert to quantized representation
42-
clamp_min_qs = round((inputs[1].number / scale) + zp)
43-
clamp_min_qs = max(clamp_min_qs, qmin)
44-
clamp_max_qs = round((inputs[2].number / scale) + zp)
45-
clamp_max_qs = min(clamp_max_qs, qmax)
45+
clamp_min_qs = quantize_value(inputs[1].number, qargs)
46+
clamp_max_qs = quantize_value(inputs[2].number, qargs)
4647
# Set fp values to 0.0 since they are not used
4748
clamp_min_fp = 0.0
4849
clamp_max_fp = 0.0

backends/arm/operators/op_log.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
from executorch.backends.arm.tosa_quant_utils import (
1919
dequantize_value,
20-
get_quant_node_args,
20+
get_quant_arg_downstream,
21+
get_quant_arg_upstream,
2122
QuantArgs,
2223
quantize_value,
2324
)
@@ -49,9 +50,9 @@ def define_node(
4950

5051
# Create attribute for 8 bit table lookup.
5152
input_node = node.all_input_nodes[0]
52-
in_quantargs = get_quant_node_args(input_node)
53+
in_quantargs = get_quant_arg_upstream(input_node)
5354
output_node = list(node.users)[0]
54-
out_quantargs = get_quant_node_args(output_node)
55+
out_quantargs = get_quant_arg_downstream(output_node)
5556

5657
table = log_table_8bit(in_quantargs, out_quantargs)
5758
table_attr = ts.TosaSerializerAttribute()

backends/arm/operators/op_max_pool2d.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
register_node_visitor,
1414
)
1515
from executorch.backends.arm.tosa_mapping import TosaArg
16-
from executorch.backends.arm.tosa_utils import get_quant_node_args
16+
from executorch.backends.arm.tosa_utils import (
17+
get_quant_arg_downstream,
18+
get_quant_arg_upstream,
19+
)
1720

1821
from serializer.tosa_serializer import TosaOp
1922

@@ -54,8 +57,8 @@ def define_node(
5457
output_zp = 0
5558

5659
if is_quant_node:
57-
input_zp = get_quant_node_args(node.all_input_nodes[0]).zp
58-
output_zp = get_quant_node_args(list(node.users)[0]).zp
60+
input_zp = get_quant_arg_upstream(node.all_input_nodes[0]).zp
61+
output_zp = get_quant_arg_downstream(list(node.users)[0]).zp
5962

6063
attr = ts.TosaSerializerAttribute()
6164
attr.PoolAttribute(

backends/arm/operators/op_mm.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
register_node_visitor,
1515
)
1616
from executorch.backends.arm.tosa_mapping import TosaArg
17-
from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args
17+
from executorch.backends.arm.tosa_quant_utils import (
18+
build_rescale,
19+
get_quant_arg_downstream,
20+
get_quant_arg_upstream,
21+
)
1822
from executorch.backends.arm.tosa_utils import (
1923
build_reshape,
2024
expand_dims,
@@ -54,8 +58,8 @@ def define_node(
5458
# For INT8, we need to get the zero point, otherwise it is 0
5559
input0_zp, input1_zp = 0, 0
5660
if is_quant_node:
57-
input0_zp = get_quant_node_args(input0).zp
58-
input1_zp = get_quant_node_args(input1).zp
61+
input0_zp = get_quant_arg_upstream(input0).zp
62+
input1_zp = get_quant_arg_upstream(input1).zp
5963

6064
mat_mul_result = tosa_graph.addIntermediate(
6165
output_new_shape, ts.DType.INT32 if is_quant_node else output.dtype
@@ -86,9 +90,9 @@ def define_node(
8690

8791
# As INT8 accumulates into INT32, we need to rescale it back to INT8
8892
if is_quant_node:
89-
input0_q_params = get_quant_node_args(input0)
90-
input1_q_params = get_quant_node_args(input1)
91-
output_q_params = get_quant_node_args(list(node.users)[0])
93+
input0_q_params = get_quant_arg_upstream(input0)
94+
input1_q_params = get_quant_arg_upstream(input1)
95+
output_q_params = get_quant_arg_downstream(list(node.users)[0])
9296

9397
final_output_scale = (
9498
input0_q_params.scale * input1_q_params.scale

backends/arm/operators/op_mul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ def define_node(
3737
if is_quant_node:
3838
input_A = inputs[0]
3939
input_B = inputs[1]
40-
input_A_qargs = tqutils.get_quant_node_args(
40+
input_A_qargs = tqutils.get_quant_arg_upstream(
4141
cast(torch.fx.Node, node.args[0])
4242
)
43-
input_B_qargs = tqutils.get_quant_node_args(
43+
input_B_qargs = tqutils.get_quant_arg_upstream(
4444
cast(torch.fx.Node, node.args[1])
4545
)
4646

0 commit comments

Comments
 (0)