Skip to content

Commit 0e7a73c

Browse files
zhxchen17facebook-github-bot
authored andcommitted
[export] Expand verifier to be multiple on ExportedProgram (pytorch#130364)
Summary: X-link: pytorch/executorch#4184 Pull Request resolved: pytorch#130364 This diff updates the ExportedProgram class in PyTorch to allow for multiple verifiers to be attached to it. This is done by adding a new field to the ExportedProgram schema called "verifiers" which is a list of strings representing the names of the verifiers to be attached to the program. The verifiers are loaded using the "load_verifier" function which is defined in the "torch._export.serde.serialize" module. The "exported_program.dialect" field is also deprecated in favor of the "verifiers" field. Test Plan: CI Differential Revision: D59408546
1 parent a205a53 commit 0e7a73c

File tree

6 files changed

+58
-59
lines changed

6 files changed

+58
-59
lines changed

torch/_export/__init__.py

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,36 +19,13 @@
1919
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2020
from unittest.mock import patch
2121

22-
import sympy
23-
2422
import torch
25-
import torch._dynamo
2623
import torch.fx
2724
import torch.utils._pytree as pytree
2825

29-
from torch._decomp import core_aten_decompositions, get_decompositions
3026
from torch._dispatch.python import enable_python_dispatcher
31-
from torch._dynamo.exc import UserError, UserErrorType
32-
from torch._dynamo.source import ConstantSource
33-
from torch._export.non_strict_utils import make_constraints
34-
from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass
35-
from torch._functorch.aot_autograd import aot_export_module, GraphSignature
36-
from torch._functorch.eager_transforms import functionalize
37-
from torch._guards import detect_fake_mode
38-
from torch._inductor import config
39-
from torch._ops import OpOverload
40-
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
41-
from torch._subclasses.functional_tensor import FunctionalTensor
4227
from torch._utils_internal import log_export_usage
4328
from torch.export._tree_utils import reorder_kwargs
44-
from torch.export._unlift import _create_stateful_graph_module
45-
from torch.export.dynamic_shapes import _combine_args, Constraint, dims, dynamic_dim
46-
from torch.export.exported_program import (
47-
_disable_prexisiting_fake_mode,
48-
ExportedProgram,
49-
ModuleCallEntry,
50-
ModuleCallSignature,
51-
)
5229
from torch.export.graph_signature import (
5330
_sig_to_specs,
5431
ArgumentSpec,
@@ -64,14 +41,7 @@
6441
from torch.fx import traceback as fx_traceback
6542
from torch.fx._compatibility import compatibility
6643
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
67-
from torch.fx.experimental.symbolic_shapes import (
68-
ConstraintViolationError,
69-
GuardOnDataDependentSymNode,
70-
ShapeEnv,
71-
StrictMinMaxConstraint,
72-
)
7344
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
74-
from torch.utils._sympy.value_ranges import ValueRangeError, ValueRanges
7545

7646
from .wrappers import _wrap_submodules
7747

@@ -89,6 +59,8 @@ class ExportDynamoConfig:
8959
# is called multiple times.
9060
@lru_cache
9161
def capture_pre_autograd_graph_warning():
62+
from torch._inductor import config
63+
9264
log.warning("+============================+")
9365
log.warning("| !!! WARNING !!! |")
9466
log.warning("+============================+")
@@ -138,6 +110,10 @@ def capture_pre_autograd_graph(
138110
"""
139111
from torch.export._trace import _convert_input_to_fake, DEFAULT_EXPORT_DYNAMO_CONFIG, _ignore_backend_decomps
140112
from torch._utils_internal import export_api_rollout_check
113+
from torch._export.non_strict_utils import make_constraints
114+
from torch._subclasses.functional_tensor import FunctionalTensor
115+
from torch.export._unlift import _create_stateful_graph_module
116+
from torch.export.dynamic_shapes import _combine_args
141117

142118
capture_pre_autograd_graph_warning()
143119

@@ -232,12 +208,13 @@ def _eval(self, mode: bool = True):
232208

233209

234210
def save(
235-
ep: ExportedProgram,
211+
ep,
236212
f: Union[str, os.PathLike, io.BytesIO],
237213
*,
238214
extra_files: Optional[Dict[str, Any]] = None,
239215
opset_version: Optional[Dict[str, int]] = None,
240216
) -> None:
217+
from torch.export.exported_program import ExportedProgram
241218
if not isinstance(ep, ExportedProgram):
242219
raise TypeError(f"save() expects an ExportedProgram but got {type(ep)}")
243220

@@ -270,7 +247,7 @@ def load(
270247
*,
271248
extra_files: Optional[Dict[str, Any]] = None,
272249
expected_opset_version: Optional[Dict[str, int]] = None,
273-
) -> ExportedProgram:
250+
):
274251
if isinstance(f, (str, os.PathLike)):
275252
f = os.fspath(f)
276253

@@ -383,6 +360,7 @@ def aot_compile(
383360
"""
384361
from torch.export._trace import _export_to_torch_ir
385362
from torch._inductor.decomposition import select_decomp_table
363+
from torch._inductor import config
386364

387365
if config.is_predispatch:
388366
gm = torch.export._trace._export(f, args, kwargs, dynamic_shapes, pre_dispatch=True).module()

torch/_export/serde/schema.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torch._export.serde.union import _Union
99

1010
# NOTE: Please update this value if any modifications are made to the schema
11-
SCHEMA_VERSION = (5, 3)
11+
SCHEMA_VERSION = (6, 1)
1212
TREESPEC_VERSION = 1
1313

1414

@@ -376,4 +376,5 @@ class ExportedProgram:
376376
opset_version: Dict[str, int]
377377
range_constraints: Dict[str, RangeConstraint]
378378
schema_version: SchemaVersion
379-
dialect: str
379+
verifiers: List[str] = field(default_factory=list)
380+
dialect: str = "" # TODO deprecated

torch/_export/serde/schema.yaml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# @generated by update_schema.py
2-
# checksum<<f1da1027d3bccb23db1f8dde8e635e53c7ab67fde5248ede49a6b7a3402ce743>>
2+
# checksum<<41b8154c93fefb96215f252a1913371bf61ad45808995a04ead58236ec3b7720>>
33
Argument:
44
kind: union
55
fields:
@@ -102,8 +102,12 @@ ExportedProgram:
102102
type: Dict[str, RangeConstraint]
103103
schema_version:
104104
type: SchemaVersion
105+
verifiers:
106+
type: List[str]
107+
default: '[]'
105108
dialect:
106109
type: str
110+
default: ''
107111
GradientToParameterSpec:
108112
kind: struct
109113
fields:
@@ -425,6 +429,6 @@ UserOutputSpec:
425429
arg:
426430
type: Argument
427431
SCHEMA_VERSION:
428-
- 5
429-
- 3
432+
- 6
433+
- 1
430434
TREESPEC_VERSION: 1

torch/_export/serde/serialize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,7 +1384,7 @@ def serialize(self, exported_program: ep.ExportedProgram) -> _SerializedProgram:
13841384
major=SCHEMA_VERSION[0],
13851385
minor=SCHEMA_VERSION[1],
13861386
),
1387-
dialect=exported_program.dialect
1387+
verifiers=[v.dialect for v in exported_program.verifiers],
13881388
)
13891389

13901390
# Test canonical form is well defined.
@@ -2259,8 +2259,8 @@ def deserialize(
22592259
range_constraints=range_constraints,
22602260
module_call_graph=res.module_call_graph,
22612261
example_inputs=res.example_inputs,
2262-
verifier=load_verifier(exported_program.dialect),
22632262
constants=res.constants,
2263+
verifiers=[load_verifier(v) for v in exported_program.verifiers],
22642264
)
22652265

22662266

@@ -2830,7 +2830,7 @@ def replace_output(out):
28302830
opset_version=opset_version,
28312831
range_constraints=range_constraints,
28322832
schema_version=ep.schema_version,
2833-
dialect=ep.dialect
2833+
verifiers=ep.verifiers,
28342834
)
28352835

28362836

torch/_export/verifier.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
import math
44
import operator
55
from collections.abc import Iterable
6-
from typing import Any, Dict, final, List, Optional, Tuple, Type
6+
from typing import Any, Dict, final, List, Tuple, Type, TYPE_CHECKING
77

88
import torch
99
from torch._ops import HigherOrderOperator, OpOverload
1010
from torch._subclasses.fake_tensor import FakeTensor
11-
from torch.export.exported_program import ExportedProgram
1211
from torch.export.graph_signature import (
1312
CustomObjArgument,
1413
InputKind,
@@ -17,8 +16,9 @@
1716
TokenArgument,
1817
)
1918
from torch.fx import GraphModule
20-
from torch.fx.experimental.symbolic_shapes import SymBool, SymFloat, SymInt
2119

20+
if TYPE_CHECKING:
21+
from torch.export.exported_program import ExportedProgram
2222

2323
class SpecViolationError(Exception):
2424
pass
@@ -34,6 +34,8 @@ def _check_has_fake_tensor(node: torch.fx.Node) -> None:
3434

3535

3636
def _check_val(node: torch.fx.Node) -> None:
37+
from torch.fx.experimental.symbolic_shapes import SymBool, SymFloat, SymInt
38+
3739
def _check_correct_val(val):
3840
if val is None:
3941
return True
@@ -150,7 +152,7 @@ def check_additional(self, gm: GraphModule) -> None:
150152
pass
151153

152154
@final
153-
def check(self, ep: ExportedProgram) -> None:
155+
def check(self, ep: "ExportedProgram") -> None:
154156
self._check_graph_module(ep.graph_module)
155157
_verify_exported_program_signature(ep)
156158

@@ -429,7 +431,7 @@ def _verify_exported_program_signature(exported_program) -> None:
429431
)
430432

431433

432-
def load_verifier(dialect: str) -> Optional[Type[Verifier]]:
434+
def load_verifier(dialect: str) -> Type[Verifier]:
433435
if dialect == "ATEN" or dialect == "":
434-
return _VerifierMeta._registry.get(dialect)
436+
return _VerifierMeta._registry.get(dialect, Verifier)
435437
return _VerifierMeta._registry[dialect]

torch/export/exported_program.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Any,
1313
Callable,
1414
Dict,
15+
final,
1516
Iterator,
1617
List,
1718
Optional,
@@ -39,6 +40,8 @@
3940

4041
import torch
4142
import torch.utils._pytree as pytree
43+
44+
from torch._export.verifier import Verifier
4245
from torch._subclasses.functional_tensor import FunctionalTensor
4346

4447
from torch.export._tree_utils import is_equivalent, reorder_kwargs
@@ -66,7 +69,6 @@
6669
TokenArgument,
6770
)
6871

69-
7072
__all__ = [
7173
"ExportedProgram",
7274
"ModuleCallEntry",
@@ -637,13 +639,15 @@ def __init__(
637639
range_constraints: "Dict[sympy.Symbol, Any]",
638640
module_call_graph: List[ModuleCallEntry],
639641
example_inputs: Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] = None,
640-
verifier: Optional[Type[Any]] = None, # TODO Change typing hint to Verifier.
642+
verifier: Optional[Type[Any]] = None, # TODO Deprecate this.
641643
tensor_constants: Optional[
642644
Dict[str, torch.Tensor]
643645
] = None, # TODO: deprecate this
644646
constants: Optional[
645647
Dict[str, Union[torch.Tensor, FakeScriptObject, torch._C.ScriptObject]]
646648
] = None,
649+
*,
650+
verifiers: Optional[List[Type[Verifier]]] = None,
647651
):
648652
# Remove codegen related things from the graph. It should just be a flat graph.
649653
graph._codegen = torch.fx.graph.CodeGen()
@@ -661,14 +665,17 @@ def __init__(
661665
self._constants = tensor_constants or constants or {}
662666
assert self._constants is not None
663667

664-
from torch._export.verifier import Verifier
665-
666-
if verifier is None:
667-
verifier = Verifier
668-
assert issubclass(verifier, Verifier)
669-
self._verifier = verifier
668+
# TODO Clean up this after we bump executorch's pin.
669+
assert verifier is None or verifiers is None
670+
if verifiers is None:
671+
if verifier is None:
672+
verifiers = [Verifier]
673+
else:
674+
verifiers = [verifier]
675+
assert all(issubclass(v, Verifier) for v in verifiers)
676+
self._verifiers = verifiers
670677
# Validate should be always the last step of the constructor.
671-
self.verifier().check(self)
678+
self._validate()
672679

673680
@property
674681
@compatibility(is_backward_compatible=False)
@@ -759,13 +766,18 @@ def call_spec(self):
759766
@property
760767
@compatibility(is_backward_compatible=False)
761768
def verifier(self) -> Any:
762-
return self._verifier
769+
return self._verifiers[0]
763770

764771
@property
765772
@compatibility(is_backward_compatible=False)
766773
def dialect(self) -> str:
767-
assert self._verifier is not None
768-
return self._verifier.dialect
774+
assert self._verifiers is not None
775+
return self._verifiers[0].dialect
776+
777+
@property
778+
@compatibility(is_backward_compatible=False)
779+
def verifiers(self):
780+
return self._verifiers
769781

770782
@property
771783
@compatibility(is_backward_compatible=False)
@@ -1079,8 +1091,10 @@ def _check_input_constraints(self, flat_args_with_path):
10791091
input_placeholders, flat_args_with_path, self.range_constraints
10801092
)
10811093

1094+
@final
10821095
def _validate(self):
1083-
self.verifier().check(self)
1096+
for v in self.verifiers:
1097+
v().check(self)
10841098

10851099
# TODO(zhxchen17) Formalize this.
10861100
def _update(

0 commit comments

Comments
 (0)