Skip to content

Commit 587f2f8

Browse files
authored
[jit] Remove more reference to TorchScript (#10856)
Summary: Remove some occurrences in comments. Test Plan: None Reviewers: Subscribers: Tasks: Tags: ### Summary [PLEASE REMOVE] See [CONTRIBUTING.md's Pull Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests) for ExecuTorch PR guidelines. [PLEASE REMOVE] If this PR closes an issue, please add a `Fixes #<issue-id>` line. [PLEASE REMOVE] If this PR introduces a fix or feature that should be the upcoming release notes, please add a "Release notes: <area>" label. For a list of available release notes labels, check out [CONTRIBUTING.md's Pull Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests). ### Test plan [PLEASE REMOVE] How did you test this PR? Please write down any manual commands you used and note down tests that you have written if applicable.
1 parent f39e694 commit 587f2f8

File tree

5 files changed

+9
-292
lines changed

5 files changed

+9
-292
lines changed

exir/dialects/_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def __getattr__(self, op_name):
100100
parent_packet = getattr(self._op_namespace, op_name)
101101
except AttributeError as e:
102102
# Turn this into AttributeError so getattr(obj, key, default)
103-
# works (this is called by TorchScript with __origin__)
103+
# works
104104
raise AttributeError(
105105
f"'_OpNamespace' '{self._dialect}.{self._name}' object has no attribute '{op_name}'"
106106
) from e

exir/serde/TARGETS

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ python_library(
1212
"schema_check.py",
1313
"serialize.py",
1414
"union.py",
15-
"upgrade.py",
1615
],
1716
deps = [
1817
"fbsource//third-party/pypi/sympy:sympy",

exir/serde/export_serialize.py

Lines changed: 4 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@
114114
"ExportedProgramDeserializer",
115115
]
116116

117-
from .upgrade import GraphModuleOpUpgrader
118117

119118
log = logging.getLogger(__name__)
120119

@@ -2220,12 +2219,8 @@ def deserialize_module_call_graph(
22202219

22212220

22222221
class ExportedProgramDeserializer:
2223-
def __init__(self, expected_opset_version: Optional[Dict[str, int]] = None):
2224-
self.expected_opset_version: Dict[str, int] = {}
2225-
if expected_opset_version:
2226-
self.expected_opset_version.update(expected_opset_version)
2227-
if "aten" not in self.expected_opset_version:
2228-
self.expected_opset_version["aten"] = torch._C._get_max_operator_version()
2222+
def __init__(self):
2223+
pass
22292224

22302225
def deserialize_range_constraints(
22312226
self,
@@ -2278,13 +2273,6 @@ def deserialize(
22782273
symbol_name_to_range,
22792274
res.names_to_symbols,
22802275
)
2281-
model_opset_version: Optional[Dict[str, int]] = exported_program.opset_version
2282-
self._validate_model_opset_version(model_opset_version)
2283-
2284-
upgrader = GraphModuleOpUpgrader(
2285-
self.expected_opset_version, model_opset_version
2286-
)
2287-
22882276
exported_program = ep.ExportedProgram(
22892277
root=res.graph_module,
22902278
graph=res.graph_module.graph,
@@ -2296,56 +2284,7 @@ def deserialize(
22962284
verifier=load_verifier(exported_program.dialect),
22972285
constants=res.constants,
22982286
)
2299-
return upgrader.upgrade(exported_program)
2300-
2301-
def _validate_model_opset_version(
2302-
self, model_opset_version: Optional[Dict[str, int]]
2303-
):
2304-
"""Compare model_opset_version with expected_opset_version and raise error if we can't resolve the version
2305-
difference.
2306-
E.g., model_opset_version = {"aten": 3, "custom": 4}
2307-
expected_opset_version = {"aten": 4, "custom": 4}
2308-
This means we can use an upgrader for ATen to reconcile the deserialized model.
2309-
2310-
The logic of this method:
2311-
2312-
For common op namespaces:
2313-
1. if model version < expected version, this case can be handled by upgraders.
2314-
2. if model version > expected version, we need downgraders but not implemented yet.
2315-
3. if model version == expected version, we don't need extra handling.
2316-
2317-
For op namespace only in model_opset_version, we should give a warning because it is missing from
2318-
expected_opset_version.
2319-
"""
2320-
if not model_opset_version:
2321-
raise RuntimeError("Serialized model should have opset version.")
2322-
common_namespaces = {
2323-
key for key in model_opset_version if key in self.expected_opset_version
2324-
}
2325-
for namespace in common_namespaces:
2326-
model_version = model_opset_version[namespace]
2327-
assert isinstance(
2328-
model_version, int
2329-
), f"model_opset_version value should be int, got {model_version}"
2330-
2331-
compiler_version = self.expected_opset_version[namespace]
2332-
assert isinstance(
2333-
compiler_version, int
2334-
), f"expected_opset_version value should be int, got {compiler_version}"
2335-
2336-
# TODO(larryliu0820): Add support for upgrader & downgrader
2337-
if model_version != compiler_version:
2338-
raise NotImplementedError(
2339-
f"Model opset version {model_opset_version} doesn't match to compiler opset version "
2340-
f"{self.expected_opset_version}! Upgrader/downgrader is not implemented yet."
2341-
)
2342-
for namespace in model_opset_version:
2343-
if namespace in common_namespaces:
2344-
continue
2345-
log.warning(
2346-
"Compiler doesn't have a version table for op namespace: {ns}. ",
2347-
extra={"ns": namespace},
2348-
)
2287+
return exported_program
23492288

23502289

23512290
class EnumEncoder(json.JSONEncoder):
@@ -2435,15 +2374,14 @@ def _dict_to_dataclass(cls, data):
24352374

24362375
def deserialize(
24372376
artifact: SerializedArtifact,
2438-
expected_opset_version: Optional[Dict[str, int]] = None,
24392377
) -> ep.ExportedProgram:
24402378
assert isinstance(artifact.exported_program, bytes)
24412379
exported_program_str = artifact.exported_program.decode("utf-8")
24422380
exported_program_dict = json.loads(exported_program_str)
24432381
serialized_exported_program = _dict_to_dataclass(
24442382
ExportedProgram, exported_program_dict
24452383
)
2446-
return ExportedProgramDeserializer(expected_opset_version).deserialize(
2384+
return ExportedProgramDeserializer().deserialize(
24472385
serialized_exported_program,
24482386
artifact.state_dict,
24492387
artifact.constants,

exir/serde/serialize.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from executorch.exir.lowered_backend_module import (
3333
LoweredBackendModule as ExirLoweredBackendModule,
3434
)
35-
from executorch.exir.serde.export_serialize import GraphModuleOpUpgrader, SerializeError
35+
from executorch.exir.serde.export_serialize import SerializeError
3636
from executorch.exir.serde.schema import (
3737
CompileSpec,
3838
LoweredBackendModule as SerdeLoweredBackendModule,
@@ -617,12 +617,6 @@ def deserialize(
617617
symbol_name_to_range,
618618
res.names_to_symbols,
619619
)
620-
model_opset_version: Optional[Dict[str, int]] = exported_program.opset_version
621-
self._validate_model_opset_version(model_opset_version)
622-
623-
upgrader = GraphModuleOpUpgrader(
624-
self.expected_opset_version, model_opset_version
625-
)
626620

627621
dummy_g = torch.fx.Graph()
628622
dummy_g.output(())
@@ -656,7 +650,7 @@ def deserialize(
656650
node.target,
657651
getattr(res.graph_module, node.target),
658652
)
659-
return upgrader.upgrade(exported_program)
653+
return exported_program
660654

661655

662656
def serialize(
@@ -683,15 +677,14 @@ def serialize(
683677

684678
def deserialize(
685679
artifact: export_serialize.SerializedArtifact,
686-
expected_opset_version: Optional[Dict[str, int]] = None,
687680
) -> ep.ExportedProgram:
688681
assert isinstance(artifact.exported_program, bytes)
689682
exported_program_str = artifact.exported_program.decode("utf-8")
690683
exported_program_dict = json.loads(exported_program_str)
691684
serialized_exported_program = export_serialize._dict_to_dataclass(
692685
schema.ExportedProgram, exported_program_dict
693686
)
694-
return ExportedProgramDeserializer(expected_opset_version).deserialize(
687+
return ExportedProgramDeserializer().deserialize(
695688
serialized_exported_program,
696689
artifact.state_dict,
697690
artifact.constants,
@@ -735,7 +728,6 @@ def load(
735728
f: Union[str, os.PathLike[str], io.BytesIO],
736729
*,
737730
extra_files: Optional[Dict[str, Any]] = None,
738-
expected_opset_version: Optional[Dict[str, int]] = None,
739731
) -> ep.ExportedProgram:
740732
if isinstance(f, (str, os.PathLike)):
741733
f = os.fspath(str(f))
@@ -796,6 +788,6 @@ def load(
796788
)
797789

798790
# Deserialize ExportedProgram
799-
ep = deserialize(artifact, expected_opset_version)
791+
ep = deserialize(artifact)
800792

801793
return ep

0 commit comments

Comments
 (0)