Skip to content

Revert "Add full operator to fold dq/q handling" #7351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@
DecomposeSoftmaxesPass,
)
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
FoldAndAnnotateQParamsPass,
QuantizeFullArgument,
)
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
KeepDimsFalseToSqueezePass,
)
Expand All @@ -54,7 +50,6 @@
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
from executorch.exir import ExportedProgram
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_manager import PassManager


Expand Down Expand Up @@ -85,19 +80,6 @@ def transform_to_backend_pipeline(
self.add_pass(Conv1dUnsqueezePass(exported_program))
self.add_pass(DecomposeSoftmaxesPass())
self.add_pass(DecomposeLinearPass())
self.add_pass(QuantizeFullArgument())
self.add_pass(
FoldAndAnnotateQParamsPass(
[
exir_ops.edge.aten.minimum.default,
exir_ops.edge.aten.maximum.default,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.full.default,
]
)
)
for spec in compile_spec:
if spec.key == "permute_memory_format":
memory_format = spec.value.decode()
Expand Down
181 changes: 0 additions & 181 deletions backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

This file was deleted.

2 changes: 0 additions & 2 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,6 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.mm.default,
exir_ops.edge.aten.minimum.default,
exir_ops.edge.aten.maximum.default,
exir_ops.edge.aten.repeat.default,
exir_ops.edge.aten.reciprocal.default,
exir_ops.edge.aten.relu.default,
Expand Down
2 changes: 0 additions & 2 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
op_get_item,
op_hardtanh,
op_log,
op_max,
op_max_pool2d,
op_min,
op_mm,
op_mul,
op_permute,
Expand Down
51 changes: 25 additions & 26 deletions backends/arm/operators/op_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import executorch.backends.arm.tosa_utils as tutils

import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
Expand Down Expand Up @@ -40,27 +41,33 @@ def define_node(
output: TosaArg,
is_quant_node: bool,
) -> None:
# Specification (0.80) states that input and output types
# should all be the same
assert inputs[0].dtype == inputs[1].dtype == output.dtype
# Handle int8 (quantized) and int32
assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]

if inputs[0].dtype == ts.DType.INT8:
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
tosa_graph, inputs, node
input_nodes = tutils.get_two_inputs(node)

if not is_quant_node and not all(
tensor.meta["val"].dtype in (torch.int8, torch.int32)
for tensor in input_nodes
):
raise RuntimeError(
f"Unexpected non quantized {AddVisitor_080_BI.target} node."
)
else:
# input[0].dtype == ts.DType.INT32
# Non quantized input, natively support by TOSA.ADD
rescaled_inputs = inputs

if output.dtype == ts.DType.INT8:
needs_rescale = not (
all(tensor.meta["val"].dtype == torch.int32 for tensor in input_nodes)
and node.meta["val"].dtype == torch.int32
)

if needs_rescale:
# Rescale inputs to 32 bit
rescaled_inputs, scale = tqutils.rescale_nodes_to_int32(
input_nodes, tosa_graph
)

# Prepare add output tensor
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
else:
# output.dtype == ts.DType.INT32
add_output = output
rescaled_inputs = inputs

# Do the INT32 Add
tosa_graph.addOperator(
Expand All @@ -73,10 +80,10 @@ def define_node(
None,
)

if output.dtype == ts.DType.INT8:
if needs_rescale:
# Scale output back to 8 bit
# pyre-ignore
tqutils.insert_rescale_op_to_int8(tosa_graph, add_output, scale_back, node)
tqutils.rescale_node_back_to_int8(node, add_output, scale, tosa_graph)


@register_node_visitor
Expand All @@ -98,19 +105,11 @@ def define_node(
output: TosaArg,
is_quant_node: bool,
) -> None:
# Specification (0.80) states that input and output types
# should all be the same
assert inputs[0].dtype == inputs[1].dtype == output.dtype

if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
if is_quant_node:
# Call the inherited define_node for handling integers
super().define_node(node, tosa_graph, inputs, output, is_quant_node)
else:
# FP32 Add lowering
assert inputs[0].dtype == ts.DType.FP32
assert output.dtype == ts.DType.FP32

# MI lowering
tosa_graph.addOperator(
TosaOp.Op().ADD,
[inputs[0].name, inputs[1].name],
Expand Down
Loading
Loading