Skip to content

[cadence][aot]Implement mul.Tensor to quant fusion. #11580

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions backends/cadence/aot/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,76 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
return result


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class FuseMulTensorIntoQuantPass(ExportPass):
"""
Looks for the pattern where aten.mul.Tensor is followed by quant node.
If found, updates the quant scale to reflect the multiplication and
removes the mul node.
"""

def attempt_fusion(
self, graph_module: torch.fx.GraphModule, mul_node: torch.fx.Node
) -> None:
full_nodes = [
arg
for arg in mul_node.args
if isinstance(arg, torch.fx.Node)
and arg.target == exir_ops.edge.aten.full.default
]

if len(full_nodes) != 1 or len(mul_node.users) != 1:
return

full_node = full_nodes[0]
mul_user = list(mul_node.users.keys())[0]

if mul_user.target not in {
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.cadence.quantize_per_tensor.default,
}:
return

quant_node = mul_user

# Calculate the new scale value.
prev_scale = quant_node.args[1]
assert isinstance(prev_scale, (int, float))
mul_scalar = full_node.args[1]
assert isinstance(mul_scalar, (int, float))
new_scale = float(prev_scale) * float(mul_scalar)

logging.debug(
f"Fused {mul_node} and {full_node} into {quant_node}. Updated scale from {quant_node.args[1]} to {new_scale}"
)

# Replace the input first
quant_node.replace_input_with(
cast(torch.fx.Node, quant_node.args[0]),
cast(torch.fx.Node, mul_node.args[0]),
)

# Now update the scale in the args
new_quant_args = list(quant_node.args)
new_quant_args[1] = new_scale
quant_node.args = tuple(new_quant_args)

# Clean up the mul_node
mul_node.args = ()
mul_node.users = {}

graph_module.graph.erase_node(mul_node)
graph_module.graph.erase_node(full_node)

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
for node in graph_module.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.mul.Tensor
):
self.attempt_fusion(graph_module, node)
graph_module.graph.eliminate_dead_code()
return super().call(graph_module)


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class FuseMulTensorIntoDequantPass(ExportPass):
"""
Expand Down
43 changes: 43 additions & 0 deletions backends/cadence/aot/tests/test_fusion_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
FuseMMWithAdd,
FuseMulScalarIntoDequantPass,
FuseMulTensorIntoDequantPass,
FuseMulTensorIntoQuantPass,
FuseQuantDequantToRequantizePass,
FuseTransposeOrPermuteOpPairsPass,
)
Expand Down Expand Up @@ -587,6 +588,48 @@ def test_fuse_mul_scalar_into_dequant(self):
deq_scale = node.args[1]
self.assertEqual(deq_scale, dequant_scale * mul_value)

def test_fuse_mul_into_quant(self):
quant_scale = 1.5
mul_value = 10

builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(4, 32, dtype=torch.float32))
full = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([1], mul_value),
)
mul = builder.call_operator(
op=exir_ops.edge.aten.mul.Tensor,
args=(x, full),
)
quant = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(mul, quant_scale, 0, 0, 255, torch.uint8),
)
builder.output(quant)
graph_module = FuseMulTensorIntoQuantPass()(
builder.get_graph_module()
).graph_module

# verify that the mul and full ops were removed
self.check_op_counts(
graph_module,
expected_op_counts={
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
exir_ops.edge.aten.full.default: 0,
exir_ops.edge.aten.mul.Tensor: 0,
},
)

# verify that the quant scale value was updated correctly
for node in graph_module.graph.nodes:
if (
node.target
== exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
):
deq_scale = node.args[1]
self.assertEqual(deq_scale, quant_scale * mul_value)

def test_fuse_then_transpose_pass(self):
# Create a graph with full -> transpose.
builder = GraphBuilder()
Expand Down
Loading