114
114
"ExportedProgramDeserializer" ,
115
115
]
116
116
117
- from .upgrade import GraphModuleOpUpgrader
118
117
119
118
log = logging .getLogger (__name__ )
120
119
@@ -2220,12 +2219,8 @@ def deserialize_module_call_graph(
2220
2219
2221
2220
2222
2221
class ExportedProgramDeserializer :
2223
- def __init__ (self , expected_opset_version : Optional [Dict [str , int ]] = None ):
2224
- self .expected_opset_version : Dict [str , int ] = {}
2225
- if expected_opset_version :
2226
- self .expected_opset_version .update (expected_opset_version )
2227
- if "aten" not in self .expected_opset_version :
2228
- self .expected_opset_version ["aten" ] = torch ._C ._get_max_operator_version ()
2222
+ def __init__ (self ):
2223
+ pass
2229
2224
2230
2225
def deserialize_range_constraints (
2231
2226
self ,
@@ -2278,13 +2273,6 @@ def deserialize(
2278
2273
symbol_name_to_range ,
2279
2274
res .names_to_symbols ,
2280
2275
)
2281
- model_opset_version : Optional [Dict [str , int ]] = exported_program .opset_version
2282
- self ._validate_model_opset_version (model_opset_version )
2283
-
2284
- upgrader = GraphModuleOpUpgrader (
2285
- self .expected_opset_version , model_opset_version
2286
- )
2287
-
2288
2276
exported_program = ep .ExportedProgram (
2289
2277
root = res .graph_module ,
2290
2278
graph = res .graph_module .graph ,
@@ -2296,56 +2284,7 @@ def deserialize(
2296
2284
verifier = load_verifier (exported_program .dialect ),
2297
2285
constants = res .constants ,
2298
2286
)
2299
- return upgrader .upgrade (exported_program )
2300
-
2301
- def _validate_model_opset_version (
2302
- self , model_opset_version : Optional [Dict [str , int ]]
2303
- ):
2304
- """Compare model_opset_version with expected_opset_version and raise error if we can't resolve the version
2305
- difference.
2306
- E.g., model_opset_version = {"aten": 3, "custom": 4}
2307
- expected_opset_version = {"aten": 4, "custom": 4}
2308
- This means we can use an upgrader for ATen to reconcile the deserialized model.
2309
-
2310
- The logic of this method:
2311
-
2312
- For common op namespaces:
2313
- 1. if model version < expected version, this case can be handled by upgraders.
2314
- 2. if model version > expected version, we need downgraders but not implemented yet.
2315
- 3. if model version == expected version, we don't need extra handling.
2316
-
2317
- For op namespace only in model_opset_version, we should give a warning because it is missing from
2318
- expected_opset_version.
2319
- """
2320
- if not model_opset_version :
2321
- raise RuntimeError ("Serialized model should have opset version." )
2322
- common_namespaces = {
2323
- key for key in model_opset_version if key in self .expected_opset_version
2324
- }
2325
- for namespace in common_namespaces :
2326
- model_version = model_opset_version [namespace ]
2327
- assert isinstance (
2328
- model_version , int
2329
- ), f"model_opset_version value should be int, got { model_version } "
2330
-
2331
- compiler_version = self .expected_opset_version [namespace ]
2332
- assert isinstance (
2333
- compiler_version , int
2334
- ), f"expected_opset_version value should be int, got { compiler_version } "
2335
-
2336
- # TODO(larryliu0820): Add support for upgrader & downgrader
2337
- if model_version != compiler_version :
2338
- raise NotImplementedError (
2339
- f"Model opset version { model_opset_version } doesn't match to compiler opset version "
2340
- f"{ self .expected_opset_version } ! Upgrader/downgrader is not implemented yet."
2341
- )
2342
- for namespace in model_opset_version :
2343
- if namespace in common_namespaces :
2344
- continue
2345
- log .warning (
2346
- "Compiler doesn't have a version table for op namespace: {ns}. " ,
2347
- extra = {"ns" : namespace },
2348
- )
2287
+ return exported_program
2349
2288
2350
2289
2351
2290
class EnumEncoder (json .JSONEncoder ):
@@ -2435,15 +2374,14 @@ def _dict_to_dataclass(cls, data):
2435
2374
2436
2375
def deserialize (
2437
2376
artifact : SerializedArtifact ,
2438
- expected_opset_version : Optional [Dict [str , int ]] = None ,
2439
2377
) -> ep .ExportedProgram :
2440
2378
assert isinstance (artifact .exported_program , bytes )
2441
2379
exported_program_str = artifact .exported_program .decode ("utf-8" )
2442
2380
exported_program_dict = json .loads (exported_program_str )
2443
2381
serialized_exported_program = _dict_to_dataclass (
2444
2382
ExportedProgram , exported_program_dict
2445
2383
)
2446
- return ExportedProgramDeserializer (expected_opset_version ).deserialize (
2384
+ return ExportedProgramDeserializer ().deserialize (
2447
2385
serialized_exported_program ,
2448
2386
artifact .state_dict ,
2449
2387
artifact .constants ,
0 commit comments