Skip to content

remove exir.capture from quant fusion test #3106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 34 additions & 24 deletions exir/tests/test_quant_fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch
from executorch import exir
from executorch.exir import CaptureConfig, EdgeCompileConfig
from executorch.exir import EdgeCompileConfig, to_edge
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
from executorch.exir.tests.common import register_additional_test_aten_ops
from torch.ao.quantization import ( # @manual
Expand All @@ -26,6 +26,7 @@
_convert_to_reference_decomposed_fx,
prepare_fx,
)
from torch.export import export
from torch.nn import functional as F

from torch.testing import FileCheck
Expand Down Expand Up @@ -56,9 +57,11 @@ def forward(self, x, y):
)
m = _convert_to_reference_decomposed_fx(m)
config = EdgeCompileConfig(_check_ir_validity=False)
m = exir.capture(m, example_inputs, CaptureConfig()).to_edge(config=config)
m = to_edge(export(m, example_inputs), compile_config=config)
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
m = m.transform(QuantFusionPass(_fix_node_meta_val=True))
m = m.transform(
[QuantFusionPass(_fix_node_meta_val=True)], check_ir_validity=False
)
# check that we are using functional variant of q/dq/add
FileCheck().check(
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default"
Expand All @@ -67,12 +70,12 @@ def forward(self, x, y):
).check(
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default"
).run(
m.exported_program.graph_module.code
m.exported_program().graph_module.code
)
m = m.to_executorch()
# check that we are using out variant of q/dq/add
FileCheck().check("torch.ops.quantized_decomposed.add.out").run(
m.exported_program.graph_module.code
m.exported_program().graph_module.code
)

def test_reshape(self) -> None:
Expand All @@ -95,9 +98,11 @@ def forward(self, x, y):
m(*example_inputs)
m = _convert_to_reference_decomposed_fx(m)
config = EdgeCompileConfig(_check_ir_validity=False)
m = exir.capture(m, example_inputs, CaptureConfig()).to_edge(config=config)
m = to_edge(export(m, example_inputs), compile_config=config)
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
m = m.transform(QuantFusionPass(_fix_node_meta_val=True))
m = m.transform(
[QuantFusionPass(_fix_node_meta_val=True)], check_ir_validity=False
)
# check that we are using functional variant of q/dq/add/reshape
# make sure we only have two quant and one dequant since the q/dq around reshape
# should be fused
Expand All @@ -114,14 +119,14 @@ def forward(self, x, y):
1,
exactly=True,
).run(
m.exported_program.graph_module.code
m.exported_program().graph_module.code
)

m = m.to_executorch(exir.ExecutorchBackendConfig(remove_view_copy=False))
# check that we are using out variant of q/dq/add
FileCheck().check("torch.ops.quantized_decomposed.add.out").check(
"torch.ops.aten.view_copy.out"
).run(m.exported_program.graph_module.code)
).run(m.exported_program().graph_module.code)

def test_slice(self) -> None:
"""We don't proactively quantize slice today, but we'll fuse the dq-slice-q
Expand Down Expand Up @@ -150,9 +155,11 @@ def forward(self, x, y):
)
m = _convert_to_reference_decomposed_fx(m)
config = EdgeCompileConfig(_check_ir_validity=False)
m = exir.capture(m, example_inputs, CaptureConfig()).to_edge(config=config)
m = to_edge(export(m, example_inputs), compile_config=config)
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
m = m.transform(QuantFusionPass(_fix_node_meta_val=True))
m = m.transform(
[QuantFusionPass(_fix_node_meta_val=True)], check_ir_validity=False
)
# check that we are using functional variant of q/dq/add/slice
# make sure we only have one quant and one dequant since the q/dq around slice
# should be fused
Expand All @@ -169,14 +176,14 @@ def forward(self, x, y):
).check(
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default"
).run(
m.exported_program.graph_module.code
m.exported_program().graph_module.code
)

m = m.to_executorch()
# check that we are using out variant of add and slice_copy
FileCheck().check("torch.ops.quantized_decomposed.add.out").check(
"torch.ops.aten.slice_copy.Tensor_out"
).run(m.dump_graph_module().code)
).run(m.exported_program().graph_module.code)

def test_cat(self) -> None:
class M(torch.nn.Module):
Expand All @@ -197,9 +204,9 @@ def forward(self, x, y):
m(*example_inputs)
m = _convert_to_reference_decomposed_fx(m)
config = EdgeCompileConfig(_check_ir_validity=False)
m = exir.capture(m, example_inputs, CaptureConfig()).to_edge(config=config)
m = to_edge(export(m, example_inputs), compile_config=config)
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
m = m.transform(QuantFusionPass())
m = m.transform([QuantFusionPass()], check_ir_validity=False)
# check that we are using functional variant of q/dq/cat
FileCheck().check_count(
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default",
Expand All @@ -210,7 +217,7 @@ def forward(self, x, y):
1,
exactly=True,
).run(
m.exported_program.graph_module.code
m.exported_program().graph_module.code
)

m = m.to_executorch()
Expand All @@ -224,7 +231,7 @@ def forward(self, x, y):
).check("torch.ops.aten.cat.out").check_count(
"torch.ops.quantized_decomposed.dequantize_per_tensor.out", 1, exactly=True
).run(
m.dump_graph_module().code
m.exported_program().graph_module.code
)

def test_embedding_byte(self) -> None:
Expand Down Expand Up @@ -292,16 +299,18 @@ def forward(self, indices):
_check_ir_validity=False,
_use_edge_ops=True,
)
m = exir.capture(m, example_inputs).to_edge(config=compile_config)
m = to_edge(export(m, example_inputs), compile_config=compile_config)
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
m = m.transform(QuantFusionPass(_fix_node_meta_val=True))
m = m.transform(
[QuantFusionPass(_fix_node_meta_val=True)], check_ir_validity=False
)
# check that we are using functional variant of q/dq/cat
FileCheck().check(
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_channel_default",
).check(
"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_byte_default"
).run(
m.exported_program.graph_module.code
m.exported_program().graph_module.code
)

# TODO: enable after the out variants of quantize_per_channel is supported
Expand Down Expand Up @@ -348,17 +357,18 @@ def forward(self, indices):
_check_ir_validity=False,
_use_edge_ops=True,
)
m = exir.capture(m, example_inputs).to_edge(config=compile_config)
m = to_edge(export(m, example_inputs), compile_config=compile_config)
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
m = m.transform(QuantFusionPass(_fix_node_meta_val=True))
m(*example_inputs)
m = m.transform(
[QuantFusionPass(_fix_node_meta_val=True)], check_ir_validity=False
)
# check that we are using functional variant of q/dq/cat
FileCheck().check(
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_channel_default",
).check(
"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_byte_default"
).run(
m.exported_program.graph_module.code
m.exported_program().graph_module.code
)

# TODO: enable after the out variants of quantize_per_channel is supported
Expand Down