Skip to content

Revert "Remove unused functions for quantization handling" #7724

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand All @@ -15,7 +15,7 @@
get_node_arg,
insert_q_dq_pair,
)
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, register_passable_op
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
Expand Down Expand Up @@ -43,6 +43,9 @@ def _transpose_impl(*args, **kwargs):
return args[0]


register_passable_op(torch.ops.passthrough_to_tosa._transpose)


class AnnotateChannelsLastDimOrder(ExportPass):
"""
Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
op_bmm,
op_cat,
op_conv2d,
op_dequant,
op_exp,
op_full,
op_get_item,
Expand All @@ -23,6 +24,7 @@
op_min,
op_mul,
op_permute,
op_quant,
op_reciprocal,
op_relu,
op_repeat,
Expand Down
35 changes: 35 additions & 0 deletions backends/arm/operators/op_dequant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2023-2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import List

import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from serializer.tosa_serializer import TosaOp


@register_node_visitor
class DequantVisitor(NodeVisitor):
target = "quantized_decomposed.dequantize_per_tensor.default"

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
item_name = inputs[0].name
## Simply add an identityOp
tosa_graph.addOperator(TosaOp.Op().IDENTITY, [item_name], [output.name])
7 changes: 4 additions & 3 deletions backends/arm/operators/op_hardtanh.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023-2025 Arm Limited and/or its affiliates.
# Copyright 2023-2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -19,6 +19,7 @@
)
from executorch.backends.arm.tosa_mapping import TosaArg

from executorch.backends.arm.tosa_quant_utils import quantize_value
from serializer.tosa_serializer import TosaOp


Expand All @@ -43,8 +44,8 @@ def define_node(
input_qparams = get_input_qparams(node) # pyre-ignore[16]
qargs = input_qparams[0]
# Convert to quantized representation
clamp_min_qs = qargs.quantize_value(inputs[1].number).item()
clamp_max_qs = qargs.quantize_value(inputs[2].number).item()
clamp_min_qs = quantize_value(inputs[1].number, qargs)
clamp_max_qs = quantize_value(inputs[2].number, qargs)
# Set fp values to 0.0 since they are not used
clamp_min_fp = 0.0
clamp_max_fp = 0.0
Expand Down
35 changes: 35 additions & 0 deletions backends/arm/operators/op_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2023-2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import List

import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from serializer.tosa_serializer import TosaOp


@register_node_visitor
class QuantVisitor(NodeVisitor):
target = "quantized_decomposed.quantize_per_tensor.default"

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
item_name = inputs[0].name
## Simply add an identityOp
tosa_graph.addOperator(TosaOp.Op().IDENTITY, [item_name], [output.name])
8 changes: 5 additions & 3 deletions backends/arm/operators/op_relu.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import executorch.backends.arm.tosa_quant_utils as tqutils
import serializer.tosa_serializer as ts
import torch.fx

Expand Down Expand Up @@ -42,8 +43,9 @@ def define_node(
clamp_max_qs = 0
if inputs[0].dtype == ts.DType.INT8:
out_qargs = get_output_qparams(node) # pyre-ignore[16]
clamp_min_qs = out_qargs[0].quantize_value(0).item()
clamp_max_qs = out_qargs[0].quantize_value(float("inf")).item()
clamp_min_qs = tqutils.quantize_value(0, out_qargs[0])
clamp_max_qs = tqutils.quantize_value(float("inf"), out_qargs[0])

else:
clamp_min_fp = 0
clamp_max_fp = float("inf")
Expand Down
22 changes: 19 additions & 3 deletions backends/arm/process_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
import torch
import torch.fx
from executorch.backends.arm.operators.node_visitor import NodeVisitor
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
from executorch.backends.arm.tosa_quant_utils import (
dq_op,
get_quantized_node_output_dtype,
is_node_quantized,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape
from torch.export.exported_program import ExportedProgram
Expand All @@ -30,8 +35,15 @@ def process_call_function(
# Convert output (this node itself)
output = TosaArg(node)

is_dq_node = node.target == dq_op
if is_dq_node:
output_dtype = ts.DType.INT8
else:
output_dtype = output.dtype
tosa_graph.currRegion.currBasicBlock.addTensor(
output.name, tosa_shape(output.shape, output.dim_order), output.dtype
output.name,
tosa_shape(output.shape, output.dim_order),
output_dtype,
)

# Visiting each Node
Expand Down Expand Up @@ -67,7 +79,11 @@ def process_inputs(
tensor = ts.TosaSerializerTensor(
inputs[0].name,
tosa_shape(input_shape, input_dim_order),
inputs[0].dtype,
(
map_dtype(get_quantized_node_output_dtype(node))
if is_node_quantized(node)
else inputs[0].dtype
),
data=None,
placeholderFilename=inputs[0].name + ".npy",
)
Expand Down
Loading
Loading