Skip to content

Commit 9ac99c2

Browse files
committed
[executorch][serialization] Serialize PTD files.
Pull Request resolved: #7270 Introduce top-level serialization file that calls: - serialize_pte_binary for PTE file - FlatTensor.serialize_tensors for PTD files. ghstack-source-id: 257491629 @exported-using-ghexport Differential Revision: [D66523267](https://our.internmc.facebook.com/intern/diff/D66523267/)
1 parent 444b17b commit 9ac99c2

File tree

4 files changed

+96
-14
lines changed

4 files changed

+96
-14
lines changed

exir/_serialize/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ runtime.python_library(
3333
"_dataclass.py",
3434
"_flatbuffer.py",
3535
"_program.py",
36+
"_serialize.py",
3637
"utils.py",
3738
"data_serializer.py",
3839
],

exir/_serialize/_serialize.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
10+
from typing import Dict, Tuple
11+
12+
from executorch.exir._serialize import _serialize_pte_binary
13+
14+
from executorch.exir._serialize._cord import Cord
15+
from executorch.exir._serialize.data_serializer import (
16+
DataSerializer,
17+
SerializationInfo,
18+
TensorLayout,
19+
)
20+
21+
from executorch.exir.capture._config import ExecutorchBackendConfig
22+
from executorch.exir.emit import EmitterOutput
23+
from executorch.exir.schema import Tensor, TensorDataLocation
24+
25+
26+
def serialize(
27+
emitter_output: EmitterOutput,
28+
config: ExecutorchBackendConfig,
29+
data_serializer: DataSerializer,
30+
) -> Tuple[Cord, Dict[str, Cord]]:
31+
"""Serialize the output from Emitter into ExecuTorch artifacts; PTE and PTD files."""
32+
# Serialize PTE file.
33+
pte: Cord = _serialize_pte_binary(
34+
program=emitter_output.program,
35+
mutable_data=emitter_output.mutable_data,
36+
extract_delegate_segments=config.extract_delegate_segments,
37+
segment_alignment=config.segment_alignment,
38+
constant_tensor_alignment=config.constant_tensor_alignment,
39+
delegate_alignment=config.delegate_alignment,
40+
)
41+
42+
# Serialize PTD files.
43+
ptd_files: Dict[str, Cord] = {}
44+
45+
# Find all external tensors and organize into {fqn: Tensor}.
46+
fqn_to_tensor_layout: Dict[str, TensorLayout] = {}
47+
for plan in emitter_output.program.execution_plan:
48+
for evalue in plan.values:
49+
if isinstance(evalue.val, Tensor):
50+
tensor = evalue.val
51+
if (
52+
tensor.extra_tensor_info is not None
53+
and tensor.extra_tensor_info.fully_qualified_name is not None
54+
and tensor.extra_tensor_info.location is TensorDataLocation.EXTERNAL
55+
):
56+
fqn_to_tensor_layout[
57+
tensor.extra_tensor_info.fully_qualified_name # pyre-ignore Undefined attribute [16]
58+
] = TensorLayout(tensor.scalar_type, tensor.sizes, tensor.dim_order)
59+
if len(fqn_to_tensor_layout) > 0:
60+
assert emitter_output.external_constant_map is not None
61+
for (
62+
file,
63+
fqn_map,
64+
) in (
65+
# pyre-ignore Undefined attribute [16]: Optional type has no attribute `items`.
66+
emitter_output.external_constant_map.items()
67+
):
68+
ptd_files[file] = data_serializer.serialize_tensors(
69+
SerializationInfo(
70+
emitter_output.external_constant_buffer,
71+
fqn_map,
72+
fqn_to_tensor_layout,
73+
)
74+
)
75+
76+
return pte, ptd_files

exir/program/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ python_library(
4444
"//executorch/exir/passes:spec_prop_pass",
4545
"//executorch/exir/passes:weights_to_outputs_pass",
4646
"//executorch/exir/verification:verifier",
47+
"//executorch/extension/flat_tensor/serialize:serialize",
4748
] + (["//executorch/exir/program/fb:logger"] if not runtime.is_oss else [])
4849
)
4950

exir/program/_program.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313

1414
import torch
1515
import torch._export
16-
from executorch.exir._serialize import _serialize_pte_binary
1716
from executorch.exir._serialize._cord import Cord
17+
from executorch.exir._serialize._serialize import serialize
18+
from executorch.exir._serialize.data_serializer import DataSerializer
1819
from executorch.exir._warnings import experimental
1920
from executorch.exir.backend.backend_api import to_backend
2021
from executorch.exir.backend.partitioner import Partitioner
@@ -56,6 +57,7 @@
5657
EXIREdgeDialectVerifier,
5758
get_aten_verifier,
5859
)
60+
from executorch.extension.flat_tensor.serialize.serialize import FlatTensorSerializer
5961
from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass
6062
from torch.export import ExportedProgram
6163
from torch.export._remove_auto_functionalized_pass import (
@@ -494,23 +496,23 @@ def __init__(
494496
)
495497
self.exported_program = exir_exported_program.exported_program
496498
self._pte_data: Optional[Cord] = None
499+
self._data_files: Optional[Dict[str, Cord]] = None
497500
self._buffer: Optional[bytes] = None
498501
self._emitter_output: Optional[EmitterOutput] = None
499502
self._emit_stacktrace: bool = emit_stacktrace
500503
self._extract_delegate_segments: bool = extract_delegate_segments
501504
self._segment_alignment: int = segment_alignment
502505
self._constant_tensor_alignment: Optional[int] = constant_tensor_alignment
503506
self._delegate_alignment: Optional[int] = delegate_alignment
507+
self._data_serializer: DataSerializer = FlatTensorSerializer()
504508

505509
def _get_pte_data(self) -> Cord:
506510
if self._pte_data is None:
507-
self._pte_data = _serialize_pte_binary(
508-
program=self.program,
509-
extract_delegate_segments=self._extract_delegate_segments,
510-
segment_alignment=self._segment_alignment,
511-
constant_tensor_alignment=self._constant_tensor_alignment,
512-
delegate_alignment=self._delegate_alignment,
511+
assert self._emitter_output is not None
512+
self._pte_data, self._data_files = serialize(
513+
self._emitter_output, ExecutorchBackendConfig(), self._data_serializer
513514
)
515+
assert self._pte_data is not None
514516
return self._pte_data
515517

516518
@property
@@ -1443,14 +1445,11 @@ def __init__(
14431445
self._config_methods,
14441446
)
14451447

1448+
self._data_serializer = FlatTensorSerializer()
1449+
14461450
# Serialize emitter output, ready to be written to a file.
1447-
self._pte_data: Cord = _serialize_pte_binary(
1448-
program=self._emitter_output.program,
1449-
mutable_data=self._emitter_output.mutable_data,
1450-
extract_delegate_segments=backend_config.extract_delegate_segments,
1451-
segment_alignment=backend_config.segment_alignment,
1452-
constant_tensor_alignment=backend_config.constant_tensor_alignment,
1453-
delegate_alignment=backend_config.delegate_alignment,
1451+
self._pte_data, self._data_files = serialize(
1452+
self._emitter_output, ExecutorchBackendConfig(), self._data_serializer
14541453
)
14551454
self._buffer: Optional[bytes] = None
14561455

@@ -1532,3 +1531,8 @@ def write_to_file(self, open_file: io.BufferedIOBase) -> None:
15321531
reducing the peak memory usage.
15331532
"""
15341533
self._pte_data.write_to_file(open_file)
1534+
1535+
for filename, cord in self._data_files.items():
1536+
filename = filename + ".ptd"
1537+
with open(filename, "wb") as file:
1538+
cord.write_to_file(file)

0 commit comments

Comments
 (0)