Skip to content

Commit fe7eca2

Browse files
eigen-kfacebook-github-bot
authored andcommitted
Use GraphBuilder in unit tests. (#10977)
Summary: Pull Request resolved: #10977 Use GraphBuilder to create the model for unit testing. Reviewed By: zonglinpeng Differential Revision: D74907087
1 parent d3c8bcd commit fe7eca2

File tree

1 file changed

+110
-63
lines changed

1 file changed

+110
-63
lines changed

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 110 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@
1313
import executorch.backends.cadence.aot.ops_registrations # noqa
1414
import torch
1515
from executorch.backends.cadence.aot import compiler
16-
from executorch.backends.cadence.aot.compiler import (
17-
export_to_edge,
18-
quantize_and_export_to_edge,
19-
)
16+
from executorch.backends.cadence.aot.compiler import export_to_edge
2017
from executorch.backends.cadence.aot.fuse_ops import (
2118
FuseFullThenReshapePass,
2219
FuseMulScalarIntoDequantPass,
@@ -343,94 +340,144 @@ def test_replace_quant_view_dequant_with_requantize(self):
343340
)
344341

345342
def test_replace_dequant_quant_with_requantize(self):
346-
class M(torch.nn.Module):
347-
def __init__(self):
348-
super().__init__()
349-
350-
def forward(self, x):
351-
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
352-
x, 1.2, 3, 0, 127, torch.int8
353-
)
354-
x = torch.permute(x, [2, 0, 1, 3])
355-
x = torch.ops.quantized_decomposed.quantize_per_tensor(
356-
x, 4.5, 6, 0, 127, torch.int8
357-
)
358-
return x
359-
360-
inputs = torch.randn(2, 12, 1, 6).to(torch.int8)
361-
model = M()
362-
graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module
363-
graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
343+
builder = GraphBuilder()
344+
x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32))
345+
dequant = builder.call_operator(
346+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
347+
args=(x, 1.2, 3, 0, 127, torch.int8),
348+
)
349+
quant = builder.call_operator(
350+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
351+
args=(dequant, 4.5, 6, 0, 127, torch.int8),
352+
)
353+
builder.output(quant)
354+
graph_module = FuseQuantDequantToRequantizePass()(
355+
builder.get_graph_module()
356+
).graph_module
364357

365358
self.check_op_counts(
366359
graph_module,
367360
expected_op_counts={
368-
# Verify that dequant -> permute -> quant was replaced with permute -> requantize.
361+
# Verify that dequant -> quant was replaced with requantize.
369362
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
370363
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
371364
exir_ops.edge.cadence.requantize.default: 1,
372365
},
373366
)
374367

375368
def test_replace_dequant_permute_quant_with_requantize(self):
376-
class M(torch.nn.Module):
377-
def __init__(self):
378-
super().__init__()
379-
380-
def forward(self, x):
381-
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
382-
x, 1.2, 3, 0, 127, torch.int8
383-
)
384-
x = torch.permute(x, [2, 0, 1, 3])
385-
x = torch.ops.quantized_decomposed.quantize_per_tensor(
386-
x, 4.5, 6, 0, 127, torch.int8
387-
)
388-
return x
389-
390-
inputs = torch.randn(2, 12, 1, 6).to(torch.int8)
391-
model = M()
392-
graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module
393-
graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
369+
builder = GraphBuilder()
370+
x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32))
371+
dequant = builder.call_operator(
372+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
373+
args=(x, 1.2, 3, 0, 127, torch.int8),
374+
)
375+
permute = builder.call_operator(
376+
op=exir_ops.edge.aten.permute_copy.default, args=(dequant, [2, 0, 1, 3])
377+
)
378+
quant = builder.call_operator(
379+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
380+
args=(permute, 4.5, 6, 0, 127, torch.int8),
381+
)
382+
builder.output(quant)
383+
graph_module = FuseQuantDequantToRequantizePass()(
384+
builder.get_graph_module()
385+
).graph_module
394386

395387
self.check_op_counts(
396388
graph_module,
397389
expected_op_counts={
398390
# Verify that dequant -> permute -> quant was replaced with permute -> requantize.
399391
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
400392
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
393+
exir_ops.edge.aten.permute_copy.default: 1,
401394
exir_ops.edge.cadence.requantize.default: 1,
402395
},
403396
)
404397

405398
def test_remove_nop_dequant_quant(self):
406-
class M(torch.nn.Module):
407-
def __init__(self):
408-
super(M, self).__init__()
409-
self.lin1 = torch.nn.Linear(6, 12, bias=False)
410-
self.lin2 = torch.nn.Linear(12, 24, bias=False)
411-
412-
def forward(self, x):
413-
x = self.lin1(x)
414-
# redundant dequant+quant will be created around this permute
415-
x = torch.permute(x, [0, 2, 1, 3])
416-
x = self.lin2(x)
417-
return x
399+
LEADING_DIMS: Final[int] = 12
400+
IN_DIM: Final[int] = 6
401+
OUT_DIM: Final[int] = 12
418402

419-
inputs = torch.randn(2, 12, 1, 6)
420-
model = M()
421-
graph_module = (
422-
quantize_and_export_to_edge(model, (inputs,))
423-
.exported_program()
424-
.graph_module
403+
builder = GraphBuilder()
404+
x = builder.placeholder(
405+
"x", torch.randn(LEADING_DIMS, IN_DIM, dtype=torch.float32)
425406
)
426-
graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
407+
quant1 = builder.call_operator(
408+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
409+
args=(x, 4.5, 6, 0, 127, torch.int8),
410+
)
411+
weights = builder.call_operator(
412+
op=exir_ops.edge.aten.full.default, args=([OUT_DIM, IN_DIM], 1)
413+
)
414+
bias = builder.call_operator(
415+
op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 1)
416+
)
417+
weight_zero_point = builder.call_operator(
418+
op=exir_ops.edge.aten.full.default, args=([IN_DIM], 0)
419+
)
420+
out_multiplier = builder.call_operator(
421+
op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 1)
422+
)
423+
out_shift = builder.call_operator(
424+
op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 0)
425+
)
426+
linear1 = builder.call_operator(
427+
op=exir_ops.edge.cadence.quantized_linear.default,
428+
args=(
429+
quant1,
430+
weights,
431+
bias,
432+
0, # src_zero_point
433+
weight_zero_point,
434+
out_multiplier,
435+
out_shift,
436+
0, # out_zero_point
437+
None,
438+
),
439+
)
440+
dequant1 = builder.call_operator(
441+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
442+
args=(linear1, 1.2, 3, 0, 127, torch.int8),
443+
)
444+
permute = builder.call_operator(
445+
op=exir_ops.edge.aten.permute_copy.default, args=(dequant1, [1, 0])
446+
)
447+
quant2 = builder.call_operator(
448+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
449+
args=(permute, 4.5, 6, 0, 127, torch.int8),
450+
)
451+
linear2 = builder.call_operator(
452+
op=exir_ops.edge.cadence.quantized_linear.default,
453+
args=(
454+
quant2,
455+
weights,
456+
bias,
457+
0, # src_zero_point
458+
weight_zero_point,
459+
out_multiplier,
460+
out_shift,
461+
0, # out_zero_point
462+
None,
463+
),
464+
)
465+
dequant2 = builder.call_operator(
466+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
467+
args=(linear2, 1.2, 3, 0, 127, torch.int8),
468+
)
469+
builder.output(dequant2)
470+
graph_module = FuseQuantDequantToRequantizePass()(
471+
builder.get_graph_module()
472+
).graph_module
427473
self.check_op_counts(
428474
graph_module,
429475
expected_op_counts={
430-
# Verify that one dequant/quant pair was removed
431-
# Expect 1 quantize ops: 1 input
476+
# Verify that one dequant/quant pair was removed from chain:
477+
# quant->linear->dequant->permute->quant->linear->dequant
478+
# gets converted to:
479+
# quant->linear->permute->linear->dequant
432480
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
433-
# Expect 1 dequant op at the end (output of second linear)
434481
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1,
435482
},
436483
)

0 commit comments

Comments
 (0)