Skip to content

Commit 47c2f2e

Browse files
perfreddan80
authored andcommitted
Add full operator to fold dq/q handling
Signed-off-by: Per Åstrand <[email protected]> Change-Id: I39d11cff0ef78df08e67f216b8e0bb86af9fac26
1 parent e24d503 commit 47c2f2e

File tree

3 files changed

+45
-17
lines changed

3 files changed

+45
-17
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
3232
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
3333
FoldAndAnnotateQParamsPass,
34+
QuantizeFullArgument,
3435
)
3536
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
3637
KeepDimsFalseToSqueezePass,
@@ -84,6 +85,7 @@ def transform_to_backend_pipeline(
8485
self.add_pass(Conv1dUnsqueezePass(exported_program))
8586
self.add_pass(DecomposeSoftmaxesPass())
8687
self.add_pass(DecomposeLinearPass())
88+
self.add_pass(QuantizeFullArgument())
8789
self.add_pass(
8890
FoldAndAnnotateQParamsPass(
8991
[
@@ -92,6 +94,7 @@ def transform_to_backend_pipeline(
9294
exir_ops.edge.aten.add.Tensor,
9395
exir_ops.edge.aten.avg_pool2d.default,
9496
exir_ops.edge.aten.convolution.default,
97+
exir_ops.edge.aten.full.default,
9598
]
9699
)
97100
)

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from executorch.exir.pass_base import ExportPass, PassResult
1616
from torch.fx import GraphModule, Node
1717

18+
q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
19+
dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
20+
1821

1922
def get_input_qparams(node: Node) -> dict[int, QuantArgs]:
2023
"""
@@ -77,8 +80,6 @@ def __init__(self, targeted_ops: Iterable[Callable]):
7780
self.targeted_ops = targeted_ops
7881

7982
def call(self, graph_module: GraphModule) -> PassResult:
80-
q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
81-
dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
8283

8384
# Loop over the graph nodes and find any node in the 'targeted_ops' list.
8485
for n in graph_module.graph.nodes:
@@ -145,3 +146,36 @@ def call(self, graph_module: GraphModule) -> PassResult:
145146

146147
graph_module.recompile()
147148
return PassResult(graph_module, True)
149+
150+
151+
class QuantizeFullArgument(ExportPass):
152+
"""
153+
Make sure the fill_value for full.default is quantized. This pass needs to be run before
154+
the folding pass above to make sure that the retraced output of the full.default op is
155+
the right dtype.
156+
"""
157+
158+
def call(self, graph_module: GraphModule) -> PassResult:
159+
modified = False
160+
# Loop over the graph nodes and find any node in the 'targeted_ops' list.
161+
for n in graph_module.graph.nodes:
162+
n = cast(Node, n)
163+
if n.target != exir_ops.edge.aten.full.default:
164+
continue
165+
166+
# Make sure we have a quantized operator
167+
user = list(n.users)[0]
168+
if user.target != q_op:
169+
continue
170+
171+
qargs = QuantArgs.from_operator(user.target, user.args)
172+
if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype:
173+
# replace the node arg with a quantized dito and also set dtype
174+
# to get the right output according to the Edge IR specification:
175+
# exir/dialects/edge/edge.yaml:3596
176+
quantized_full_value = qargs.quantize_value(n.args[1]).item()
177+
n.update_arg(1, quantized_full_value)
178+
n.update_kwarg("dtype", qargs.dtype)
179+
modified = True
180+
181+
return PassResult(graph_module, modified)

backends/arm/operators/op_full.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,6 @@
1414
register_node_visitor,
1515
)
1616
from executorch.backends.arm.tosa_mapping import TosaArg
17-
from executorch.backends.arm.tosa_quant_utils import (
18-
get_quant_arg_downstream,
19-
quantize_value,
20-
)
2117
from executorch.backends.arm.tosa_utils import tosa_shape
2218
from torch.fx import Node
2319

@@ -41,19 +37,14 @@ def define_node(
4137
shape = tosa_shape(inputs[0].special, output.dim_order)
4238

4339
value = inputs[1].number
44-
if is_quant_node:
45-
qargs = get_quant_arg_downstream(list(node.users)[0])
46-
qvalue = quantize_value(value, qargs)
47-
dtype = ts.DType.INT8
48-
data = np.full(shape, qvalue, dtype=np.int8)
40+
41+
if output.dtype == ts.DType.INT8:
42+
fill_dtype = np.int8
4943
else:
50-
assert (
51-
output.dtype == ts.DType.FP32
52-
), "'Full' currently only supports FP32 for unquantized models."
53-
dtype = ts.DType.FP32
54-
data = np.full(shape, value, dtype=np.float32)
44+
fill_dtype = np.float32
45+
data = np.full(shape, value, dtype=fill_dtype)
5546

56-
tosa_graph.addConst(shape, dtype, data, node.name + "full-const")
47+
tosa_graph.addConst(shape, output.dtype, data, node.name + "full-const")
5748
tosa_graph.addOperator(
5849
ts.TosaOp.Op.IDENTITY, [node.name + "full-const"], [output.name]
5950
)

0 commit comments

Comments
 (0)