Skip to content

Use GraphBuilder in unit tests. #10977

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 22, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 109 additions & 63 deletions backends/cadence/aot/tests/test_fusion_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@
import executorch.backends.cadence.aot.ops_registrations # noqa
import torch
from executorch.backends.cadence.aot import compiler
from executorch.backends.cadence.aot.compiler import (
export_to_edge,
quantize_and_export_to_edge,
)
from executorch.backends.cadence.aot.fuse_ops import (
FuseFullThenReshapePass,
FuseMulScalarIntoDequantPass,
Expand Down Expand Up @@ -336,94 +332,144 @@ def test_replace_quant_view_dequant_with_requantize(self):
)

def test_replace_dequant_quant_with_requantize(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, 1.2, 3, 0, 127, torch.int8
)
x = torch.permute(x, [2, 0, 1, 3])
x = torch.ops.quantized_decomposed.quantize_per_tensor(
x, 4.5, 6, 0, 127, torch.int8
)
return x

inputs = torch.randn(2, 12, 1, 6).to(torch.int8)
model = M()
graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module
graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32))
dequant = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
args=(x, 1.2, 3, 0, 127, torch.int8),
)
quant = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(dequant, 4.5, 6, 0, 127, torch.int8),
)
builder.output(quant)
graph_module = FuseQuantDequantToRequantizePass()(
builder.get_graph_module()
).graph_module

self.check_op_counts(
graph_module,
expected_op_counts={
# Verify that dequant -> permute -> quant was replaced with permute -> requantize.
# Verify that dequant -> quant was replaced with requantize.
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
exir_ops.edge.cadence.requantize.default: 1,
},
)

def test_replace_dequant_permute_quant_with_requantize(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, 1.2, 3, 0, 127, torch.int8
)
x = torch.permute(x, [2, 0, 1, 3])
x = torch.ops.quantized_decomposed.quantize_per_tensor(
x, 4.5, 6, 0, 127, torch.int8
)
return x

inputs = torch.randn(2, 12, 1, 6).to(torch.int8)
model = M()
graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module
graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32))
dequant = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
args=(x, 1.2, 3, 0, 127, torch.int8),
)
permute = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default, args=(dequant, [2, 0, 1, 3])
)
quant = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(permute, 4.5, 6, 0, 127, torch.int8),
)
builder.output(quant)
graph_module = FuseQuantDequantToRequantizePass()(
builder.get_graph_module()
).graph_module

self.check_op_counts(
graph_module,
expected_op_counts={
# Verify that dequant -> permute -> quant was replaced with permute -> requantize.
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
exir_ops.edge.aten.permute_copy.default: 1,
exir_ops.edge.cadence.requantize.default: 1,
},
)

def test_remove_nop_dequant_quant(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.lin1 = torch.nn.Linear(6, 12, bias=False)
self.lin2 = torch.nn.Linear(12, 24, bias=False)
LEADING_DIMS: Final[int] = 12
IN_DIM: Final[int] = 6
OUT_DIM: Final[int] = 12

def forward(self, x):
x = self.lin1(x)
# redundant dequant+quant will be created around this permute
x = torch.permute(x, [0, 2, 1, 3])
x = self.lin2(x)
return x

inputs = torch.randn(2, 12, 1, 6)
model = M()
graph_module = (
quantize_and_export_to_edge(model, (inputs,))
.exported_program()
.graph_module
builder = GraphBuilder()
x = builder.placeholder(
"x", torch.randn(LEADING_DIMS, IN_DIM, dtype=torch.float32)
)
quant1 = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(x, 4.5, 6, 0, 127, torch.int8),
)
weights = builder.call_operator(
op=exir_ops.edge.aten.full.default, args=([OUT_DIM, IN_DIM], 1)
)
bias = builder.call_operator(
op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 1)
)
weight_zero_point = builder.call_operator(
op=exir_ops.edge.aten.full.default, args=([IN_DIM], 0)
)
out_multiplier = builder.call_operator(
op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 1)
)
out_shift = builder.call_operator(
op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 0)
)
graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
linear1 = builder.call_operator(
op=exir_ops.edge.cadence.quantized_linear.default,
args=(
quant1,
weights,
bias,
0, # src_zero_point
weight_zero_point,
out_multiplier,
out_shift,
0, # out_zero_point
None,
),
)
dequant1 = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
args=(linear1, 1.2, 3, 0, 127, torch.int8),
)
permute = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default, args=(dequant1, [1, 0])
)
quant2 = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(permute, 4.5, 6, 0, 127, torch.int8),
)
linear2 = builder.call_operator(
op=exir_ops.edge.cadence.quantized_linear.default,
args=(
quant2,
weights,
bias,
0, # src_zero_point
weight_zero_point,
out_multiplier,
out_shift,
0, # out_zero_point
None,
),
)
dequant2 = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
args=(linear2, 1.2, 3, 0, 127, torch.int8),
)
builder.output(dequant2)
graph_module = FuseQuantDequantToRequantizePass()(
builder.get_graph_module()
).graph_module
self.check_op_counts(
graph_module,
expected_op_counts={
# Verify that one dequant/quant pair was removed
# Expect 1 quantize ops: 1 input
# Verify that one dequant/quant pair was removed from chain:
# quant->linear->dequant->permute->quant->linear->dequant
# gets converted to:
# quant->linear->permute->linear->dequant
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
# Expect 1 dequant op at the end (output of second linear)
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1,
},
)
Expand Down
Loading