Skip to content

Commit 1489e8d

Browse files
committed
Remove upgrade.py since it's not being used
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 1eb5188 commit 1489e8d

File tree

3 files changed

+4
-279
lines changed

3 files changed

+4
-279
lines changed

exir/serde/TARGETS

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ python_library(
1212
"schema_check.py",
1313
"serialize.py",
1414
"union.py",
15-
"upgrade.py",
1615
],
1716
deps = [
1817
"fbsource//third-party/pypi/sympy:sympy",

exir/serde/export_serialize.py

Lines changed: 4 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@
114114
"ExportedProgramDeserializer",
115115
]
116116

117-
from .upgrade import GraphModuleOpUpgrader
118117

119118
log = logging.getLogger(__name__)
120119

@@ -2220,12 +2219,8 @@ def deserialize_module_call_graph(
22202219

22212220

22222221
class ExportedProgramDeserializer:
2223-
def __init__(self, expected_opset_version: Optional[Dict[str, int]] = None):
2224-
self.expected_opset_version: Dict[str, int] = {}
2225-
if expected_opset_version:
2226-
self.expected_opset_version.update(expected_opset_version)
2227-
if "aten" not in self.expected_opset_version:
2228-
self.expected_opset_version["aten"] = torch._C._get_max_operator_version()
2222+
def __init__(self):
2223+
pass
22292224

22302225
def deserialize_range_constraints(
22312226
self,
@@ -2278,13 +2273,6 @@ def deserialize(
22782273
symbol_name_to_range,
22792274
res.names_to_symbols,
22802275
)
2281-
model_opset_version: Optional[Dict[str, int]] = exported_program.opset_version
2282-
self._validate_model_opset_version(model_opset_version)
2283-
2284-
upgrader = GraphModuleOpUpgrader(
2285-
self.expected_opset_version, model_opset_version
2286-
)
2287-
22882276
exported_program = ep.ExportedProgram(
22892277
root=res.graph_module,
22902278
graph=res.graph_module.graph,
@@ -2296,56 +2284,7 @@ def deserialize(
22962284
verifier=load_verifier(exported_program.dialect),
22972285
constants=res.constants,
22982286
)
2299-
return upgrader.upgrade(exported_program)
2300-
2301-
def _validate_model_opset_version(
2302-
self, model_opset_version: Optional[Dict[str, int]]
2303-
):
2304-
"""Compare model_opset_version with expected_opset_version and raise error if we can't resolve the version
2305-
difference.
2306-
E.g., model_opset_version = {"aten": 3, "custom": 4}
2307-
expected_opset_version = {"aten": 4, "custom": 4}
2308-
This means we can use an upgrader for ATen to reconcile the deserialized model.
2309-
2310-
The logic of this method:
2311-
2312-
For common op namespaces:
2313-
1. if model version < expected version, this case can be handled by upgraders.
2314-
2. if model version > expected version, we need downgraders but not implemented yet.
2315-
3. if model version == expected version, we don't need extra handling.
2316-
2317-
For op namespace only in model_opset_version, we should give a warning because it is missing from
2318-
expected_opset_version.
2319-
"""
2320-
if not model_opset_version:
2321-
raise RuntimeError("Serialized model should have opset version.")
2322-
common_namespaces = {
2323-
key for key in model_opset_version if key in self.expected_opset_version
2324-
}
2325-
for namespace in common_namespaces:
2326-
model_version = model_opset_version[namespace]
2327-
assert isinstance(
2328-
model_version, int
2329-
), f"model_opset_version value should be int, got {model_version}"
2330-
2331-
compiler_version = self.expected_opset_version[namespace]
2332-
assert isinstance(
2333-
compiler_version, int
2334-
), f"expected_opset_version value should be int, got {compiler_version}"
2335-
2336-
# TODO(larryliu0820): Add support for upgrader & downgrader
2337-
if model_version != compiler_version:
2338-
raise NotImplementedError(
2339-
f"Model opset version {model_opset_version} doesn't match to compiler opset version "
2340-
f"{self.expected_opset_version}! Upgrader/downgrader is not implemented yet."
2341-
)
2342-
for namespace in model_opset_version:
2343-
if namespace in common_namespaces:
2344-
continue
2345-
log.warning(
2346-
"Compiler doesn't have a version table for op namespace: {ns}. ",
2347-
extra={"ns": namespace},
2348-
)
2287+
return exported_program
23492288

23502289

23512290
class EnumEncoder(json.JSONEncoder):
@@ -2435,15 +2374,14 @@ def _dict_to_dataclass(cls, data):
24352374

24362375
def deserialize(
24372376
artifact: SerializedArtifact,
2438-
expected_opset_version: Optional[Dict[str, int]] = None,
24392377
) -> ep.ExportedProgram:
24402378
assert isinstance(artifact.exported_program, bytes)
24412379
exported_program_str = artifact.exported_program.decode("utf-8")
24422380
exported_program_dict = json.loads(exported_program_str)
24432381
serialized_exported_program = _dict_to_dataclass(
24442382
ExportedProgram, exported_program_dict
24452383
)
2446-
return ExportedProgramDeserializer(expected_opset_version).deserialize(
2384+
return ExportedProgramDeserializer().deserialize(
24472385
serialized_exported_program,
24482386
artifact.state_dict,
24492387
artifact.constants,

exir/serde/upgrade.py

Lines changed: 0 additions & 212 deletions
This file was deleted.

0 commit comments

Comments
 (0)