Skip to content

Commit a57ca0e

Browse files
eigen-kfacebook-github-bot
authored andcommitted
Use GraphBuilder in unit tests.
Summary: Use GraphBuilder to create the model for unit testing. Reviewed By: zonglinpeng Differential Revision: D74907087
1 parent 78227f0 commit a57ca0e

File tree

1 file changed

+110
-63
lines changed

1 file changed

+110
-63
lines changed

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 110 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@
1313
import executorch.backends.cadence.aot.ops_registrations # noqa
1414
import torch
1515
from executorch.backends.cadence.aot import compiler
16-
from executorch.backends.cadence.aot.compiler import (
17-
export_to_edge,
18-
quantize_and_export_to_edge,
19-
)
16+
from executorch.backends.cadence.aot.compiler import export_to_edge
2017
from executorch.backends.cadence.aot.fuse_ops import (
2118
FuseFullThenReshapePass,
2219
FuseMulScalarIntoDequantPass,
@@ -341,94 +338,144 @@ def forward(self, x):
341338
)
342339

343340
def test_replace_dequant_quant_with_requantize(self):
344-
class M(torch.nn.Module):
345-
def __init__(self):
346-
super().__init__()
347-
348-
def forward(self, x):
349-
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
350-
x, 1.2, 3, 0, 127, torch.int8
351-
)
352-
x = torch.permute(x, [2, 0, 1, 3])
353-
x = torch.ops.quantized_decomposed.quantize_per_tensor(
354-
x, 4.5, 6, 0, 127, torch.int8
355-
)
356-
return x
357-
358-
inputs = torch.randn(2, 12, 1, 6).to(torch.int8)
359-
model = M()
360-
graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module
361-
graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
341+
builder = GraphBuilder()
342+
x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32))
343+
dequant = builder.call_operator(
344+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
345+
args=(x, 1.2, 3, 0, 127, torch.int8),
346+
)
347+
quant = builder.call_operator(
348+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
349+
args=(dequant, 4.5, 6, 0, 127, torch.int8),
350+
)
351+
builder.output(quant)
352+
graph_module = FuseQuantDequantToRequantizePass()(
353+
builder.get_graph_module()
354+
).graph_module
362355

363356
self.check_op_counts(
364357
graph_module,
365358
expected_op_counts={
366-
# Verify that dequant -> permute -> quant was replaced with permute -> requantize.
359+
# Verify that dequant -> quant was replaced with requantize.
367360
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
368361
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
369362
exir_ops.edge.cadence.requantize.default: 1,
370363
},
371364
)
372365

373366
def test_replace_dequant_permute_quant_with_requantize(self):
374-
class M(torch.nn.Module):
375-
def __init__(self):
376-
super().__init__()
377-
378-
def forward(self, x):
379-
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
380-
x, 1.2, 3, 0, 127, torch.int8
381-
)
382-
x = torch.permute(x, [2, 0, 1, 3])
383-
x = torch.ops.quantized_decomposed.quantize_per_tensor(
384-
x, 4.5, 6, 0, 127, torch.int8
385-
)
386-
return x
387-
388-
inputs = torch.randn(2, 12, 1, 6).to(torch.int8)
389-
model = M()
390-
graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module
391-
graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
367+
builder = GraphBuilder()
368+
x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32))
369+
dequant = builder.call_operator(
370+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
371+
args=(x, 1.2, 3, 0, 127, torch.int8),
372+
)
373+
permute = builder.call_operator(
374+
op=exir_ops.edge.aten.permute_copy.default, args=(dequant, [2, 0, 1, 3])
375+
)
376+
quant = builder.call_operator(
377+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
378+
args=(permute, 4.5, 6, 0, 127, torch.int8),
379+
)
380+
builder.output(quant)
381+
graph_module = FuseQuantDequantToRequantizePass()(
382+
builder.get_graph_module()
383+
).graph_module
392384

393385
self.check_op_counts(
394386
graph_module,
395387
expected_op_counts={
396388
# Verify that dequant -> permute -> quant was replaced with permute -> requantize.
397389
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
398390
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
391+
exir_ops.edge.aten.permute_copy.default: 1,
399392
exir_ops.edge.cadence.requantize.default: 1,
400393
},
401394
)
402395

403396
def test_remove_nop_dequant_quant(self):
404-
class M(torch.nn.Module):
405-
def __init__(self):
406-
super(M, self).__init__()
407-
self.lin1 = torch.nn.Linear(6, 12, bias=False)
408-
self.lin2 = torch.nn.Linear(12, 24, bias=False)
409-
410-
def forward(self, x):
411-
x = self.lin1(x)
412-
# redundant dequant+quant will be created around this permute
413-
x = torch.permute(x, [0, 2, 1, 3])
414-
x = self.lin2(x)
415-
return x
397+
LEADING_DIMS: Final[int] = 12
398+
IN_DIM: Final[int] = 6
399+
OUT_DIM: Final[int] = 12
416400

417-
inputs = torch.randn(2, 12, 1, 6)
418-
model = M()
419-
graph_module = (
420-
quantize_and_export_to_edge(model, (inputs,))
421-
.exported_program()
422-
.graph_module
401+
builder = GraphBuilder()
402+
x = builder.placeholder(
403+
"x", torch.randn(LEADING_DIMS, IN_DIM, dtype=torch.float32)
423404
)
424-
graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
405+
quant1 = builder.call_operator(
406+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
407+
args=(x, 4.5, 6, 0, 127, torch.int8),
408+
)
409+
weights = builder.call_operator(
410+
op=exir_ops.edge.aten.full.default, args=([OUT_DIM, IN_DIM], 1)
411+
)
412+
bias = builder.call_operator(
413+
op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 1)
414+
)
415+
weight_zero_point = builder.call_operator(
416+
op=exir_ops.edge.aten.full.default, args=([IN_DIM], 0)
417+
)
418+
out_multiplier = builder.call_operator(
419+
op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 1)
420+
)
421+
out_shift = builder.call_operator(
422+
op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 0)
423+
)
424+
linear1 = builder.call_operator(
425+
op=exir_ops.edge.cadence.quantized_linear.default,
426+
args=(
427+
quant1,
428+
weights,
429+
bias,
430+
0, # src_zero_point
431+
weight_zero_point,
432+
out_multiplier,
433+
out_shift,
434+
0, # out_zero_point
435+
None,
436+
),
437+
)
438+
dequant1 = builder.call_operator(
439+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
440+
args=(linear1, 1.2, 3, 0, 127, torch.int8),
441+
)
442+
permute = builder.call_operator(
443+
op=exir_ops.edge.aten.permute_copy.default, args=(dequant1, [1, 0])
444+
)
445+
quant2 = builder.call_operator(
446+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
447+
args=(permute, 4.5, 6, 0, 127, torch.int8),
448+
)
449+
linear2 = builder.call_operator(
450+
op=exir_ops.edge.cadence.quantized_linear.default,
451+
args=(
452+
quant2,
453+
weights,
454+
bias,
455+
0, # src_zero_point
456+
weight_zero_point,
457+
out_multiplier,
458+
out_shift,
459+
0, # out_zero_point
460+
None,
461+
),
462+
)
463+
dequant2 = builder.call_operator(
464+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
465+
args=(linear2, 1.2, 3, 0, 127, torch.int8),
466+
)
467+
builder.output(dequant2)
468+
graph_module = FuseQuantDequantToRequantizePass()(
469+
builder.get_graph_module()
470+
).graph_module
425471
self.check_op_counts(
426472
graph_module,
427473
expected_op_counts={
428-
# Verify that one dequant/quant pair was removed
429-
# Expect 1 quantize ops: 1 input
474+
# Verify that one dequant/quant pair was removed from chain:
475+
# quant->linear->dequant->permute->quant->linear->dequant
476+
# gets converted to:
477+
# quant->linear->permute->linear->dequant
430478
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
431-
# Expect 1 dequant op at the end (output of second linear)
432479
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1,
433480
},
434481
)

0 commit comments

Comments
 (0)