Skip to content

Arm backend: Add ComputeConstantOpsAOT pass #9504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@
RetraceFoldedDtypesPass,
)
from executorch.backends.arm._passes.fuse_batchnorm2d_pass import FuseBatchnorm2DPass
from executorch.backends.arm._passes.fuse_constant_ops_pass import FuseConstantOpsPass
from executorch.backends.arm._passes.fuse_constant_ops_pass import (
ComputeConstantOpsAOT,
FuseConstantArgsPass,
)
from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( # type: ignore[import-not-found]
FuseQuantizedActivationPass,
)
Expand Down Expand Up @@ -121,21 +124,23 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(QuantizeOperatorArguments())
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
self.add_pass(RetraceFoldedDtypesPass())
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
self.add_pass(MatchArgRanksPass(exported_program))
self.add_pass(ComputeConstantOpsAOT(exported_program))

self.add_pass(RemoveClonePass())
self.add_pass(SizeAdjustConv2DPass())
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(UnsqueezeBeforeRepeatPass())
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
self.add_pass(CastInt64ToInt32Pass(exported_program))
self.add_pass(MatchArgRanksPass(exported_program))
self.add_pass(KeepDimsFalseToSqueezePass())
self.add_pass(Conv1dUnsqueezePass(exported_program))
self.add_pass(DecomposeSelectPass())
self.add_pass(ConvertSqueezesToViewPass())

self.add_pass(FuseViewCopyTransform())
self.add_pass(FuseConstantOpsPass(exported_program))
self.add_pass(FuseConstantArgsPass(exported_program))

self.add_pass(InsertTableOpsPass(exported_program))
self.add_pass(AnnotateChannelsLastDimOrder())
self.add_pass(InsertRescalePass())
Expand Down Expand Up @@ -166,21 +171,22 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(QuantizeOperatorArguments())
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
self.add_pass(RetraceFoldedDtypesPass())
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
self.add_pass(MatchArgRanksPass(exported_program))
self.add_pass(ComputeConstantOpsAOT(exported_program))

self.add_pass(RemoveClonePass())
self.add_pass(SizeAdjustConv2DPass())
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(UnsqueezeBeforeRepeatPass())
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
self.add_pass(CastInt64ToInt32Pass(exported_program))
self.add_pass(MatchArgRanksPass(exported_program))
self.add_pass(KeepDimsFalseToSqueezePass())
self.add_pass(Conv1dUnsqueezePass(exported_program))
self.add_pass(DecomposeSelectPass())
self.add_pass(ConvertSqueezesToViewPass())

self.add_pass(FuseViewCopyTransform())
self.add_pass(FuseConstantOpsPass(exported_program))
self.add_pass(FuseConstantArgsPass(exported_program))
self.add_pass(InsertTableOpsPass(exported_program))
self.add_pass(AnnotateChannelsLastDimOrder())
self.add_pass(InsertRescalePass())
Expand Down
55 changes: 28 additions & 27 deletions backends/arm/_passes/cast_int64_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import logging

import torch
from executorch.backends.arm._passes.arm_pass_utils import is_param_node
from executorch.exir.pass_base import ExportPass, PassResult
from torch._export.utils import is_buffer

Expand All @@ -25,35 +24,37 @@ def __init__(self, exported_program: torch.export.ExportedProgram):
super(CastInt64ToInt32Pass, self).__init__()
self.exported_program = exported_program

def _assert_within_int32(self, tensor: torch.Tensor, node: torch.fx.Node):
if torch.min(tensor) < torch.iinfo(torch.int32).min:
raise RuntimeError(
f"Node {node.name} has value < {torch.iinfo(torch.int32).min}"
)
if torch.max(tensor) > torch.iinfo(torch.int32).max:
raise RuntimeError(
f"Node {node.name} has value > {torch.iinfo(torch.int32).max}"
)

def _to_int32(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
fake_tensor = node.meta["val"]
if isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor):
if node.meta["val"].dtype == torch.int64 and is_param_node(
self.exported_program, node
):
if is_buffer(self.exported_program, node):
node.meta["val"] = node.meta["val"].to(torch.int32)
buffer_name = (
self.exported_program.graph_signature.inputs_to_buffers[
node.name
]
)
buffer = self.exported_program.state_dict[node.name]
logger.warning(
f"Casting buffer {node.name} from torch.int64 to torch.int32"
f" defined in {node.meta['stack_trace']}"
)
if torch.min(buffer) < torch.iinfo(torch.int32).min:
raise RuntimeError(
f"Buffer {node.name} has value < {torch.iinfo(torch.int32).min}"
)
if torch.max(buffer) > torch.iinfo(torch.int32).max:
raise RuntimeError(
f"Buffer {node.name} has value > {torch.iinfo(torch.int32).max}"
)
buffer_int32 = buffer.to(torch.int32)
self.exported_program.state_dict[buffer_name] = buffer_int32
if not isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor):
continue
if fake_tensor.dtype != torch.int64:
continue
if is_buffer(self.exported_program, node):
node.meta["val"] = fake_tensor.to(torch.int32)
buffer_name = self.exported_program.graph_signature.inputs_to_buffers[
node.name
]
buffer = self.exported_program.state_dict[node.name]
self._assert_within_int32(buffer, node)
logger.warning(
f"Casting buffer {node.name} from torch.int64 to torch.int32"
f" defined in {node.meta.get('stack_trace','[no stack trace found]')}"
)
buffer_int32 = buffer.to(torch.int32)
self.exported_program.state_dict[buffer_name] = buffer_int32
continue

def call(self, graph_module: torch.fx.GraphModule):
self._to_int32(graph_module)
Expand Down
17 changes: 2 additions & 15 deletions backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,8 @@ def call(self, graph_module: GraphModule) -> PassResult:

class QuantizeOperatorArguments(ExportPass):
"""
This pass makes sure that the arguments to full.default and clamp.default are quantized correctly.
This pass makes sure that the arguments to clamp.default are quantized correctly.
More specifically, this pass:
- Makes sure the fill_value for full.default is quantized. This pass needs to be run before
the folding pass above to make sure that the retraced output of the full.default op is
the right dtype.
- Makes sure the min and max values to clamp.default are quantized, if it's a quantized operator.
"""

Expand All @@ -189,7 +186,6 @@ def call(self, graph_module: GraphModule) -> PassResult:
n = cast(Node, n)
if n.target not in {
exir_ops.edge.aten.clamp.default,
exir_ops.edge.aten.full.default,
}:
continue

Expand All @@ -200,16 +196,7 @@ def call(self, graph_module: GraphModule) -> PassResult:

qargs = QuantArgs.from_operator(user.target, user.args)

if n.target == exir_ops.edge.aten.full.default:
if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype:
# replace the node arg with a quantized dito and also set dtype
# to get the right output according to the Edge IR specification:
# exir/dialects/edge/edge.yaml:3596
quantized_full_value = qargs.quantize_value(n.args[1]).item()
n.update_arg(1, quantized_full_value)
n.update_kwarg("dtype", qargs.dtype)
modified = True
elif n.target == exir_ops.edge.aten.clamp.default:
if n.target == exir_ops.edge.aten.clamp.default:
# Quantize the min and max arguments of clamp, if they are not None
min_val = n.args[1]
max_val = None if len(n.args) <= 2 else n.args[2]
Expand Down
Loading
Loading