@@ -349,13 +349,14 @@ def serialize(
349
349
additional_kwargs ["verifiers" ] = [
350
350
v .dialect for v in exported_program .verifiers
351
351
]
352
+ elif hasattr (exported_program , "dialect" ):
353
+ additional_kwargs ["dialect" ] = exported_program .dialect
352
354
return export_serialize .SerializedArtifact (
353
355
schema .ExportedProgram (
354
356
graph_module = serialized_graph_module ,
355
357
opset_version = self .opset_version ,
356
358
range_constraints = serialized_range_constraints ,
357
359
schema_version = SchemaVersion (- 1 , - 1 ),
358
- dialect = exported_program .dialect ,
359
360
** additional_kwargs ,
360
361
),
361
362
export_serialize .serialize_torch_artifact (exported_program .state_dict ),
@@ -681,16 +682,22 @@ def deserialize(
681
682
682
683
dummy_g = torch .fx .Graph ()
683
684
dummy_g .output (())
685
+ serialized_ep = serialized_artifact .exported_program
686
+ additional_kwargs = {}
687
+ if hasattr (serialized_ep , "verifiers" ):
688
+ additional_kwargs ["verifiers" ] = [
689
+ load_verifier (v ) for v in serialized_ep .verifiers # pyre-ignore
690
+ ]
691
+ elif hasattr (serialized_ep , "dialect" ):
692
+ additional_kwargs ["verifier" ] = load_verifier (serialized_ep .dialect ) # pyre-ignore
684
693
exported_program = exir .ExportedProgram (
685
694
root = state_dict ,
686
695
graph = dummy_g ,
687
696
graph_signature = ep .ExportGraphSignature (input_specs = [], output_specs = []),
688
697
state_dict = state_dict , # TODO(T157676982)
689
698
range_constraints = range_constraints ,
690
699
module_call_graph = module_call_graph ,
691
- verifier = load_verifier (
692
- serialized_artifact .exported_program .dialect # pyre-ignore
693
- ),
700
+ ** additional_kwargs ,
694
701
)
695
702
exported_program .graph_module .graph = graph_module .graph
696
703
exported_program ._graph_signature = res .signature
0 commit comments