Skip to content

Commit eaad7ff

Browse files
Revert "Remove unused functions for quantization handling" (#7724)
Revert "Remove unused functions for quantization handling (#7700)" This reverts commit ffc2020.
1 parent ffc2020 commit eaad7ff

File tree

10 files changed

+399
-21
lines changed

10 files changed

+399
-21
lines changed

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024-2025 Arm Limited and/or its affiliates.
1+
# Copyright 2024 Arm Limited and/or its affiliates.
22
# All rights reserved.
33
#
44
# This source code is licensed under the BSD-style license found in the
@@ -15,7 +15,7 @@
1515
get_node_arg,
1616
insert_q_dq_pair,
1717
)
18-
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
18+
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, register_passable_op
1919
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
2020
from executorch.exir.dialects._ops import ops as exir_ops
2121
from executorch.exir.pass_base import ExportPass, PassResult
@@ -43,6 +43,9 @@ def _transpose_impl(*args, **kwargs):
4343
return args[0]
4444

4545

46+
register_passable_op(torch.ops.passthrough_to_tosa._transpose)
47+
48+
4649
class AnnotateChannelsLastDimOrder(ExportPass):
4750
"""
4851
Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order

backends/arm/operators/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
op_bmm,
1414
op_cat,
1515
op_conv2d,
16+
op_dequant,
1617
op_exp,
1718
op_full,
1819
op_get_item,
@@ -23,6 +24,7 @@
2324
op_min,
2425
op_mul,
2526
op_permute,
27+
op_quant,
2628
op_reciprocal,
2729
op_relu,
2830
op_repeat,

backends/arm/operators/op_dequant.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2023-2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
from typing import List
8+
9+
import serializer.tosa_serializer as ts
10+
import torch
11+
from executorch.backends.arm.operators.node_visitor import (
12+
NodeVisitor,
13+
register_node_visitor,
14+
)
15+
from executorch.backends.arm.tosa_mapping import TosaArg
16+
from serializer.tosa_serializer import TosaOp
17+
18+
19+
@register_node_visitor
20+
class DequantVisitor(NodeVisitor):
21+
target = "quantized_decomposed.dequantize_per_tensor.default"
22+
23+
def __init__(self, *args):
24+
super().__init__(*args)
25+
26+
def define_node(
27+
self,
28+
node: torch.fx.Node,
29+
tosa_graph: ts.TosaSerializer,
30+
inputs: List[TosaArg],
31+
output: TosaArg,
32+
) -> None:
33+
item_name = inputs[0].name
34+
## Simply add an identityOp
35+
tosa_graph.addOperator(TosaOp.Op().IDENTITY, [item_name], [output.name])

backends/arm/operators/op_hardtanh.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023-2025 Arm Limited and/or its affiliates.
1+
# Copyright 2023-2024 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -19,6 +19,7 @@
1919
)
2020
from executorch.backends.arm.tosa_mapping import TosaArg
2121

22+
from executorch.backends.arm.tosa_quant_utils import quantize_value
2223
from serializer.tosa_serializer import TosaOp
2324

2425

@@ -43,8 +44,8 @@ def define_node(
4344
input_qparams = get_input_qparams(node) # pyre-ignore[16]
4445
qargs = input_qparams[0]
4546
# Convert to quantized representation
46-
clamp_min_qs = qargs.quantize_value(inputs[1].number).item()
47-
clamp_max_qs = qargs.quantize_value(inputs[2].number).item()
47+
clamp_min_qs = quantize_value(inputs[1].number, qargs)
48+
clamp_max_qs = quantize_value(inputs[2].number, qargs)
4849
# Set fp values to 0.0 since they are not used
4950
clamp_min_fp = 0.0
5051
clamp_max_fp = 0.0

backends/arm/operators/op_quant.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2023-2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
from typing import List
8+
9+
import serializer.tosa_serializer as ts
10+
import torch
11+
from executorch.backends.arm.operators.node_visitor import (
12+
NodeVisitor,
13+
register_node_visitor,
14+
)
15+
from executorch.backends.arm.tosa_mapping import TosaArg
16+
from serializer.tosa_serializer import TosaOp
17+
18+
19+
@register_node_visitor
20+
class QuantVisitor(NodeVisitor):
21+
target = "quantized_decomposed.quantize_per_tensor.default"
22+
23+
def __init__(self, *args):
24+
super().__init__(*args)
25+
26+
def define_node(
27+
self,
28+
node: torch.fx.Node,
29+
tosa_graph: ts.TosaSerializer,
30+
inputs: List[TosaArg],
31+
output: TosaArg,
32+
) -> None:
33+
item_name = inputs[0].name
34+
## Simply add an identityOp
35+
tosa_graph.addOperator(TosaOp.Op().IDENTITY, [item_name], [output.name])

backends/arm/operators/op_relu.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
# Copyright 2024-2025 Arm Limited and/or its affiliates.
1+
# Copyright 2024 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
77

8+
import executorch.backends.arm.tosa_quant_utils as tqutils
89
import serializer.tosa_serializer as ts
910
import torch.fx
1011

@@ -42,8 +43,9 @@ def define_node(
4243
clamp_max_qs = 0
4344
if inputs[0].dtype == ts.DType.INT8:
4445
out_qargs = get_output_qparams(node) # pyre-ignore[16]
45-
clamp_min_qs = out_qargs[0].quantize_value(0).item()
46-
clamp_max_qs = out_qargs[0].quantize_value(float("inf")).item()
46+
clamp_min_qs = tqutils.quantize_value(0, out_qargs[0])
47+
clamp_max_qs = tqutils.quantize_value(float("inf"), out_qargs[0])
48+
4749
else:
4850
clamp_min_fp = 0
4951
clamp_max_fp = float("inf")

backends/arm/process_node.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212
import torch
1313
import torch.fx
1414
from executorch.backends.arm.operators.node_visitor import NodeVisitor
15-
from executorch.backends.arm.tosa_mapping import TosaArg
15+
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
16+
from executorch.backends.arm.tosa_quant_utils import (
17+
dq_op,
18+
get_quantized_node_output_dtype,
19+
is_node_quantized,
20+
)
1621
from executorch.backends.arm.tosa_specification import TosaSpecification
1722
from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape
1823
from torch.export.exported_program import ExportedProgram
@@ -30,8 +35,15 @@ def process_call_function(
3035
# Convert output (this node itself)
3136
output = TosaArg(node)
3237

38+
is_dq_node = node.target == dq_op
39+
if is_dq_node:
40+
output_dtype = ts.DType.INT8
41+
else:
42+
output_dtype = output.dtype
3343
tosa_graph.currRegion.currBasicBlock.addTensor(
34-
output.name, tosa_shape(output.shape, output.dim_order), output.dtype
44+
output.name,
45+
tosa_shape(output.shape, output.dim_order),
46+
output_dtype,
3547
)
3648

3749
# Visiting each Node
@@ -67,7 +79,11 @@ def process_inputs(
6779
tensor = ts.TosaSerializerTensor(
6880
inputs[0].name,
6981
tosa_shape(input_shape, input_dim_order),
70-
inputs[0].dtype,
82+
(
83+
map_dtype(get_quantized_node_output_dtype(node))
84+
if is_node_quantized(node)
85+
else inputs[0].dtype
86+
),
7187
data=None,
7288
placeholderFilename=inputs[0].name + ".npy",
7389
)

0 commit comments

Comments
 (0)