Skip to content

remove exir.capture from a bunch of tests #2719

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 42 additions & 209 deletions exir/backend/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch
from executorch import exir
from executorch.exir import CaptureConfig, to_edge
from executorch.exir import to_edge
from executorch.exir.backend.backend_api import to_backend
from executorch.exir.backend.partitioner import Partitioner, PartitionResult
from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
Expand All @@ -21,31 +21,14 @@
get_non_lowered_nodes,
is_identical_graph,
print_delegated_graph,
remove_first_quant_and_last_dequant,
replace_quantized_partition_with_op,
)

from executorch.exir.dialects._ops import bind_pattern_to_op, ops as exir_ops
from pandas.testing import assert_frame_equal
from torch.ao.quantization import get_default_qconfig # @manual
from torch.ao.quantization.backend_config.executorch import (
get_executorch_backend_config,
)
from torch.ao.quantization.qconfig_mapping import (
_get_symmetric_qnnpack_qconfig_mapping,
QConfigMapping,
)
from torch.ao.quantization.quantize_fx import (
_convert_to_reference_decomposed_fx,
prepare_fx,
)
from torch.export import ExportedProgram
from torch.export import export, ExportedProgram
from torch.fx import symbolic_trace
from torch.fx.passes.utils.fuser_utils import legalize_graph
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
from torch.library import Library
from torch.testing import FileCheck

T_QuantPerTensor = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
T_DQuantPerTensor = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
Expand Down Expand Up @@ -115,22 +98,24 @@ def forward(self, x, y):
return x + 1, y + 2

graph_module_1: torch.fx.GraphModule = (
exir.capture(
MyModule1(),
(torch.rand(3, 4), torch.rand(3, 4)),
CaptureConfig(),
to_edge(
export(
MyModule1(),
(torch.rand(3, 4), torch.rand(3, 4)),
)
)
.to_edge()
.exported_program.graph_module
.exported_program()
.graph_module
)
graph_module_2: torch.fx.GraphModule = (
exir.capture(
MyModule2(),
(torch.rand(3, 4), torch.rand(3, 4)),
CaptureConfig(),
to_edge(
export(
MyModule2(),
(torch.rand(3, 4), torch.rand(3, 4)),
)
)
.to_edge()
.exported_program.graph_module
.exported_program()
.graph_module
)
is_matched = is_identical_graph(graph_module_1, graph_module_2)
self.assertFalse(is_matched)
Expand All @@ -149,40 +134,25 @@ def forward(self, x):

inputs = (torch.ones(3, 3),)

# Large model graph:
# opcode name target args kwargs
# ------------- ----------------- ------------------ ------------------------------------------- --------
# placeholder ph_0 ph_0 () {}
# get_attr _param_constant0 _param_constant0 () {}
# call_function add_tensor aten.add.Tensor (ph_0, _param_constant0) {}
# get_attr _param_constant1 _param_constant1 () {}
# get_attr _tensor_constant0 _tensor_constant0 () {}
# call_function addmm_default aten.addmm.default (_param_constant1, ph_0, _tensor_constant0) {}
# output output output ([add_tensor, addmm_default],) {}

large_model = (
exir.capture(
LargeModel(),
inputs,
CaptureConfig(),
to_edge(
export(
LargeModel(),
inputs,
),
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
)
.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
.exported_program.graph_module
.exported_program()
.graph_module
)

# Pattern graph:
# opcode name target args kwargs
# ------------- ----------------- ------------------ ------------------------------------------- --------
# placeholder ph_0 ph_0 () {}
# get_attr _param_constant0 _param_constant0 () {}
# get_attr _tensor_constant0 _tensor_constant0 () {}
# call_function addmm_default aten.addmm.default (_param_constant0, ph_0, _tensor_constant0) {}
# output output output ([addmm_default],) {}

pattern = (
exir.capture(torch.nn.Linear(3, 3), inputs, CaptureConfig())
.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
.exported_program.graph_module.graph
to_edge(
export(torch.nn.Linear(3, 3), inputs),
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
)
.exported_program()
.graph_module.graph
)

subgraph_matcher = SubgraphMatcher(pattern)
Expand All @@ -191,65 +161,6 @@ def forward(self, x):
# Should find exact one match
self.assertEqual(len(match_result), 1)

def test_remove_first_quant_and_last_dequant(self):
qconfig_mapping = _get_symmetric_qnnpack_qconfig_mapping()
linear = torch.nn.Linear(3, 4).eval()

example_inputs = (torch.ones(1, 1, 3, dtype=torch.float),)
prepared_linear = prepare_fx(
linear,
qconfig_mapping,
example_inputs,
backend_config=get_executorch_backend_config(),
)

converted_linear: torch.fx.GraphModule = _convert_to_reference_decomposed_fx(
prepared_linear,
)

actual_static_quant_linear = (
exir.capture(
converted_linear,
example_inputs,
CaptureConfig(
enable_functionalization=False,
),
)
.to_edge(
exir.EdgeCompileConfig(
_check_ir_validity=False,
)
)
.exported_program.graph_module
)

# Original graph has exactly 3 dequantize ops and 3 quantize ops
FileCheck().check_count(
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default",
3,
exactly=True,
).run(actual_static_quant_linear.code)
FileCheck().check_count(
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default",
3,
exactly=True,
).run(actual_static_quant_linear.code)

# Remove first and last dequant in static quant
remove_first_quant_and_last_dequant(actual_static_quant_linear)

# Original graph has exactly 2 dequantize ops and 2 quantize ops
FileCheck().check_count(
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default",
2,
exactly=True,
).run(actual_static_quant_linear.code)
FileCheck().check_count(
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default",
2,
exactly=True,
).run(actual_static_quant_linear.code)

def test_invalid_partitioner_without_partitioner(self):
"""
Tests replacing literals with placeholders in the case there are
Expand All @@ -272,13 +183,10 @@ def partition(
tagged_exported_program=edge_exported_program, partition_tags=None
)

exported_program = exir.capture(
torch.nn.Linear(3, 3),
(torch.randn(3, 3),),
CaptureConfig(),
).to_edge(
exir.EdgeCompileConfig(
_check_ir_validity=False,
exported_program = to_edge(
export(
torch.nn.Linear(3, 3),
(torch.randn(3, 3),),
)
)

Expand All @@ -287,7 +195,7 @@ def partition(
AssertionError,
error_msg,
):
_ = to_backend(exported_program.exported_program, InvalidPartitioner())
_ = to_backend(exported_program.exported_program(), InvalidPartitioner())

test_lib = Library("test_lib", "DEF")

Expand All @@ -298,74 +206,6 @@ def partition(
def q_linear(x, weight, bias):
return x

def test_replace_quantized_partition_with_op(self):
class LinearModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 4)

def forward(self, input):
return self.linear(input)

linear_model = LinearModel()
example_inputs = (torch.ones(1, 1, 3, dtype=torch.float),)
prepared_linear = prepare_fx(
linear_model,
QConfigMapping().set_object_type(
torch.nn.Linear,
get_default_qconfig("qnnpack"),
),
example_inputs,
backend_config=get_executorch_backend_config(),
)

converted_linear: torch.fx.GraphModule = _convert_to_reference_decomposed_fx(
prepared_linear,
)

actual_static_quant_linear = (
exir.capture(
converted_linear,
example_inputs,
CaptureConfig(
enable_functionalization=False,
),
)
.to_edge(
exir.EdgeCompileConfig(
_check_ir_validity=False,
),
)
.exported_program.graph_module
)

source_partitions_by_module = get_source_partitions(
actual_static_quant_linear.graph,
[torch.ao.nn.quantized.reference.modules.linear.Linear],
)

replace_quantized_partition_with_op(
actual_static_quant_linear,
list(source_partitions_by_module.values())[0][0],
torch.ops.test_lib.test_q_linear,
)

legalize_graph(actual_static_quant_linear)

FileCheck().check_count(
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default",
1,
exactly=True,
).run(actual_static_quant_linear.code)
FileCheck().check_count(
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default",
1,
exactly=True,
).run(actual_static_quant_linear.code)
FileCheck().check_count("test_lib.test_q_linear", 1, exactly=True).run(
actual_static_quant_linear.code
)

def test_get_non_lowered_nodes(self):
class Model(torch.nn.Module):
def __init__(self):
Expand All @@ -381,12 +221,9 @@ def forward(self, a, x, b):

m = Model()
inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
edge = exir.capture(m, inputs, exir.CaptureConfig()).to_edge()
edge.exported_program = to_backend(
edge.exported_program, AddMulPartitionerDemo()
)
edge.dump()
number_of_cpu_nodes = get_non_lowered_nodes(edge.exported_program.graph)
edge = to_edge(export(m, inputs))
edge = edge.to_backend(AddMulPartitionerDemo())
number_of_cpu_nodes = get_non_lowered_nodes(edge.exported_program().graph)
# Only sub is not not lowerable
self.assertEqual(len(number_of_cpu_nodes), 1)

Expand All @@ -405,11 +242,9 @@ def forward(self, a, x, b):

m = Model()
inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
edge = exir.capture(m, inputs, exir.CaptureConfig()).to_edge()
edge.exported_program = to_backend(
edge.exported_program, AddMulPartitionerDemo()
)
number_of_delegates = get_delegates(edge.exported_program.graph)
edge = to_edge(export(m, inputs))
edge = edge.to_backend(AddMulPartitionerDemo())
number_of_delegates = get_delegates(edge.exported_program().graph)
# there will be 2 delegates: (mm + add) -> sub -> (mm + add)
self.assertEqual(len(number_of_delegates), 2)

Expand All @@ -429,9 +264,7 @@ def forward(self, a, x, b):
m = Model()
inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))

edge = to_edge(torch.export.export(m, inputs)).to_backend(
AddMulPartitionerDemo()
)
edge = to_edge(export(m, inputs)).to_backend(AddMulPartitionerDemo())

graph_str = print_delegated_graph(edge.exported_program().graph_module)
self.assertIn(
Expand Down
22 changes: 2 additions & 20 deletions exir/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary")
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")

oncall("executorch")

python_library(
name = "lib",
srcs = [
"common.py",
],
deps = [
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir:schema",
],
)
Expand Down Expand Up @@ -162,35 +163,16 @@ python_unittest(
supports_static_listing = False,
deps = [
"fbsource//third-party/pypi/parameterized:parameterized",
":asr_joiner",
"//caffe2:torch",
"//executorch/backends/fb/qnnpack/partition:qnnpack_partitioner",
"//executorch/exir:lib",
"//executorch/exir:memory_planning",
"//executorch/exir:pass_base",
"//executorch/exir:pass_manager",
"//executorch/exir:print_program",
"//executorch/exir:schema",
"//executorch/exir/backend:backend_api",
"//executorch/exir/passes:lib",
"//executorch/exir/passes:sym_shape_eval_pass",
],
)

python_unittest(
name = "experimental",
srcs = [
"test_experimental.py",
],
deps = [
"//caffe2:torch",
"//executorch/exir:error",
"//executorch/exir:lib",
"//executorch/exir/experimental:export_pt2",
"//executorch/exir/experimental:lib",
],
)

python_unittest(
name = "passes",
srcs = [
Expand Down
Loading