Skip to content

Commit 12b5eb6

Browse files
authored
Add a pass to fuse mul.Scalar into dequant
Differential Revision: D74626626 Pull Request resolved: #10853
1 parent 101746e commit 12b5eb6

File tree

2 files changed

+100
-7
lines changed

2 files changed

+100
-7
lines changed

backends/cadence/aot/fuse_ops.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -814,11 +814,61 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
814814

815815

816816
@register_cadence_pass(CadencePassAttribute(opt_level=1))
817-
class FuseMulIntoDequantPass(ExportPass):
817+
class FuseMulScalarIntoDequantPass(ExportPass):
818818
"""
819-
Looks for the pattern where atem.mul is multiplying the outputs of dequantize
820-
and aten.full. If found, updates the dequant scale to reflect the multiplication
821-
and removes the full and mul nodes.
819+
Looks for the pattern where aten.mul.Scalar is multiplying the
820+
outputs of dequantize. If found, updates the dequant scale
821+
to reflect the multiplication and removes the mul node.
822+
"""
823+
824+
def attempt_fusion(
825+
self, graph_module: torch.fx.GraphModule, node: torch.fx.Node
826+
) -> None:
827+
if node.target not in {
828+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
829+
exir_ops.edge.cadence.dequantize_per_tensor.default,
830+
}:
831+
return
832+
833+
# ensure that the single user of dequant is aten.mul.Scalar
834+
user = list(node.users.keys())[0]
835+
if len(node.users) != 1 or user.target != exir_ops.edge.aten.mul.Scalar:
836+
return
837+
838+
# ensure that the other arg to mul is a node (i.e. not a constant)
839+
if len(user.args) > 1 and isinstance(user.args[1], torch.fx.Node):
840+
return
841+
842+
new_deq_args = list(node.args)
843+
assert isinstance(node.args[1], Number)
844+
assert isinstance(user.args[1], Number)
845+
# pyre-ignore[58]: Unsupported operand *
846+
new_deq_args[1] = node.args[1] * user.args[1]
847+
848+
logging.debug(
849+
f"Fused {node} and {user} into {node}. Updated scale from {node.args[1]} to {new_deq_args[1]}"
850+
)
851+
852+
user.replace_all_uses_with(node)
853+
node.args = tuple(new_deq_args)
854+
855+
graph_module.graph.erase_node(user)
856+
857+
graph_module.recompile()
858+
859+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
860+
for node in graph_module.graph.nodes:
861+
self.attempt_fusion(graph_module, node)
862+
result = super().call(graph_module)
863+
return result
864+
865+
866+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
867+
class FuseMulTensorIntoDequantPass(ExportPass):
868+
"""
869+
Looks for the pattern where aten.mul is multiplying the outputs of dequantize
870+
and aten.full, or vice versa. If found, updates the dequant scale to reflect
871+
the multiplication and removes the full and mul nodes.
822872
"""
823873

824874
def attempt_fusion(
@@ -1017,7 +1067,8 @@ class CadenceFuseOpsInGraph:
10171067
FuseCascadedTransposeOrPermuteOps,
10181068
FuseCascadedViewOps,
10191069
FuseQuantDequantToRequantizePass,
1020-
FuseMulIntoDequantPass,
1070+
FuseMulTensorIntoDequantPass,
1071+
FuseMulScalarIntoDequantPass,
10211072
FuseFullThenReshapePass,
10221073
FuseTransposeOrPermuteOpPairsPass,
10231074
]

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
)
2020
from executorch.backends.cadence.aot.fuse_ops import (
2121
FuseFullThenReshapePass,
22-
FuseMulIntoDequantPass,
22+
FuseMulScalarIntoDequantPass,
23+
FuseMulTensorIntoDequantPass,
2324
FuseQuantDequantToRequantizePass,
2425
FuseTransposeOrPermuteOpPairsPass,
2526
)
@@ -446,7 +447,7 @@ def forward(self, x):
446447

447448
inputs = (torch.randint(0, 255, [4, 32], dtype=torch.uint8),)
448449
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
449-
graph_module = FuseMulIntoDequantPass()(graph_module).graph_module
450+
graph_module = FuseMulTensorIntoDequantPass()(graph_module).graph_module
450451

451452
# verify that the mul and full ops were removed
452453
self.check_op_counts(
@@ -467,6 +468,47 @@ def forward(self, x):
467468
deq_scale = node.args[1]
468469
self.assertEqual(deq_scale, 4.5)
469470

471+
def test_fuse_mul_scalar_into_dequant(self):
472+
dequant_scale = 0.006
473+
mul_value = 0.3
474+
475+
builder = GraphBuilder()
476+
x = builder.placeholder("x", torch.randn(2, 3, 4, dtype=torch.float32))
477+
quant = builder.call_operator(
478+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
479+
args=(x, 1, 0, -128, 127, torch.int8),
480+
)
481+
dequant = builder.call_operator(
482+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
483+
args=(quant, dequant_scale, 5, -128, 127, torch.int8),
484+
)
485+
mul_scalar = builder.call_operator(
486+
op=exir_ops.edge.aten.mul.Scalar,
487+
args=(dequant, mul_value),
488+
)
489+
builder.output(mul_scalar)
490+
graph_module = builder.get_graph_module()
491+
492+
graph_module = FuseMulScalarIntoDequantPass()(graph_module).graph_module
493+
494+
# verify that the mul and full ops were removed
495+
self.check_op_counts(
496+
graph_module,
497+
expected_op_counts={
498+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1,
499+
exir_ops.edge.aten.mul.Scalar: 0,
500+
},
501+
)
502+
503+
# verify that the dequant scale value was updated correctly
504+
for node in graph_module.graph.nodes:
505+
if (
506+
node.target
507+
== exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
508+
):
509+
deq_scale = node.args[1]
510+
self.assertEqual(deq_scale, dequant_scale * mul_value)
511+
470512
def test_fuse_then_transpose_pass(self):
471513
# Create a graph with full -> transpose.
472514
builder = GraphBuilder()

0 commit comments

Comments
 (0)