Skip to content

[to_edge] Allow core aten op exception list #5237

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion exir/capture/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union

import torch

from executorch.exir.dynamic_shape import DynamicMemoryPlanningMode
from executorch.exir.pass_manager import PassType
from executorch.exir.passes import MemoryPlanningPass, ToOutVarPass
Expand Down Expand Up @@ -38,6 +39,10 @@ class EdgeCompileConfig:
_check_ir_validity: bool = True
# TODO(larryliu): remove this
_use_edge_ops: bool = True
# Allow core ATen ops check to be skipped for certain ops, but continue with the rest of the checks.
_core_aten_ops_exception_list: List[torch._ops.OpOverload] = field(
default_factory=list
)
_skip_type_promotion: bool = False
# TODO(gasoonjia): remove this
# TODO(T192537614): reenanle dim order as default
Expand Down
31 changes: 22 additions & 9 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,9 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
EXIRATenDialectVerifier()(ep.exported_program.graph_module)
except ExportError:
logging.info(
"If a particular operator failed core ATen IR check, please consider adding it to the exception list. "
"Add the operator to _core_aten_ops_exception_list in EdgeCompileConfig. This is the recommended way "
"to resolve this type of failure, so that the rest of the IR validation check can still be performed.\n"
"If you'd like to disable IR validation checking, please set _check_ir_validity in EdgeCompileConfig, "
"like *.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))."
)
Expand All @@ -590,7 +593,11 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
module_call_graph=ep.exported_program.module_call_graph,
example_inputs=ep.exported_program.example_inputs,
constants=ep.exported_program.constants,
verifiers=[get_aten_verifier(enable=config._check_ir_validity)],
verifiers=[
get_aten_verifier(
config=config,
)
],
),
False,
)
Expand Down Expand Up @@ -698,10 +705,13 @@ def _generate_edge_program(
program: ExportedProgram,
ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None,
) -> ExportedProgram:

if config._check_ir_validity:
try:
EXIRATenDialectVerifier(ops_set_to_not_decompose)(program.graph_module)
EXIRATenDialectVerifier(
edge_compile_config=config,
class_only=False,
exception_list=ops_set_to_not_decompose,
)(program.graph_module)
except ExportError as e:
logging.info(f"Input program {name} is not in ATen dialect.")
raise e
Expand Down Expand Up @@ -1020,13 +1030,8 @@ def to_edge_transform_and_lower(
edge_manager = edge_manager.to_backend({name: curr_partitioner})

for name, program in edge_manager._edge_programs.items():
if config._check_ir_validity:
EXIREdgeDialectVerifier(
edge_compile_config=config,
class_only=True,
)()(program.graph_module)

ops_set_to_not_decompose = set()
ops_set_to_not_decompose: Set[torch._ops.OpOverload] = set()
partitioners = partitioner.get(name, [])
for curr_partitioner in partitioners:
curr_op_set, check_op_support = curr_partitioner.ops_to_not_decompose(
Expand All @@ -1042,6 +1047,13 @@ def to_edge_transform_and_lower(
generate_error=True,
)

if config._check_ir_validity:
EXIREdgeDialectVerifier(
edge_compile_config=config,
class_only=True,
exception_list=list(ops_set_to_not_decompose),
)()(program.graph_module)

return edge_manager


Expand Down Expand Up @@ -1107,6 +1119,7 @@ def __init__(
self.compile_config = compile_config or EdgeCompileConfig()
if not isinstance(edge_programs, dict):
edge_programs = {"forward": edge_programs}

for name, program in edge_programs.items():
try:
EXIREdgeDialectVerifier(
Expand Down
37 changes: 36 additions & 1 deletion exir/program/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,11 +531,14 @@ def test_edge_manager_dialect(self):
)
self.assertTrue(edge_manager.exported_program().dialect == "EDGE")

def _test_edge_dialect_verifier(self, callable, validate_ir=True):
def _test_edge_dialect_verifier(
self, callable, validate_ir=True, exception_list=None
):
from executorch.exir import EdgeCompileConfig

edge_compile_config = EdgeCompileConfig(
_check_ir_validity=validate_ir,
_core_aten_ops_exception_list=exception_list,
)
# pre-autograd export. eventually this will become torch.export
one = torch.ones(1, dtype=torch.float)
Expand Down Expand Up @@ -681,3 +684,35 @@ def count_nodes(graph_module, target):
),
1,
)

def test_edge_dialect_non_core_aten_ops(self):
class LinalgNorm(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.linalg.norm(x)

from torch._export.verifier import SpecViolationError

input = torch.arange(9, dtype=torch.float) - 4
ep = torch.export.export(LinalgNorm(), (input,))

# aten::linalg_norm is not a core op, so it should error out
with self.assertRaises(SpecViolationError):
_ = to_edge(ep, compile_config=EdgeCompileConfig(_check_ir_validity=True))

# with exception list, it should not error out
try:
# This should not raise error
_ = to_edge(
ep,
compile_config=EdgeCompileConfig(
_check_ir_validity=True,
_core_aten_ops_exception_list=[
torch.ops.aten.linalg_vector_norm.default
],
),
)
except SpecViolationError:
self.fail("Should not error out on linalg_vector_norm op")
100 changes: 66 additions & 34 deletions exir/verification/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,6 @@ def _check_valid_dim_order_ops(op, use_dim_order) -> None:
class EXIRATenDialectVerifierBase(Verifier):
dialect = "OLD_EXIR_ATEN_DISABLED"

def __init__(
self, exception_list: Optional[List[torch._ops.OpOverload]] = None
) -> None:
super().__init__()
self._exception_list = exception_list if exception_list else []

def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
return (
torch.fx.GraphModule,
Expand All @@ -78,38 +72,68 @@ def __call__(self, *args, **kwargs):
raise RuntimeError("")


class EXIRATenDialectVerifier(EXIRATenDialectVerifierBase):
dialect = "OLD_EXIR_ATEN"
def EXIRATenDialectVerifier( # noqa: C901
edge_compile_config: Optional[EdgeCompileConfig] = None,
class_only: bool = False,
exception_list: Optional[List[torch._ops.OpOverload]] = None,
):
"""
Returns a verifier class that runs ATen dialect specific checks on the graph module.
"""
# merge the exception list from edge_compile_config and exception_list
if edge_compile_config and edge_compile_config._core_aten_ops_exception_list:
exception_list = edge_compile_config._core_aten_ops_exception_list + (
exception_list or []
)

def _get_exception_list(self) -> List[torch._ops.OpOverload]:
exception_list = [
torch.ops.aten.mkldnn_rnn_layer.default,
torch.ops.aten._upsample_bilinear2d_aa.default,
torch.ops.aten.quantize_per_tensor.default,
torch.ops.aten.dequantize.self,
torch.ops.aten.max.default, # TODO(T188268054)
torch.ops.aten.min.default, # TODO(T188268054)
torch.ops.aten.full_like.default, # TODO(T183507359)
]
exception_list += self._exception_list
class _EXIRATenDialectVerifier(EXIRATenDialectVerifierBase):
dialect = "OLD_EXIR_ATEN"

return exception_list
def __init__(self) -> None:
super().__init__()
# Note: here we are using the exception list passed from EXIRATenDialectVerifier function!
self._exception_list = exception_list if exception_list else []

def check_valid_op(self, op):
if isinstance(op, OpOverload):
# TODO These special ops should be removable easily.
if op.namespace != "aten" or op in self._get_exception_list():
return
if torch.Tag.core not in op.tags and torch.Tag.view_copy not in op.tags:
# NOTE(qihan): whether view_copy operators are marked as canonical is still under
# discussion.
raise SpecViolationError(
f"Operator {op.__module__}.{op.__name__} is not Aten Canonical."
)
def _get_exception_list(self) -> List[torch._ops.OpOverload]:
exception_list = [
torch.ops.aten.mkldnn_rnn_layer.default,
torch.ops.aten._upsample_bilinear2d_aa.default,
torch.ops.aten.quantize_per_tensor.default,
torch.ops.aten.dequantize.self,
torch.ops.aten.max.default, # TODO(T188268054)
torch.ops.aten.min.default, # TODO(T188268054)
torch.ops.aten.full_like.default, # TODO(T183507359)
]
exception_list += self._exception_list

return exception_list

def get_aten_verifier(enable: bool = True):
return EXIRATenDialectVerifier if enable else EXIRATenDialectVerifierBase
def check_valid_op(self, op):
if isinstance(op, OpOverload):
# TODO These special ops should be removable easily.
if op.namespace != "aten" or op in self._get_exception_list():
return
if torch.Tag.core not in op.tags and torch.Tag.view_copy not in op.tags:
# NOTE(qihan): whether view_copy operators are marked as canonical is still under
# discussion.
raise SpecViolationError(
f"Operator {op.__module__}.{op.__name__} is not Aten Canonical."
)

ret = _EXIRATenDialectVerifier
if not class_only:
ret = ret()
return ret


def get_aten_verifier(config: EdgeCompileConfig):
return (
EXIRATenDialectVerifier(
class_only=True, exception_list=config._core_aten_ops_exception_list
)
if config._check_ir_validity
else EXIRATenDialectVerifierBase
)


def _get_inputs(graph_module: GraphModule) -> List[Optional[FakeTensor]]:
Expand Down Expand Up @@ -160,6 +184,12 @@ def EXIREdgeDialectVerifier( # noqa: C901
class_only: bool = False,
exception_list: Optional[List[torch._ops.OpOverload]] = None,
):
# merge the exception list from edge_compile_config and exception_list
if edge_compile_config and edge_compile_config._core_aten_ops_exception_list:
exception_list = edge_compile_config._core_aten_ops_exception_list + (
exception_list or []
)

class _EXIREdgeDialectVerifier(Verifier):
dialect = "EDGE"

Expand All @@ -170,7 +200,9 @@ def __init__(self) -> None:
self.check_edge_ops = _edge_compile_config._use_edge_ops
self.use_dim_order = not _edge_compile_config._skip_dim_order

self.aten_op_verifier = EXIRATenDialectVerifier(exception_list)
self.aten_op_verifier = EXIRATenDialectVerifier(
exception_list=exception_list
)
self.check_valid_aten_op = self.aten_op_verifier.check_valid_op

if self.check_edge_ops:
Expand Down
Loading