Skip to content

Commit 7229b8b

Browse files
eigen-kfacebook-github-bot
authored andcommitted
Use GraphBuilder in unit tests. (#10977)
Summary: Use GraphBuilder to create the model for unit testing. Reviewed By: zonglinpeng Differential Revision: D74907087
1 parent fff7b3c commit 7229b8b

File tree

1 file changed

+110
-63
lines changed

1 file changed

+110
-63
lines changed

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 110 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@
1313
import executorch.backends.cadence.aot.ops_registrations # noqa
1414
import torch
1515
from executorch.backends.cadence.aot import compiler
16-
from executorch.backends.cadence.aot.compiler import (
17-
export_to_edge,
18-
quantize_and_export_to_edge,
19-
)
16+
from executorch.backends.cadence.aot.compiler import export_to_edge
2017
from executorch.backends.cadence.aot.fuse_ops import (
2118
FuseFullThenReshapePass,
2219
FuseMulScalarIntoDequantPass,
@@ -336,94 +333,144 @@ def test_replace_quant_view_dequant_with_requantize(self):
336333
)
337334

338335
def test_replace_dequant_quant_with_requantize(self):
339-
class M(torch.nn.Module):
340-
def __init__(self):
341-
super().__init__()
342-
343-
def forward(self, x):
344-
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
345-
x, 1.2, 3, 0, 127, torch.int8
346-
)
347-
x = torch.permute(x, [2, 0, 1, 3])
348-
x = torch.ops.quantized_decomposed.quantize_per_tensor(
349-
x, 4.5, 6, 0, 127, torch.int8
350-
)
351-
return x
352-
353-
inputs = torch.randn(2, 12, 1, 6).to(torch.int8)
354-
model = M()
355-
graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module
356-
graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
336+
builder = GraphBuilder()
337+
x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32))
338+
dequant = builder.call_operator(
339+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
340+
args=(x, 1.2, 3, 0, 127, torch.int8),
341+
)
342+
quant = builder.call_operator(
343+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
344+
args=(dequant, 4.5, 6, 0, 127, torch.int8),
345+
)
346+
builder.output(quant)
347+
graph_module = FuseQuantDequantToRequantizePass()(
348+
builder.get_graph_module()
349+
).graph_module
357350

358351
self.check_op_counts(
359352
graph_module,
360353
expected_op_counts={
361-
# Verify that dequant -> permute -> quant was replaced with permute -> requantize.
354+
# Verify that dequant -> quant was replaced with requantize.
362355
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
363356
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
364357
exir_ops.edge.cadence.requantize.default: 1,
365358
},
366359
)
367360

368361
def test_replace_dequant_permute_quant_with_requantize(self):
369-
class M(torch.nn.Module):
370-
def __init__(self):
371-
super().__init__()
372-
373-
def forward(self, x):
374-
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
375-
x, 1.2, 3, 0, 127, torch.int8
376-
)
377-
x = torch.permute(x, [2, 0, 1, 3])
378-
x = torch.ops.quantized_decomposed.quantize_per_tensor(
379-
x, 4.5, 6, 0, 127, torch.int8
380-
)
381-
return x
382-
383-
inputs = torch.randn(2, 12, 1, 6).to(torch.int8)
384-
model = M()
385-
graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module
386-
graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
362+
builder = GraphBuilder()
363+
x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32))
364+
dequant = builder.call_operator(
365+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
366+
args=(x, 1.2, 3, 0, 127, torch.int8),
367+
)
368+
permute = builder.call_operator(
369+
op=exir_ops.edge.aten.permute_copy.default, args=(dequant, [2, 0, 1, 3])
370+
)
371+
quant = builder.call_operator(
372+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
373+
args=(permute, 4.5, 6, 0, 127, torch.int8),
374+
)
375+
builder.output(quant)
376+
graph_module = FuseQuantDequantToRequantizePass()(
377+
builder.get_graph_module()
378+
).graph_module
387379

388380
self.check_op_counts(
389381
graph_module,
390382
expected_op_counts={
391383
# Verify that dequant -> permute -> quant was replaced with permute -> requantize.
392384
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
393385
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
386+
exir_ops.edge.aten.permute_copy.default: 1,
394387
exir_ops.edge.cadence.requantize.default: 1,
395388
},
396389
)
397390

398391
def test_remove_nop_dequant_quant(self):
399-
class M(torch.nn.Module):
400-
def __init__(self):
401-
super(M, self).__init__()
402-
self.lin1 = torch.nn.Linear(6, 12, bias=False)
403-
self.lin2 = torch.nn.Linear(12, 24, bias=False)
392+
LEADING_DIMS: Final[int] = 12
393+
IN_DIM: Final[int] = 6
394+
OUT_DIM: Final[int] = 12
404395

405-
def forward(self, x):
406-
x = self.lin1(x)
407-
# redundant dequant+quant will be created around this permute
408-
x = torch.permute(x, [0, 2, 1, 3])
409-
x = self.lin2(x)
410-
return x
411-
412-
inputs = torch.randn(2, 12, 1, 6)
413-
model = M()
414-
graph_module = (
415-
quantize_and_export_to_edge(model, (inputs,))
416-
.exported_program()
417-
.graph_module
396+
builder = GraphBuilder()
397+
x = builder.placeholder(
398+
"x", torch.randn(LEADING_DIMS, IN_DIM, dtype=torch.float32)
399+
)
400+
quant1 = builder.call_operator(
401+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
402+
args=(x, 4.5, 6, 0, 127, torch.int8),
403+
)
404+
weights = builder.call_operator(
405+
op=exir_ops.edge.aten.full.default, args=([OUT_DIM, IN_DIM], 1)
406+
)
407+
bias = builder.call_operator(
408+
op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 1)
409+
)
410+
weight_zero_point = builder.call_operator(
411+
op=exir_ops.edge.aten.full.default, args=([IN_DIM], 0)
412+
)
413+
out_multiplier = builder.call_operator(
414+
op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 1)
415+
)
416+
out_shift = builder.call_operator(
417+
op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 0)
418418
)
419-
graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
419+
linear1 = builder.call_operator(
420+
op=exir_ops.edge.cadence.quantized_linear.default,
421+
args=(
422+
quant1,
423+
weights,
424+
bias,
425+
0, # src_zero_point
426+
weight_zero_point,
427+
out_multiplier,
428+
out_shift,
429+
0, # out_zero_point
430+
None,
431+
),
432+
)
433+
dequant1 = builder.call_operator(
434+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
435+
args=(linear1, 1.2, 3, 0, 127, torch.int8),
436+
)
437+
permute = builder.call_operator(
438+
op=exir_ops.edge.aten.permute_copy.default, args=(dequant1, [1, 0])
439+
)
440+
quant2 = builder.call_operator(
441+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
442+
args=(permute, 4.5, 6, 0, 127, torch.int8),
443+
)
444+
linear2 = builder.call_operator(
445+
op=exir_ops.edge.cadence.quantized_linear.default,
446+
args=(
447+
quant2,
448+
weights,
449+
bias,
450+
0, # src_zero_point
451+
weight_zero_point,
452+
out_multiplier,
453+
out_shift,
454+
0, # out_zero_point
455+
None,
456+
),
457+
)
458+
dequant2 = builder.call_operator(
459+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
460+
args=(linear2, 1.2, 3, 0, 127, torch.int8),
461+
)
462+
builder.output(dequant2)
463+
graph_module = FuseQuantDequantToRequantizePass()(
464+
builder.get_graph_module()
465+
).graph_module
420466
self.check_op_counts(
421467
graph_module,
422468
expected_op_counts={
423-
# Verify that one dequant/quant pair was removed
424-
# Expect 1 quantize ops: 1 input
469+
# Verify that one dequant/quant pair was removed from chain:
470+
# quant->linear->dequant->permute->quant->linear->dequant
471+
# gets converted to:
472+
# quant->linear->permute->linear->dequant
425473
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
426-
# Expect 1 dequant op at the end (output of second linear)
427474
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1,
428475
},
429476
)

0 commit comments

Comments
 (0)