Skip to content

Commit 8d7ab65

Browse files
authored
Refactor FlatTensor
Differential Revision: D76285590 Pull Request resolved: #11499
1 parent 9e9815c commit 8d7ab65

18 files changed

+368
-514
lines changed

exir/_serialize/_serialize.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
DataEntry,
1717
DataPayload,
1818
DataSerializer,
19-
TensorEntry,
2019
TensorLayout,
2120
)
2221

@@ -29,22 +28,22 @@ def serialize_for_executorch(
2928
emitter_output: EmitterOutput,
3029
config: ExecutorchBackendConfig,
3130
data_serializer: DataSerializer,
32-
named_data: Optional[NamedDataStoreOutput] = None,
31+
named_data_store: Optional[NamedDataStoreOutput] = None,
3332
) -> Tuple[Cord, Dict[str, Cord]]:
3433
"""Serialize the output from Emitter into ExecuTorch artifacts; PTE and PTD files."""
3534

3635
# Serialize PTE file.
3736
pte_named_data = None
3837
if (
39-
named_data is not None
40-
and len(named_data.buffers) > 0
41-
and len(named_data.pte_data) > 0
38+
named_data_store is not None
39+
and len(named_data_store.buffers) > 0
40+
and len(named_data_store.pte_data) > 0
4241
):
4342
# Create a separate NamedDataStoreOutput with only pte_data; exclude
4443
# external_data, which shouldn't be serialized with the PTE file.
4544
pte_named_data = NamedDataStoreOutput(
46-
buffers=named_data.buffers,
47-
pte_data=named_data.pte_data,
45+
buffers=named_data_store.buffers,
46+
pte_data=named_data_store.pte_data,
4847
external_data={},
4948
)
5049
pte: Cord = _serialize_pte_binary(
@@ -72,22 +71,23 @@ def serialize_for_executorch(
7271
and tensor.extra_tensor_info.location is TensorDataLocation.EXTERNAL
7372
):
7473
fqn_to_tensor_layout[
74+
# pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`
7575
tensor.extra_tensor_info.fully_qualified_name
7676
] = TensorLayout(tensor.scalar_type, tensor.sizes, tensor.dim_order)
7777

7878
if len(fqn_to_tensor_layout) == 0 and (
79-
named_data is None or len(named_data.external_data) == 0
79+
named_data_store is None or len(named_data_store.external_data) == 0
8080
):
8181
return pte, ptd_files
8282

8383
# Consolidate tensors and opaque data with the same external tag so they
8484
# can be saved to the same PTD.
8585
all_external_tags: Set[str] = set()
86-
if named_data is not None and len(named_data.external_data) > 0:
86+
if named_data_store is not None and len(named_data_store.external_data) > 0:
8787
assert (
88-
len(named_data.buffers) > 0
88+
len(named_data_store.buffers) > 0
8989
), "External data exists, but there are no buffers provided."
90-
all_external_tags = set(named_data.external_data.keys())
90+
all_external_tags = set(named_data_store.external_data.keys())
9191

9292
if len(fqn_to_tensor_layout) > 0:
9393
# emitter_output.external_constant_map contains the mapping from
@@ -103,35 +103,38 @@ def serialize_for_executorch(
103103

104104
for tag in all_external_tags:
105105
buffers = []
106-
fqn_to_tensor_entry: Dict[str, TensorEntry] = {}
106+
key_to_data_entry: Dict[str, DataEntry] = {}
107107
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `get`.
108108
fqn_to_index = emitter_output.external_constant_map.get(tag, {})
109-
# Create a TensorEntry for each external tensor.
109+
# Create a DataEntry for each external tensor.
110110
for fqn, index in fqn_to_index.items():
111111
assert fqn in fqn_to_tensor_layout
112-
fqn_to_tensor_entry[fqn] = TensorEntry(
112+
assert fqn not in key_to_data_entry # fqn must be unique
113+
key_to_data_entry[fqn] = DataEntry(
113114
buffer_index=len(buffers),
114-
layout=fqn_to_tensor_layout[fqn],
115+
alignment=config.constant_tensor_alignment,
116+
tensor_layout=fqn_to_tensor_layout[fqn],
115117
)
116118
buffers.append(emitter_output.external_constant_buffer[index])
117119

118120
# Extract external data.
119-
key_to_data: Dict[str, DataEntry] = {}
120121
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `get`.
121-
key_to_buffer_index = named_data.external_data.get(tag, {})
122+
key_to_buffer_index = named_data_store.external_data.get(tag, {})
122123
for key, index in key_to_buffer_index.items():
123-
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `buffers`.
124-
key_to_data[key] = DataEntry(
125-
len(buffers), named_data.buffers[index].alignment
124+
assert key not in key_to_data_entry # key must be unique
125+
key_to_data_entry[key] = DataEntry(
126+
buffer_index=len(buffers),
127+
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `buffers`.
128+
alignment=named_data_store.buffers[index].alignment,
129+
tensor_layout=None,
126130
)
127-
buffers.append(named_data.buffers[index].buffer)
131+
buffers.append(named_data_store.buffers[index].buffer)
128132

129133
# Serialize into PTD file.
130134
ptd_files[tag] = data_serializer.serialize(
131135
DataPayload(
132136
buffers=buffers,
133-
fqn_to_tensor=fqn_to_tensor_entry,
134-
key_to_data=key_to_data,
137+
named_data=key_to_data_entry,
135138
)
136139
)
137140

exir/_serialize/data_serializer.py

Lines changed: 16 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,9 @@
11
from abc import ABC, abstractmethod
22
from dataclasses import dataclass
3-
from typing import Dict, List, Sequence
3+
from typing import Dict, Optional, Sequence
44

55
from executorch.exir._serialize._cord import Cord
6-
7-
from executorch.exir.schema import ScalarType
8-
9-
10-
@dataclass
11-
class TensorLayout:
12-
"""Tensor layout information for externally-serialized tensors.
13-
14-
Attributes:
15-
scalar_type: type of the elements in the tensor.
16-
sizes: size of each dim in the tensor.
17-
dim_order: specifies the order the dimensions are laid out in memory,
18-
from outer to inner.
19-
"""
20-
21-
scalar_type: ScalarType
22-
sizes: List[int]
23-
dim_order: List[int]
24-
25-
26-
@dataclass
27-
class TensorEntry:
28-
"""Represents a single tensor in `DataPayload`, specifying its location
29-
and metadata.
30-
31-
Attributes:
32-
buffer_index: The index inside `DataPayload.buffers` that this
33-
TensorEntry refers to.
34-
layout: Metadata about the tensor.
35-
"""
36-
37-
buffer_index: int
38-
layout: TensorLayout
6+
from executorch.extension.flat_tensor.serialize.flat_tensor_schema import TensorLayout
397

408

419
@dataclass
@@ -47,33 +15,33 @@ class DataEntry:
4715
buffer_index: The index inside `DataPayload.buffers` that this
4816
DataEntry refers to.
4917
alignment: The alignment of the data.
18+
tensor_layout: If this is a tensor, the tensor layout information.
5019
"""
5120

5221
buffer_index: int
5322
alignment: int
23+
tensor_layout: Optional[TensorLayout]
5424

5525

5626
@dataclass
5727
class DataPayload:
5828
"""Contains the data and metadata required for serialization.
5929
6030
Having an index-based arrangement instead of embedding the buffers in
61-
TensorEntry allows the caller to deduplicate buffers and point multiple
62-
fully qualified names (FQNs) to the same entry.
31+
DataEntry allows the caller to deduplicate buffers and point multiple
32+
keys to the same entry.
6333
6434
Attributes:
65-
buffers: a sequence of tensor buffers.
66-
fqn_to_tensor: a map from fully qualified names to serializable tensors.
67-
key_to_data: a map from unique keys to serializable opaque data.
35+
buffers: a sequence of byte buffers.
36+
key_to_data: a map from unique keys to serializable data.
6837
"""
6938

7039
buffers: Sequence[bytes]
71-
fqn_to_tensor: Dict[str, TensorEntry]
72-
key_to_data: Dict[str, DataEntry]
40+
named_data: Dict[str, DataEntry]
7341

7442

7543
class DataSerializer(ABC):
76-
"""Serializes and deserializes FQN-tagged tensor data.
44+
"""Serializes and deserializes data. Data can be referenced by a unique key.
7745
7846
This base class enables serialization into different formats. See
7947
executorch/extension/flat_tensor/ for an example.
@@ -85,11 +53,11 @@ def serialize(
8553
data: DataPayload,
8654
) -> Cord:
8755
"""
88-
Serializes a list of tensors emitted by ExecuTorch into a binary blob.
56+
Serializes a list of bytes emitted by ExecuTorch into a binary blob.
8957
9058
Args:
91-
data: the tensor buffers and tensor layout information required for
92-
serialization.
59+
data: buffers and corresponding metadata used for serialization.
60+
9361
9462
Returns:
9563
A binary blob that contains the serialized data.
@@ -99,14 +67,14 @@ def serialize(
9967
@abstractmethod
10068
def deserialize(self, blob: Cord) -> DataPayload:
10169
"""
102-
Deserializes a blob into a list of tensors. Reverses the effect of
70+
Deserializes a blob into a DataPayload. Reverses the effect of
10371
serialize.
10472
10573
Args:
10674
blob: A binary blob that contains the serialized data.
10775
10876
Returns:
109-
DataPayload: tensor buffers and tensor layout information
110-
deserialized from `blob`.
77+
DataPayload: buffers and corresponding metadata deserialized
78+
from `blob`.
11179
"""
11280
raise NotImplementedError("deserialize_data")

0 commit comments

Comments
 (0)