Skip to content

Commit 9467d4f

Browse files
committed
Arm backend: Add ComputeConstantOpsAOT pass
Operators that output tensors based on constant args are pre-computed and added as buffers. - The pass currently supports full, arange, linspace, and eye. - Remove some logic for full now handled by the pass - Rename FuseConstantOpsPass to FuseConstantArgsPass and do minor improvements Signed-off-by: Erik Lundell <[email protected]> Change-Id: I744e2583a9ed011e350cfaa43410902bd9e54292
1 parent c91a6c0 commit 9467d4f

12 files changed

+322
-129
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@
5555
RetraceFoldedDtypesPass,
5656
)
5757
from executorch.backends.arm._passes.fuse_batchnorm2d_pass import FuseBatchnorm2DPass
58-
from executorch.backends.arm._passes.fuse_constant_ops_pass import FuseConstantOpsPass
58+
from executorch.backends.arm._passes.fuse_constant_ops_pass import (
59+
ComputeConstantOpsAOT,
60+
FuseConstantArgsPass,
61+
)
5962
from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( # type: ignore[import-not-found]
6063
FuseQuantizedActivationPass,
6164
)
@@ -121,21 +124,23 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
121124
self.add_pass(QuantizeOperatorArguments())
122125
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
123126
self.add_pass(RetraceFoldedDtypesPass())
127+
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
128+
self.add_pass(MatchArgRanksPass(exported_program))
129+
self.add_pass(ComputeConstantOpsAOT(exported_program))
124130

125131
self.add_pass(RemoveClonePass())
126132
self.add_pass(SizeAdjustConv2DPass())
127133
self.add_pass(ConvertExpandCopyToRepeatPass())
128134
self.add_pass(UnsqueezeBeforeRepeatPass())
129-
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
130135
self.add_pass(CastInt64ToInt32Pass(exported_program))
131-
self.add_pass(MatchArgRanksPass(exported_program))
132136
self.add_pass(KeepDimsFalseToSqueezePass())
133137
self.add_pass(Conv1dUnsqueezePass(exported_program))
134138
self.add_pass(DecomposeSelectPass())
135139
self.add_pass(ConvertSqueezesToViewPass())
136140

137141
self.add_pass(FuseViewCopyTransform())
138-
self.add_pass(FuseConstantOpsPass(exported_program))
142+
self.add_pass(FuseConstantArgsPass(exported_program))
143+
139144
self.add_pass(InsertTableOpsPass(exported_program))
140145
self.add_pass(AnnotateChannelsLastDimOrder())
141146
self.add_pass(InsertRescalePass())
@@ -166,21 +171,22 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
166171
self.add_pass(QuantizeOperatorArguments())
167172
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
168173
self.add_pass(RetraceFoldedDtypesPass())
174+
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
175+
self.add_pass(MatchArgRanksPass(exported_program))
176+
self.add_pass(ComputeConstantOpsAOT(exported_program))
169177

170178
self.add_pass(RemoveClonePass())
171179
self.add_pass(SizeAdjustConv2DPass())
172180
self.add_pass(ConvertExpandCopyToRepeatPass())
173181
self.add_pass(UnsqueezeBeforeRepeatPass())
174-
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
175182
self.add_pass(CastInt64ToInt32Pass(exported_program))
176-
self.add_pass(MatchArgRanksPass(exported_program))
177183
self.add_pass(KeepDimsFalseToSqueezePass())
178184
self.add_pass(Conv1dUnsqueezePass(exported_program))
179185
self.add_pass(DecomposeSelectPass())
180186
self.add_pass(ConvertSqueezesToViewPass())
181187

182188
self.add_pass(FuseViewCopyTransform())
183-
self.add_pass(FuseConstantOpsPass(exported_program))
189+
self.add_pass(FuseConstantArgsPass(exported_program))
184190
self.add_pass(InsertTableOpsPass(exported_program))
185191
self.add_pass(AnnotateChannelsLastDimOrder())
186192
self.add_pass(InsertRescalePass())

backends/arm/_passes/cast_int64_pass.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import logging
99

1010
import torch
11-
from executorch.backends.arm._passes.arm_pass_utils import is_param_node
1211
from executorch.exir.pass_base import ExportPass, PassResult
1312
from torch._export.utils import is_buffer
1413

@@ -25,35 +24,37 @@ def __init__(self, exported_program: torch.export.ExportedProgram):
2524
super(CastInt64ToInt32Pass, self).__init__()
2625
self.exported_program = exported_program
2726

27+
def _assert_within_int32(self, tensor: torch.Tensor, node: torch.fx.Node):
28+
if torch.min(tensor) < torch.iinfo(torch.int32).min:
29+
raise RuntimeError(
30+
f"Node {node.name} has value < {torch.iinfo(torch.int32).min}"
31+
)
32+
if torch.max(tensor) > torch.iinfo(torch.int32).max:
33+
raise RuntimeError(
34+
f"Node {node.name} has value > {torch.iinfo(torch.int32).max}"
35+
)
36+
2837
def _to_int32(self, graph_module: torch.fx.GraphModule):
2938
for node in graph_module.graph.nodes:
3039
fake_tensor = node.meta["val"]
31-
if isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor):
32-
if node.meta["val"].dtype == torch.int64 and is_param_node(
33-
self.exported_program, node
34-
):
35-
if is_buffer(self.exported_program, node):
36-
node.meta["val"] = node.meta["val"].to(torch.int32)
37-
buffer_name = (
38-
self.exported_program.graph_signature.inputs_to_buffers[
39-
node.name
40-
]
41-
)
42-
buffer = self.exported_program.state_dict[node.name]
43-
logger.warning(
44-
f"Casting buffer {node.name} from torch.int64 to torch.int32"
45-
f" defined in {node.meta['stack_trace']}"
46-
)
47-
if torch.min(buffer) < torch.iinfo(torch.int32).min:
48-
raise RuntimeError(
49-
f"Buffer {node.name} has value < {torch.iinfo(torch.int32).min}"
50-
)
51-
if torch.max(buffer) > torch.iinfo(torch.int32).max:
52-
raise RuntimeError(
53-
f"Buffer {node.name} has value > {torch.iinfo(torch.int32).max}"
54-
)
55-
buffer_int32 = buffer.to(torch.int32)
56-
self.exported_program.state_dict[buffer_name] = buffer_int32
40+
if not isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor):
41+
continue
42+
if fake_tensor.dtype != torch.int64:
43+
continue
44+
if is_buffer(self.exported_program, node):
45+
node.meta["val"] = fake_tensor.to(torch.int32)
46+
buffer_name = self.exported_program.graph_signature.inputs_to_buffers[
47+
node.name
48+
]
49+
buffer = self.exported_program.state_dict[node.name]
50+
self._assert_within_int32(buffer, node)
51+
logger.warning(
52+
f"Casting buffer {node.name} from torch.int64 to torch.int32"
53+
f" defined in {node.meta.get('stack_trace','[no stack trace found]')}"
54+
)
55+
buffer_int32 = buffer.to(torch.int32)
56+
self.exported_program.state_dict[buffer_name] = buffer_int32
57+
continue
5758

5859
def call(self, graph_module: torch.fx.GraphModule):
5960
self._to_int32(graph_module)

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,8 @@ def call(self, graph_module: GraphModule) -> PassResult:
174174

175175
class QuantizeOperatorArguments(ExportPass):
176176
"""
177-
This pass makes sure that the arguments to full.default and clamp.default are quantized correctly.
177+
This pass makes sure that the arguments to clamp.default are quantized correctly.
178178
More specifically, this pass:
179-
- Makes sure the fill_value for full.default is quantized. This pass needs to be run before
180-
the folding pass above to make sure that the retraced output of the full.default op is
181-
the right dtype.
182179
- Makes sure the min and max values to clamp.default are quantized, if it's a quantized operator.
183180
"""
184181

@@ -189,7 +186,6 @@ def call(self, graph_module: GraphModule) -> PassResult:
189186
n = cast(Node, n)
190187
if n.target not in {
191188
exir_ops.edge.aten.clamp.default,
192-
exir_ops.edge.aten.full.default,
193189
}:
194190
continue
195191

@@ -200,16 +196,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
200196

201197
qargs = QuantArgs.from_operator(user.target, user.args)
202198

203-
if n.target == exir_ops.edge.aten.full.default:
204-
if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype:
205-
# replace the node arg with a quantized dito and also set dtype
206-
# to get the right output according to the Edge IR specification:
207-
# exir/dialects/edge/edge.yaml:3596
208-
quantized_full_value = qargs.quantize_value(n.args[1]).item()
209-
n.update_arg(1, quantized_full_value)
210-
n.update_kwarg("dtype", qargs.dtype)
211-
modified = True
212-
elif n.target == exir_ops.edge.aten.clamp.default:
199+
if n.target == exir_ops.edge.aten.clamp.default:
213200
# Quantize the min and max arguments of clamp, if they are not None
214201
min_val = n.args[1]
215202
max_val = None if len(n.args) <= 2 else n.args[2]

0 commit comments

Comments
 (0)