Skip to content

Commit 5eccadb

Browse files
eigen-kfacebook-github-bot
authored andcommitted
Use GraphBuilder in unit tests. (#10977)
Summary: Use GraphBuilder to create the model for unit testing. Reviewed By: zonglinpeng Differential Revision: D74907087
1 parent d50f26f commit 5eccadb

File tree

1 file changed

+110
-63
lines changed

1 file changed

+110
-63
lines changed

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 110 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@
1313
import executorch.backends.cadence.aot.ops_registrations # noqa
1414
import torch
1515
from executorch.backends.cadence.aot import compiler
16-
from executorch.backends.cadence.aot.compiler import (
17-
export_to_edge,
18-
quantize_and_export_to_edge,
19-
)
16+
from executorch.backends.cadence.aot.compiler import export_to_edge
2017
from executorch.backends.cadence.aot.fuse_ops import (
2118
FuseFullThenReshapePass,
2219
FuseMulScalarIntoDequantPass,
@@ -340,94 +337,144 @@ def test_replace_quant_view_dequant_with_requantize(self):
340337
)
341338

342339
def test_replace_dequant_quant_with_requantize(self):
343-
class M(torch.nn.Module):
344-
def __init__(self):
345-
super().__init__()
346-
347-
def forward(self, x):
348-
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
349-
x, 1.2, 3, 0, 127, torch.int8
350-
)
351-
x = torch.permute(x, [2, 0, 1, 3])
352-
x = torch.ops.quantized_decomposed.quantize_per_tensor(
353-
x, 4.5, 6, 0, 127, torch.int8
354-
)
355-
return x
356-
357-
inputs = torch.randn(2, 12, 1, 6).to(torch.int8)
358-
model = M()
359-
graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module
360-
graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
340+
builder = GraphBuilder()
341+
x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32))
342+
dequant = builder.call_operator(
343+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
344+
args=(x, 1.2, 3, 0, 127, torch.int8),
345+
)
346+
quant = builder.call_operator(
347+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
348+
args=(dequant, 4.5, 6, 0, 127, torch.int8),
349+
)
350+
builder.output(quant)
351+
graph_module = FuseQuantDequantToRequantizePass()(
352+
builder.get_graph_module()
353+
).graph_module
361354

362355
self.check_op_counts(
363356
graph_module,
364357
expected_op_counts={
365-
# Verify that dequant -> permute -> quant was replaced with permute -> requantize.
358+
# Verify that dequant -> quant was replaced with requantize.
366359
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
367360
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
368361
exir_ops.edge.cadence.requantize.default: 1,
369362
},
370363
)
371364

372365
def test_replace_dequant_permute_quant_with_requantize(self):
373-
class M(torch.nn.Module):
374-
def __init__(self):
375-
super().__init__()
376-
377-
def forward(self, x):
378-
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
379-
x, 1.2, 3, 0, 127, torch.int8
380-
)
381-
x = torch.permute(x, [2, 0, 1, 3])
382-
x = torch.ops.quantized_decomposed.quantize_per_tensor(
383-
x, 4.5, 6, 0, 127, torch.int8
384-
)
385-
return x
386-
387-
inputs = torch.randn(2, 12, 1, 6).to(torch.int8)
388-
model = M()
389-
graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module
390-
graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
366+
builder = GraphBuilder()
367+
x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32))
368+
dequant = builder.call_operator(
369+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
370+
args=(x, 1.2, 3, 0, 127, torch.int8),
371+
)
372+
permute = builder.call_operator(
373+
op=exir_ops.edge.aten.permute_copy.default, args=(dequant, [2, 0, 1, 3])
374+
)
375+
quant = builder.call_operator(
376+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
377+
args=(permute, 4.5, 6, 0, 127, torch.int8),
378+
)
379+
builder.output(quant)
380+
graph_module = FuseQuantDequantToRequantizePass()(
381+
builder.get_graph_module()
382+
).graph_module
391383

392384
self.check_op_counts(
393385
graph_module,
394386
expected_op_counts={
395387
# Verify that dequant -> permute -> quant was replaced with permute -> requantize.
396388
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
397389
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
390+
exir_ops.edge.aten.permute_copy.default: 1,
398391
exir_ops.edge.cadence.requantize.default: 1,
399392
},
400393
)
401394

402395
def test_remove_nop_dequant_quant(self):
403-
class M(torch.nn.Module):
404-
def __init__(self):
405-
super(M, self).__init__()
406-
self.lin1 = torch.nn.Linear(6, 12, bias=False)
407-
self.lin2 = torch.nn.Linear(12, 24, bias=False)
408-
409-
def forward(self, x):
410-
x = self.lin1(x)
411-
# redundant dequant+quant will be created around this permute
412-
x = torch.permute(x, [0, 2, 1, 3])
413-
x = self.lin2(x)
414-
return x
396+
LEADING_DIMS: Final[int] = 12
397+
IN_DIM: Final[int] = 6
398+
OUT_DIM: Final[int] = 12
415399

416-
inputs = torch.randn(2, 12, 1, 6)
417-
model = M()
418-
graph_module = (
419-
quantize_and_export_to_edge(model, (inputs,))
420-
.exported_program()
421-
.graph_module
400+
builder = GraphBuilder()
401+
x = builder.placeholder(
402+
"x", torch.randn(LEADING_DIMS, IN_DIM, dtype=torch.float32)
422403
)
423-
graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
404+
quant1 = builder.call_operator(
405+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
406+
args=(x, 4.5, 6, 0, 127, torch.int8),
407+
)
408+
weights = builder.call_operator(
409+
op=exir_ops.edge.aten.full.default, args=([OUT_DIM, IN_DIM], 1)
410+
)
411+
bias = builder.call_operator(
412+
op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 1)
413+
)
414+
weight_zero_point = builder.call_operator(
415+
op=exir_ops.edge.aten.full.default, args=([IN_DIM], 0)
416+
)
417+
out_multiplier = builder.call_operator(
418+
op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 1)
419+
)
420+
out_shift = builder.call_operator(
421+
op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 0)
422+
)
423+
linear1 = builder.call_operator(
424+
op=exir_ops.edge.cadence.quantized_linear.default,
425+
args=(
426+
quant1,
427+
weights,
428+
bias,
429+
0, # src_zero_point
430+
weight_zero_point,
431+
out_multiplier,
432+
out_shift,
433+
0, # out_zero_point
434+
None,
435+
),
436+
)
437+
dequant1 = builder.call_operator(
438+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
439+
args=(linear1, 1.2, 3, 0, 127, torch.int8),
440+
)
441+
permute = builder.call_operator(
442+
op=exir_ops.edge.aten.permute_copy.default, args=(dequant1, [1, 0])
443+
)
444+
quant2 = builder.call_operator(
445+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
446+
args=(permute, 4.5, 6, 0, 127, torch.int8),
447+
)
448+
linear2 = builder.call_operator(
449+
op=exir_ops.edge.cadence.quantized_linear.default,
450+
args=(
451+
quant2,
452+
weights,
453+
bias,
454+
0, # src_zero_point
455+
weight_zero_point,
456+
out_multiplier,
457+
out_shift,
458+
0, # out_zero_point
459+
None,
460+
),
461+
)
462+
dequant2 = builder.call_operator(
463+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
464+
args=(linear2, 1.2, 3, 0, 127, torch.int8),
465+
)
466+
builder.output(dequant2)
467+
graph_module = FuseQuantDequantToRequantizePass()(
468+
builder.get_graph_module()
469+
).graph_module
424470
self.check_op_counts(
425471
graph_module,
426472
expected_op_counts={
427-
# Verify that one dequant/quant pair was removed
428-
# Expect 1 quantize ops: 1 input
473+
# Verify that one dequant/quant pair was removed from chain:
474+
# quant->linear->dequant->permute->quant->linear->dequant
475+
# gets converted to:
476+
# quant->linear->permute->linear->dequant
429477
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
430-
# Expect 1 dequant op at the end (output of second linear)
431478
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1,
432479
},
433480
)

0 commit comments

Comments
 (0)