Skip to content

Commit 7a2c875

Browse files
eigen-kfacebook-github-bot
authored andcommitted
Implement mul.Tensor to quant fusion.
Reviewed By: hsharma35 Differential Revision: D76302365
1 parent 07be1ff commit 7a2c875

File tree

2 files changed

+110
-0
lines changed

2 files changed

+110
-0
lines changed

backends/cadence/aot/fuse_ops.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,73 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
862862
result = super().call(graph_module)
863863
return result
864864

865+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
866+
class FuseMulTensorIntoQuantPass(ExportPass):
867+
"""
868+
Looks for the pattern where aten.mul.Tensor is followed by quant node.
869+
If found, updates the quant scale to reflect the multiplication and
870+
removes the mul node.
871+
"""
872+
def attempt_fusion(
873+
self, graph_module: torch.fx.GraphModule, mul_node: torch.fx.Node
874+
) -> None:
875+
if mul_node.target != exir_ops.edge.aten.mul.Tensor:
876+
return
877+
878+
full_nodes = [
879+
arg
880+
for arg in mul_node.args
881+
if isinstance(arg, torch.fx.Node)
882+
and arg.target == exir_ops.edge.aten.full.default
883+
]
884+
885+
if len(full_nodes) != 1 or len(mul_node.users) != 1:
886+
return
887+
888+
full_node = full_nodes[0]
889+
mul_user = list(mul_node.users.keys())[0]
890+
891+
if mul_user.target not in {
892+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
893+
exir_ops.edge.cadence.quantize_per_tensor.default,
894+
}:
895+
return
896+
897+
quant_node = mul_user
898+
899+
# First create a copy of the current args
900+
new_quant_args = list(quant_node.args)
901+
assert isinstance(quant_node.args[1], Number)
902+
assert isinstance(full_node.args[1], Number)
903+
# pyre-ignore[58]: Unsupported operand *
904+
new_scale = quant_node.args[1] * full_node.args[1]
905+
906+
logging.debug(
907+
f"Fused {mul_node} and {full_node} into {quant_node}. Updated scale from {quant_node.args[1]} to {new_scale}"
908+
)
909+
910+
# Replace the input first
911+
quant_node.replace_input_with(cast(torch.fx.Node, quant_node.args[0]), cast(torch.fx.Node, mul_node.args[0]))
912+
913+
# Now update the scale in the args
914+
new_quant_args = list(quant_node.args)
915+
new_quant_args[1] = new_scale
916+
quant_node.args = tuple(new_quant_args)
917+
918+
# Clean up the mul_node
919+
mul_node.args = tuple()
920+
mul_node.users = {}
921+
922+
graph_module.graph.erase_node(mul_node)
923+
graph_module.graph.erase_node(full_node)
924+
graph_module.recompile()
925+
926+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
927+
for node in graph_module.graph.nodes:
928+
self.attempt_fusion(graph_module, node)
929+
result = super().call(graph_module)
930+
return result
931+
865932

866933
@register_cadence_pass(CadencePassAttribute(opt_level=1))
867934
class FuseMulTensorIntoDequantPass(ExportPass):

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
FuseMulTensorIntoDequantPass,
2323
FuseQuantDequantToRequantizePass,
2424
FuseTransposeOrPermuteOpPairsPass,
25+
FuseMulTensorIntoQuantPass,
2526
)
2627
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
2728
from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
@@ -587,6 +588,48 @@ def test_fuse_mul_scalar_into_dequant(self):
587588
deq_scale = node.args[1]
588589
self.assertEqual(deq_scale, dequant_scale * mul_value)
589590

591+
def test_fuse_mul_into_quant(self):
592+
quant_scale = 1.5
593+
mul_value = 10
594+
595+
builder = GraphBuilder()
596+
x = builder.placeholder("x", torch.randn(4, 32, dtype=torch.float32))
597+
full = builder.call_operator(
598+
op=exir_ops.edge.aten.full.default,
599+
args=([1], mul_value),
600+
)
601+
mul = builder.call_operator(
602+
op=exir_ops.edge.aten.mul.Tensor,
603+
args=(x, full),
604+
)
605+
quant = builder.call_operator(
606+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
607+
args=(mul, quant_scale, 0, 0, 255, torch.uint8),
608+
)
609+
builder.output(quant)
610+
graph_module = FuseMulTensorIntoQuantPass()(
611+
builder.get_graph_module()
612+
).graph_module
613+
614+
# verify that the mul and full ops were removed
615+
self.check_op_counts(
616+
graph_module,
617+
expected_op_counts={
618+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
619+
exir_ops.edge.aten.full.default: 0,
620+
exir_ops.edge.aten.mul.Tensor: 0,
621+
},
622+
)
623+
624+
# verify that the quant scale value was updated correctly
625+
for node in graph_module.graph.nodes:
626+
if (
627+
node.target
628+
== exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
629+
):
630+
deq_scale = node.args[1]
631+
self.assertEqual(deq_scale, quant_scale * mul_value)
632+
590633
def test_fuse_then_transpose_pass(self):
591634
# Create a graph with full -> transpose.
592635
builder = GraphBuilder()

0 commit comments

Comments
 (0)