Skip to content

Commit 7942d2c

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Allow core aten op exception list (#5237)
Summary: Currently when a non-core ATen operator shows up in the exported graph, `to_edge()` will fail and the only option is to disable IR validity check by setting `_check_ir_validity=False`. However this is unsafe to do, instead we should still run the rest of the checks. This PR adds support to allow users to bypass core ATen ops check, by passing a list of non-core ATen ops into `to_edge()`. Note that: * This is different than `ops_set_to_not_decompose` in `to_edge_transform_and_lower`, as the ops in `_core_aten_ops_exception_list` are not intended to be kept but more likely showing up because of missing decompositions or missing core ATen tag in `native_functions.yaml`. For this reason, we are combining two lists (`ops_set_to_not_decompose` and `_core_aten_ops_exception_list`) and pass to verifier. * I updated the error log to encourage people to use `_core_aten_ops_exception_list` instead of using `_check_ir_validity=False`. Pull Request resolved: #5237 Test Plan: Added unit test Reviewed By: tarun292 Differential Revision: D62469015 Pulled By: larryliu0820 fbshipit-source-id: 1abb1b4fbbfdf3eb5e64e82e2035c7f93cf5b153
1 parent d423131 commit 7942d2c

File tree

4 files changed

+130
-45
lines changed

4 files changed

+130
-45
lines changed

exir/capture/_config.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-unsafe
8-
98
from dataclasses import dataclass, field
109
from typing import Dict, List, Optional, Union
1110

11+
import torch
12+
1213
from executorch.exir.dynamic_shape import DynamicMemoryPlanningMode
1314
from executorch.exir.pass_manager import PassType
1415
from executorch.exir.passes import MemoryPlanningPass, ToOutVarPass
@@ -38,6 +39,10 @@ class EdgeCompileConfig:
3839
_check_ir_validity: bool = True
3940
# TODO(larryliu): remove this
4041
_use_edge_ops: bool = True
42+
# Allow core ATen ops check to be skipped for certain ops, but continue with the rest of the checks.
43+
_core_aten_ops_exception_list: List[torch._ops.OpOverload] = field(
44+
default_factory=list
45+
)
4146
_skip_type_promotion: bool = False
4247
# TODO(gasoonjia): remove this
4348
# TODO(T192537614): reenanle dim order as default

exir/program/_program.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,9 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
573573
EXIRATenDialectVerifier()(ep.exported_program.graph_module)
574574
except ExportError:
575575
logging.info(
576+
"If a particular operator failed core ATen IR check, please consider adding it to the exception list. "
577+
"Add the operator to _core_aten_ops_exception_list in EdgeCompileConfig. This is the recommended way "
578+
"to resolve this type of failure, so that the rest of the IR validation check can still be performed.\n"
576579
"If you'd like to disable IR validation checking, please set _check_ir_validity in EdgeCompileConfig, "
577580
"like *.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))."
578581
)
@@ -590,7 +593,11 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
590593
module_call_graph=ep.exported_program.module_call_graph,
591594
example_inputs=ep.exported_program.example_inputs,
592595
constants=ep.exported_program.constants,
593-
verifiers=[get_aten_verifier(enable=config._check_ir_validity)],
596+
verifiers=[
597+
get_aten_verifier(
598+
config=config,
599+
)
600+
],
594601
),
595602
False,
596603
)
@@ -698,10 +705,13 @@ def _generate_edge_program(
698705
program: ExportedProgram,
699706
ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None,
700707
) -> ExportedProgram:
701-
702708
if config._check_ir_validity:
703709
try:
704-
EXIRATenDialectVerifier(ops_set_to_not_decompose)(program.graph_module)
710+
EXIRATenDialectVerifier(
711+
edge_compile_config=config,
712+
class_only=False,
713+
exception_list=ops_set_to_not_decompose,
714+
)(program.graph_module)
705715
except ExportError as e:
706716
logging.info(f"Input program {name} is not in ATen dialect.")
707717
raise e
@@ -1020,13 +1030,8 @@ def to_edge_transform_and_lower(
10201030
edge_manager = edge_manager.to_backend({name: curr_partitioner})
10211031

10221032
for name, program in edge_manager._edge_programs.items():
1023-
if config._check_ir_validity:
1024-
EXIREdgeDialectVerifier(
1025-
edge_compile_config=config,
1026-
class_only=True,
1027-
)()(program.graph_module)
10281033

1029-
ops_set_to_not_decompose = set()
1034+
ops_set_to_not_decompose: Set[torch._ops.OpOverload] = set()
10301035
partitioners = partitioner.get(name, [])
10311036
for curr_partitioner in partitioners:
10321037
curr_op_set, check_op_support = curr_partitioner.ops_to_not_decompose(
@@ -1042,6 +1047,13 @@ def to_edge_transform_and_lower(
10421047
generate_error=True,
10431048
)
10441049

1050+
if config._check_ir_validity:
1051+
EXIREdgeDialectVerifier(
1052+
edge_compile_config=config,
1053+
class_only=True,
1054+
exception_list=list(ops_set_to_not_decompose),
1055+
)()(program.graph_module)
1056+
10451057
return edge_manager
10461058

10471059

@@ -1107,6 +1119,7 @@ def __init__(
11071119
self.compile_config = compile_config or EdgeCompileConfig()
11081120
if not isinstance(edge_programs, dict):
11091121
edge_programs = {"forward": edge_programs}
1122+
11101123
for name, program in edge_programs.items():
11111124
try:
11121125
EXIREdgeDialectVerifier(

exir/program/test/test_program.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,11 +531,14 @@ def test_edge_manager_dialect(self):
531531
)
532532
self.assertTrue(edge_manager.exported_program().dialect == "EDGE")
533533

534-
def _test_edge_dialect_verifier(self, callable, validate_ir=True):
534+
def _test_edge_dialect_verifier(
535+
self, callable, validate_ir=True, exception_list=None
536+
):
535537
from executorch.exir import EdgeCompileConfig
536538

537539
edge_compile_config = EdgeCompileConfig(
538540
_check_ir_validity=validate_ir,
541+
_core_aten_ops_exception_list=exception_list,
539542
)
540543
# pre-autograd export. eventually this will become torch.export
541544
one = torch.ones(1, dtype=torch.float)
@@ -681,3 +684,35 @@ def count_nodes(graph_module, target):
681684
),
682685
1,
683686
)
687+
688+
def test_edge_dialect_non_core_aten_ops(self):
689+
class LinalgNorm(torch.nn.Module):
690+
def __init__(self):
691+
super().__init__()
692+
693+
def forward(self, x: torch.Tensor) -> torch.Tensor:
694+
return torch.linalg.norm(x)
695+
696+
from torch._export.verifier import SpecViolationError
697+
698+
input = torch.arange(9, dtype=torch.float) - 4
699+
ep = torch.export.export(LinalgNorm(), (input,))
700+
701+
# aten::linalg_norm is not a core op, so it should error out
702+
with self.assertRaises(SpecViolationError):
703+
_ = to_edge(ep, compile_config=EdgeCompileConfig(_check_ir_validity=True))
704+
705+
# with exception list, it should not error out
706+
try:
707+
# This should not raise error
708+
_ = to_edge(
709+
ep,
710+
compile_config=EdgeCompileConfig(
711+
_check_ir_validity=True,
712+
_core_aten_ops_exception_list=[
713+
torch.ops.aten.linalg_vector_norm.default
714+
],
715+
),
716+
)
717+
except SpecViolationError:
718+
self.fail("Should not error out on linalg_vector_norm op")

exir/verification/verifier.py

Lines changed: 66 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,6 @@ def _check_valid_dim_order_ops(op, use_dim_order) -> None:
5252
class EXIRATenDialectVerifierBase(Verifier):
5353
dialect = "OLD_EXIR_ATEN_DISABLED"
5454

55-
def __init__(
56-
self, exception_list: Optional[List[torch._ops.OpOverload]] = None
57-
) -> None:
58-
super().__init__()
59-
self._exception_list = exception_list if exception_list else []
60-
6155
def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
6256
return (
6357
torch.fx.GraphModule,
@@ -78,38 +72,68 @@ def __call__(self, *args, **kwargs):
7872
raise RuntimeError("")
7973

8074

81-
class EXIRATenDialectVerifier(EXIRATenDialectVerifierBase):
82-
dialect = "OLD_EXIR_ATEN"
75+
def EXIRATenDialectVerifier( # noqa: C901
76+
edge_compile_config: Optional[EdgeCompileConfig] = None,
77+
class_only: bool = False,
78+
exception_list: Optional[List[torch._ops.OpOverload]] = None,
79+
):
80+
"""
81+
Returns a verifier class that runs ATen dialect specific checks on the graph module.
82+
"""
83+
# merge the exception list from edge_compile_config and exception_list
84+
if edge_compile_config and edge_compile_config._core_aten_ops_exception_list:
85+
exception_list = edge_compile_config._core_aten_ops_exception_list + (
86+
exception_list or []
87+
)
8388

84-
def _get_exception_list(self) -> List[torch._ops.OpOverload]:
85-
exception_list = [
86-
torch.ops.aten.mkldnn_rnn_layer.default,
87-
torch.ops.aten._upsample_bilinear2d_aa.default,
88-
torch.ops.aten.quantize_per_tensor.default,
89-
torch.ops.aten.dequantize.self,
90-
torch.ops.aten.max.default, # TODO(T188268054)
91-
torch.ops.aten.min.default, # TODO(T188268054)
92-
torch.ops.aten.full_like.default, # TODO(T183507359)
93-
]
94-
exception_list += self._exception_list
89+
class _EXIRATenDialectVerifier(EXIRATenDialectVerifierBase):
90+
dialect = "OLD_EXIR_ATEN"
9591

96-
return exception_list
92+
def __init__(self) -> None:
93+
super().__init__()
94+
# Note: here we are using the exception list passed from EXIRATenDialectVerifier function!
95+
self._exception_list = exception_list if exception_list else []
9796

98-
def check_valid_op(self, op):
99-
if isinstance(op, OpOverload):
100-
# TODO These special ops should be removable easily.
101-
if op.namespace != "aten" or op in self._get_exception_list():
102-
return
103-
if torch.Tag.core not in op.tags and torch.Tag.view_copy not in op.tags:
104-
# NOTE(qihan): whether view_copy operators are marked as canonical is still under
105-
# discussion.
106-
raise SpecViolationError(
107-
f"Operator {op.__module__}.{op.__name__} is not Aten Canonical."
108-
)
97+
def _get_exception_list(self) -> List[torch._ops.OpOverload]:
98+
exception_list = [
99+
torch.ops.aten.mkldnn_rnn_layer.default,
100+
torch.ops.aten._upsample_bilinear2d_aa.default,
101+
torch.ops.aten.quantize_per_tensor.default,
102+
torch.ops.aten.dequantize.self,
103+
torch.ops.aten.max.default, # TODO(T188268054)
104+
torch.ops.aten.min.default, # TODO(T188268054)
105+
torch.ops.aten.full_like.default, # TODO(T183507359)
106+
]
107+
exception_list += self._exception_list
109108

109+
return exception_list
110110

111-
def get_aten_verifier(enable: bool = True):
112-
return EXIRATenDialectVerifier if enable else EXIRATenDialectVerifierBase
111+
def check_valid_op(self, op):
112+
if isinstance(op, OpOverload):
113+
# TODO These special ops should be removable easily.
114+
if op.namespace != "aten" or op in self._get_exception_list():
115+
return
116+
if torch.Tag.core not in op.tags and torch.Tag.view_copy not in op.tags:
117+
# NOTE(qihan): whether view_copy operators are marked as canonical is still under
118+
# discussion.
119+
raise SpecViolationError(
120+
f"Operator {op.__module__}.{op.__name__} is not Aten Canonical."
121+
)
122+
123+
ret = _EXIRATenDialectVerifier
124+
if not class_only:
125+
ret = ret()
126+
return ret
127+
128+
129+
def get_aten_verifier(config: EdgeCompileConfig):
130+
return (
131+
EXIRATenDialectVerifier(
132+
class_only=True, exception_list=config._core_aten_ops_exception_list
133+
)
134+
if config._check_ir_validity
135+
else EXIRATenDialectVerifierBase
136+
)
113137

114138

115139
def _get_inputs(graph_module: GraphModule) -> List[Optional[FakeTensor]]:
@@ -160,6 +184,12 @@ def EXIREdgeDialectVerifier( # noqa: C901
160184
class_only: bool = False,
161185
exception_list: Optional[List[torch._ops.OpOverload]] = None,
162186
):
187+
# merge the exception list from edge_compile_config and exception_list
188+
if edge_compile_config and edge_compile_config._core_aten_ops_exception_list:
189+
exception_list = edge_compile_config._core_aten_ops_exception_list + (
190+
exception_list or []
191+
)
192+
163193
class _EXIREdgeDialectVerifier(Verifier):
164194
dialect = "EDGE"
165195

@@ -170,7 +200,9 @@ def __init__(self) -> None:
170200
self.check_edge_ops = _edge_compile_config._use_edge_ops
171201
self.use_dim_order = not _edge_compile_config._skip_dim_order
172202

173-
self.aten_op_verifier = EXIRATenDialectVerifier(exception_list)
203+
self.aten_op_verifier = EXIRATenDialectVerifier(
204+
exception_list=exception_list
205+
)
174206
self.check_valid_aten_op = self.aten_op_verifier.check_valid_op
175207

176208
if self.check_edge_ops:

0 commit comments

Comments
 (0)