Skip to content

Commit fff7b3c

Browse files
authored
Make test_force_quant_dequant_fusion use GraphBuilder.
Differential Revision: D74841541 Pull Request resolved: #10926
1 parent 633320e commit fff7b3c

File tree

1 file changed

+19
-23
lines changed

1 file changed

+19
-23
lines changed

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -248,32 +248,28 @@ def forward(self, x):
248248
)
249249

250250
def test_force_quant_dequant_fusion(self):
251-
class M(torch.nn.Module):
252-
def __init__(self):
253-
super().__init__()
254-
255-
def forward(self, x):
256-
x = torch.ops.quantized_decomposed.quantize_per_tensor(
257-
x, 1.2, 3, 0, 127, torch.int8
258-
)
259-
x = torch.permute(x, [2, 0, 1, 3])
260-
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
261-
x, 4.5, 6, 0, 127, torch.int8
262-
)
263-
return x
264-
265-
inputs = torch.randn(2, 12, 1, 6)
266-
model = M()
267-
graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module
268-
269-
graph_module = FuseQuantDequantToRequantizePass(
251+
builder = GraphBuilder()
252+
x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32))
253+
quant = builder.call_operator(
254+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
255+
args=(x, 1.2, 3, 0, 127, torch.int8),
256+
)
257+
permute = builder.call_operator(
258+
op=exir_ops.edge.aten.permute_copy.default, args=(quant, [2, 0, 1, 3])
259+
)
260+
dequant = builder.call_operator(
261+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
262+
args=(permute, 4.5, 6, 0, 127, torch.int8),
263+
)
264+
builder.output(dequant)
265+
original_graph = builder.get_graph_module()
266+
converted_graph = FuseQuantDequantToRequantizePass(
270267
force_quant_dequant_fusion=True
271-
)(graph_module).graph_module
268+
)(original_graph).graph_module
272269
self.check_op_counts(
273-
graph_module,
270+
converted_graph,
274271
expected_op_counts={
275-
# Verify that no dequant/quant pair was replaced with requantize.
276-
# quantize -> permute -> dequantize should not be replaced with requantize.
272+
# Verify that dequant/quant pair was replaced with requantize.
277273
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
278274
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
279275
exir_ops.edge.cadence.requantize.default: 1,

0 commit comments

Comments
 (0)