Skip to content

Commit 778ca71

Browse files
eigen-kfacebook-github-bot
authored andcommitted
Use GraphBuilder in unit tests. (pytorch#10977)
Summary: Pull Request resolved: pytorch#10977 Use GraphBuilder to create the model for unit testing. Reviewed By: zonglinpeng Differential Revision: D74907087
1 parent cfa1b5e commit 778ca71

File tree

1 file changed

+110
-63
lines changed

1 file changed

+110
-63
lines changed

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 110 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@
1313
import executorch.backends.cadence.aot.ops_registrations # noqa
1414
import torch
1515
from executorch.backends.cadence.aot import compiler
16-
from executorch.backends.cadence.aot.compiler import (
17-
export_to_edge,
18-
quantize_and_export_to_edge,
19-
)
16+
from executorch.backends.cadence.aot.compiler import export_to_edge
2017
from executorch.backends.cadence.aot.fuse_ops import (
2118
FuseFullThenReshapePass,
2219
FuseMulScalarIntoDequantPass,
@@ -339,94 +336,144 @@ def test_replace_quant_view_dequant_with_requantize(self):
339336
)
340337

341338
def test_replace_dequant_quant_with_requantize(self):
342-
class M(torch.nn.Module):
343-
def __init__(self):
344-
super().__init__()
345-
346-
def forward(self, x):
347-
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
348-
x, 1.2, 3, 0, 127, torch.int8
349-
)
350-
x = torch.permute(x, [2, 0, 1, 3])
351-
x = torch.ops.quantized_decomposed.quantize_per_tensor(
352-
x, 4.5, 6, 0, 127, torch.int8
353-
)
354-
return x
355-
356-
inputs = torch.randn(2, 12, 1, 6).to(torch.int8)
357-
model = M()
358-
graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module
359-
graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
339+
builder = GraphBuilder()
340+
x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32))
341+
dequant = builder.call_operator(
342+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
343+
args=(x, 1.2, 3, 0, 127, torch.int8),
344+
)
345+
quant = builder.call_operator(
346+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
347+
args=(dequant, 4.5, 6, 0, 127, torch.int8),
348+
)
349+
builder.output(quant)
350+
graph_module = FuseQuantDequantToRequantizePass()(
351+
builder.get_graph_module()
352+
).graph_module
360353

361354
self.check_op_counts(
362355
graph_module,
363356
expected_op_counts={
364-
# Verify that dequant -> permute -> quant was replaced with permute -> requantize.
357+
# Verify that dequant -> quant was replaced with requantize.
365358
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
366359
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
367360
exir_ops.edge.cadence.requantize.default: 1,
368361
},
369362
)
370363

371364
def test_replace_dequant_permute_quant_with_requantize(self):
372-
class M(torch.nn.Module):
373-
def __init__(self):
374-
super().__init__()
375-
376-
def forward(self, x):
377-
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
378-
x, 1.2, 3, 0, 127, torch.int8
379-
)
380-
x = torch.permute(x, [2, 0, 1, 3])
381-
x = torch.ops.quantized_decomposed.quantize_per_tensor(
382-
x, 4.5, 6, 0, 127, torch.int8
383-
)
384-
return x
385-
386-
inputs = torch.randn(2, 12, 1, 6).to(torch.int8)
387-
model = M()
388-
graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module
389-
graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
365+
builder = GraphBuilder()
366+
x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32))
367+
dequant = builder.call_operator(
368+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
369+
args=(x, 1.2, 3, 0, 127, torch.int8),
370+
)
371+
permute = builder.call_operator(
372+
op=exir_ops.edge.aten.permute_copy.default, args=(dequant, [2, 0, 1, 3])
373+
)
374+
quant = builder.call_operator(
375+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
376+
args=(permute, 4.5, 6, 0, 127, torch.int8),
377+
)
378+
builder.output(quant)
379+
graph_module = FuseQuantDequantToRequantizePass()(
380+
builder.get_graph_module()
381+
).graph_module
390382

391383
self.check_op_counts(
392384
graph_module,
393385
expected_op_counts={
394386
# Verify that dequant -> permute -> quant was replaced with permute -> requantize.
395387
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
396388
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
389+
exir_ops.edge.aten.permute_copy.default: 1,
397390
exir_ops.edge.cadence.requantize.default: 1,
398391
},
399392
)
400393

401394
def test_remove_nop_dequant_quant(self):
402-
class M(torch.nn.Module):
403-
def __init__(self):
404-
super(M, self).__init__()
405-
self.lin1 = torch.nn.Linear(6, 12, bias=False)
406-
self.lin2 = torch.nn.Linear(12, 24, bias=False)
395+
LEADING_DIMS: Final[int] = 12
396+
IN_DIM: Final[int] = 6
397+
OUT_DIM: Final[int] = 12
407398

408-
def forward(self, x):
409-
x = self.lin1(x)
410-
# redundant dequant+quant will be created around this permute
411-
x = torch.permute(x, [0, 2, 1, 3])
412-
x = self.lin2(x)
413-
return x
414-
415-
inputs = torch.randn(2, 12, 1, 6)
416-
model = M()
417-
graph_module = (
418-
quantize_and_export_to_edge(model, (inputs,))
419-
.exported_program()
420-
.graph_module
399+
builder = GraphBuilder()
400+
x = builder.placeholder(
401+
"x", torch.randn(LEADING_DIMS, IN_DIM, dtype=torch.float32)
402+
)
403+
quant1 = builder.call_operator(
404+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
405+
args=(x, 4.5, 6, 0, 127, torch.int8),
406+
)
407+
weights = builder.call_operator(
408+
op=exir_ops.edge.aten.full.default, args=([OUT_DIM, IN_DIM], 1)
409+
)
410+
bias = builder.call_operator(
411+
op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 1)
412+
)
413+
weight_zero_point = builder.call_operator(
414+
op=exir_ops.edge.aten.full.default, args=([IN_DIM], 0)
415+
)
416+
out_multiplier = builder.call_operator(
417+
op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 1)
418+
)
419+
out_shift = builder.call_operator(
420+
op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 0)
421421
)
422-
graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
422+
linear1 = builder.call_operator(
423+
op=exir_ops.edge.cadence.quantized_linear.default,
424+
args=(
425+
quant1,
426+
weights,
427+
bias,
428+
0, # src_zero_point
429+
weight_zero_point,
430+
out_multiplier,
431+
out_shift,
432+
0, # out_zero_point
433+
None,
434+
),
435+
)
436+
dequant1 = builder.call_operator(
437+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
438+
args=(linear1, 1.2, 3, 0, 127, torch.int8),
439+
)
440+
permute = builder.call_operator(
441+
op=exir_ops.edge.aten.permute_copy.default, args=(dequant1, [1, 0])
442+
)
443+
quant2 = builder.call_operator(
444+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
445+
args=(permute, 4.5, 6, 0, 127, torch.int8),
446+
)
447+
linear2 = builder.call_operator(
448+
op=exir_ops.edge.cadence.quantized_linear.default,
449+
args=(
450+
quant2,
451+
weights,
452+
bias,
453+
0, # src_zero_point
454+
weight_zero_point,
455+
out_multiplier,
456+
out_shift,
457+
0, # out_zero_point
458+
None,
459+
),
460+
)
461+
dequant2 = builder.call_operator(
462+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
463+
args=(linear2, 1.2, 3, 0, 127, torch.int8),
464+
)
465+
builder.output(dequant2)
466+
graph_module = FuseQuantDequantToRequantizePass()(
467+
builder.get_graph_module()
468+
).graph_module
423469
self.check_op_counts(
424470
graph_module,
425471
expected_op_counts={
426-
# Verify that one dequant/quant pair was removed
427-
# Expect 1 quantize ops: 1 input
472+
# Verify that one dequant/quant pair was removed from chain:
473+
# quant->linear->dequant->permute->quant->linear->dequant
474+
# gets converted to:
475+
# quant->linear->permute->linear->dequant
428476
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
429-
# Expect 1 dequant op at the end (output of second linear)
430477
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1,
431478
},
432479
)

0 commit comments

Comments
 (0)