Skip to content

Commit a924500

Browse files
zhxchen17facebook-github-bot
authored andcommitted
Expand verifier to be multiple on ExportedProgram (#4184)
Summary: Pull Request resolved: #4184 X-link: pytorch/pytorch#130364 This diff updates the ExportedProgram class in PyTorch to allow for multiple verifiers to be attached to it. This is done by adding a new field to the ExportedProgram schema called "verifiers" which is a list of strings representing the names of the verifiers to be attached to the program. The verifiers are loaded using the "load_verifier" function which is defined in the "torch._export.serde.serialize" module. The "exported_program.dialect" field is also deprecated in favor of the "verifiers" field. Reviewed By: pianpwk Differential Revision: D59408546 fbshipit-source-id: b826bb70a435a39b1f0e62dbb71db38e85cfbdff
1 parent 1ec6263 commit a924500

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

exir/serde/export_serialize.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,12 +1058,18 @@ def serialize(self, exported_program: ep.ExportedProgram) -> SerializedArtifact:
10581058
assert n not in constants
10591059
constants[n] = t
10601060

1061+
additional_kwargs = {}
1062+
if hasattr(exported_program, "verifiers"):
1063+
additional_kwargs["verifiers"] = [
1064+
v.dialect for v in exported_program.verifiers
1065+
]
10611066
serialized_ep = ExportedProgram(
10621067
graph_module=serialized_graph_module,
10631068
opset_version=self.opset_version,
10641069
range_constraints=serialized_range_constraints,
10651070
schema_version=SchemaVersion(-1, -1),
10661071
dialect=exported_program.dialect,
1072+
**additional_kwargs,
10671073
)
10681074

10691075
return SerializedArtifact(

exir/serde/serialize.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,13 +344,19 @@ def serialize(
344344
assert n not in constants
345345
constants[n] = t
346346

347+
additional_kwargs = {}
348+
if hasattr(exported_program, "verifiers"):
349+
additional_kwargs["verifiers"] = [
350+
v.dialect for v in exported_program.verifiers
351+
]
347352
return export_serialize.SerializedArtifact(
348353
schema.ExportedProgram(
349354
graph_module=serialized_graph_module,
350355
opset_version=self.opset_version,
351356
range_constraints=serialized_range_constraints,
352357
schema_version=SchemaVersion(-1, -1),
353358
dialect=exported_program.dialect,
359+
**additional_kwargs,
354360
),
355361
export_serialize.serialize_torch_artifact(exported_program.state_dict),
356362
export_serialize.serialize_torch_artifact(constants),

0 commit comments

Comments
 (0)