Skip to content

Commit 66a4633

Browse files
angelayifacebook-github-bot
authored andcommitted
Store the arguments used to trace the exported program in itself (#115)
Summary: Pull Request resolved: #115 Continuation of pytorch/pytorch#107704 Differential Revision: D48637348 fbshipit-source-id: a60b38ae9687fa3d7f84f0b9dfd912209bec1009
1 parent 8683882 commit 66a4633

File tree

5 files changed

+22
-9
lines changed

5 files changed

+22
-9
lines changed

exir/backend/backend_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ def to_backend(
302302
copy.deepcopy(edge_program.range_constraints),
303303
copy.deepcopy(edge_program.equality_constraints),
304304
copy.deepcopy(edge_program.module_call_graph),
305+
(),
305306
)
306307

307308

exir/capture/_capture.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def convert_to_fake(x):
213213
{},
214214
[],
215215
[],
216+
(),
216217
)
217218
return ExirExportedProgram(ep, False)
218219

exir/lowered_backend_module.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ def create_exported_program_from_submodule(
244244
range_constraints=copy.deepcopy(owning_program.range_constraints),
245245
equality_constraints=[],
246246
module_call_graph=[],
247+
original_traced_arguments=(),
247248
)
248249

249250

exir/serde/serialize.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ def serialize_bytes(b: bytes) -> str:
291291
(
292292
serialized_original_module,
293293
serialized_original_state_dict,
294+
_,
294295
) = ExportedProgramSerializer().serialize(lowered_module.original_module)
295296

296297
serialized_processed_bytes = serialize_bytes(lowered_module.processed_bytes)
@@ -312,8 +313,8 @@ def serialize_bytes(b: bytes) -> str:
312313

313314
class ExportedProgramSerializer(export_serialize.ExportedProgramSerializer):
314315
def serialize(
315-
self, exported_program: torch._export.ExportedProgram
316-
) -> Tuple[schema.ExportedProgram, bytes]:
316+
self, exported_program: ep.ExportedProgram
317+
) -> Tuple[schema.ExportedProgram, bytes, bytes]:
317318
assert isinstance(exported_program, torch._export.ExportedProgram)
318319
gm_serializer = GraphModuleSerializer(
319320
exported_program.graph_signature, exported_program.call_spec
@@ -337,7 +338,8 @@ def serialize(
337338
equality_constraints=serialized_equality_constraints,
338339
schema_version=schema.SCHEMA_VERSION,
339340
),
340-
export_serialize.serialize_state_dict(gm_serializer.state_dict),
341+
export_serialize.serialize_torch_artifact(gm_serializer.state_dict),
342+
b"",
341343
)
342344

343345

@@ -595,6 +597,7 @@ def deserialize_lowered_module(
595597
original_module = ExportedProgramDeserializer().deserialize(
596598
serialized_lowered_module.original_module,
597599
base64.b64decode(serialized_lowered_module.original_state_dict),
600+
b"",
598601
)
599602

600603
lowered_module = ExirLoweredBackendModule(
@@ -612,6 +615,7 @@ def deserialize(
612615
self,
613616
serialized_exported_program: schema.ExportedProgram,
614617
serialized_state_dict: bytes,
618+
serialized_original_traced_args: bytes,
615619
) -> exir.ExportedProgram:
616620
symbol_name_to_range = {
617621
k: symbolic_shapes.ValueRanges(
@@ -620,7 +624,8 @@ def deserialize(
620624
)
621625
for k, v in serialized_exported_program.range_constraints.items()
622626
}
623-
state_dict = export_serialize.deserialize_state_dict(serialized_state_dict)
627+
state_dict = export_serialize.deserialize_torch_artifact(serialized_state_dict)
628+
assert isinstance(state_dict, dict)
624629
(
625630
graph_module,
626631
sig,
@@ -656,27 +661,29 @@ def deserialize(
656661
range_constraints,
657662
equality_constraints,
658663
[],
664+
(),
659665
)
660666

661667

662668
def serialize(
663669
exported_program: torch._export.ExportedProgram,
664670
opset_version: Optional[Dict[str, int]] = None,
665-
) -> Tuple[bytes, bytes]:
666-
serialized_exported_program, serialized_state_dict = ExportedProgramSerializer(
671+
) -> Tuple[bytes, bytes, bytes]:
672+
serialized_exported_program, *_ = ExportedProgramSerializer(
667673
opset_version
668674
).serialize(exported_program)
669675
json_program = json.dumps(
670676
dataclasses.asdict(serialized_exported_program),
671677
cls=export_serialize.EnumEncoder,
672678
)
673679
json_bytes = json_program.encode("utf-8")
674-
return json_bytes, serialized_state_dict
680+
return json_bytes, *_
675681

676682

677683
def deserialize(
678684
exported_program_bytes: bytes,
679685
state_dict: bytes,
686+
original_traced_args: bytes,
680687
expected_opset_version: Optional[Dict[str, int]] = None,
681688
) -> exir.ExportedProgram:
682689
exported_program_str = exported_program_bytes.decode("utf-8")
@@ -689,5 +696,7 @@ def deserialize(
689696
schema.ExportedProgram, exported_program_dict
690697
)
691698
return ExportedProgramDeserializer(expected_opset_version).deserialize(
692-
serialized_exported_program, state_dict
699+
serialized_exported_program,
700+
state_dict,
701+
original_traced_args,
693702
)

sdk/etrecord/_etrecord.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def _handle_exported_program(
3838
etrecord_zip: ZipFile, module_name: str, method_name: str, ep: ExportedProgram
3939
) -> None:
4040
assert isinstance(ep, ExportedProgram)
41-
serialized_ep, serialized_state_dict = serialize(ep)
41+
serialized_ep, serialized_state_dict, _ = serialize(ep)
4242
etrecord_zip.writestr(f"{module_name}/{method_name}", serialized_ep)
4343
etrecord_zip.writestr(
4444
f"{module_name}/{method_name}_state_dict", serialized_state_dict
@@ -217,6 +217,7 @@ def parse_etrecord(etrecord_path: str) -> ETRecord:
217217
graph_map[serialized_file] = deserialize(
218218
etrecord_zip.read(serialized_file),
219219
etrecord_zip.read(serialized_state_dict_file),
220+
b"",
220221
)
221222

222223
return ETRecord(

0 commit comments

Comments
 (0)