Skip to content

Refactor FlatTensor #11499

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 26 additions & 23 deletions exir/_serialize/_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
DataEntry,
DataPayload,
DataSerializer,
TensorEntry,
TensorLayout,
)

Expand All @@ -29,22 +28,22 @@ def serialize_for_executorch(
emitter_output: EmitterOutput,
config: ExecutorchBackendConfig,
data_serializer: DataSerializer,
named_data: Optional[NamedDataStoreOutput] = None,
named_data_store: Optional[NamedDataStoreOutput] = None,
) -> Tuple[Cord, Dict[str, Cord]]:
"""Serialize the output from Emitter into ExecuTorch artifacts; PTE and PTD files."""

# Serialize PTE file.
pte_named_data = None
if (
named_data is not None
and len(named_data.buffers) > 0
and len(named_data.pte_data) > 0
named_data_store is not None
and len(named_data_store.buffers) > 0
and len(named_data_store.pte_data) > 0
):
# Create a separate NamedDataStoreOutput with only pte_data; exclude
# external_data, which shouldn't be serialized with the PTE file.
pte_named_data = NamedDataStoreOutput(
buffers=named_data.buffers,
pte_data=named_data.pte_data,
buffers=named_data_store.buffers,
pte_data=named_data_store.pte_data,
external_data={},
)
pte: Cord = _serialize_pte_binary(
Expand Down Expand Up @@ -72,22 +71,23 @@ def serialize_for_executorch(
and tensor.extra_tensor_info.location is TensorDataLocation.EXTERNAL
):
fqn_to_tensor_layout[
# pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`
tensor.extra_tensor_info.fully_qualified_name
] = TensorLayout(tensor.scalar_type, tensor.sizes, tensor.dim_order)

if len(fqn_to_tensor_layout) == 0 and (
named_data is None or len(named_data.external_data) == 0
named_data_store is None or len(named_data_store.external_data) == 0
):
return pte, ptd_files

# Consolidate tensors and opaque data with the same external tag so they
# can be saved to the same PTD.
all_external_tags: Set[str] = set()
if named_data is not None and len(named_data.external_data) > 0:
if named_data_store is not None and len(named_data_store.external_data) > 0:
assert (
len(named_data.buffers) > 0
len(named_data_store.buffers) > 0
), "External data exists, but there are no buffers provided."
all_external_tags = set(named_data.external_data.keys())
all_external_tags = set(named_data_store.external_data.keys())

if len(fqn_to_tensor_layout) > 0:
# emitter_output.external_constant_map contains the mapping from
Expand All @@ -103,35 +103,38 @@ def serialize_for_executorch(

for tag in all_external_tags:
buffers = []
fqn_to_tensor_entry: Dict[str, TensorEntry] = {}
key_to_data_entry: Dict[str, DataEntry] = {}
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `get`.
fqn_to_index = emitter_output.external_constant_map.get(tag, {})
# Create a TensorEntry for each external tensor.
# Create a DataEntry for each external tensor.
for fqn, index in fqn_to_index.items():
assert fqn in fqn_to_tensor_layout
fqn_to_tensor_entry[fqn] = TensorEntry(
assert fqn not in key_to_data_entry # fqn must be unique
key_to_data_entry[fqn] = DataEntry(
buffer_index=len(buffers),
layout=fqn_to_tensor_layout[fqn],
alignment=config.constant_tensor_alignment,
tensor_layout=fqn_to_tensor_layout[fqn],
)
buffers.append(emitter_output.external_constant_buffer[index])

# Extract external data.
key_to_data: Dict[str, DataEntry] = {}
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `get`.
key_to_buffer_index = named_data.external_data.get(tag, {})
key_to_buffer_index = named_data_store.external_data.get(tag, {})
for key, index in key_to_buffer_index.items():
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `buffers`.
key_to_data[key] = DataEntry(
len(buffers), named_data.buffers[index].alignment
assert key not in key_to_data_entry # key must be unique
key_to_data_entry[key] = DataEntry(
buffer_index=len(buffers),
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `buffers`.
alignment=named_data_store.buffers[index].alignment,
tensor_layout=None,
)
buffers.append(named_data.buffers[index].buffer)
buffers.append(named_data_store.buffers[index].buffer)

# Serialize into PTD file.
ptd_files[tag] = data_serializer.serialize(
DataPayload(
buffers=buffers,
fqn_to_tensor=fqn_to_tensor_entry,
key_to_data=key_to_data,
named_data=key_to_data_entry,
)
)

Expand Down
64 changes: 16 additions & 48 deletions exir/_serialize/data_serializer.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,9 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Sequence
from typing import Dict, Optional, Sequence

from executorch.exir._serialize._cord import Cord

from executorch.exir.schema import ScalarType


@dataclass
class TensorLayout:
"""Tensor layout information for externally-serialized tensors.
Attributes:
scalar_type: type of the elements in the tensor.
sizes: size of each dim in the tensor.
dim_order: specifies the order the dimensions are laid out in memory,
from outer to inner.
"""

scalar_type: ScalarType
sizes: List[int]
dim_order: List[int]


@dataclass
class TensorEntry:
"""Represents a single tensor in `DataPayload`, specifying its location
and metadata.
Attributes:
buffer_index: The index inside `DataPayload.buffers` that this
TensorEntry refers to.
layout: Metadata about the tensor.
"""

buffer_index: int
layout: TensorLayout
from executorch.extension.flat_tensor.serialize.flat_tensor_schema import TensorLayout


@dataclass
Expand All @@ -47,33 +15,33 @@ class DataEntry:
buffer_index: The index inside `DataPayload.buffers` that this
DataEntry refers to.
alignment: The alignment of the data.
tensor_layout: If this is a tensor, the tensor layout information.
"""

buffer_index: int
alignment: int
tensor_layout: Optional[TensorLayout]


@dataclass
class DataPayload:
"""Contains the data and metadata required for serialization.
Having an index-based arrangement instead of embedding the buffers in
TensorEntry allows the caller to deduplicate buffers and point multiple
fully qualified names (FQNs) to the same entry.
DataEntry allows the caller to deduplicate buffers and point multiple
keys to the same entry.
Attributes:
buffers: a sequence of tensor buffers.
fqn_to_tensor: a map from fully qualified names to serializable tensors.
key_to_data: a map from unique keys to serializable opaque data.
buffers: a sequence of byte buffers.
key_to_data: a map from unique keys to serializable data.
"""

buffers: Sequence[bytes]
fqn_to_tensor: Dict[str, TensorEntry]
key_to_data: Dict[str, DataEntry]
named_data: Dict[str, DataEntry]


class DataSerializer(ABC):
"""Serializes and deserializes FQN-tagged tensor data.
"""Serializes and deserializes data. Data can be referenced by a unique key.
This base class enables serialization into different formats. See
executorch/extension/flat_tensor/ for an example.
Expand All @@ -85,11 +53,11 @@ def serialize(
data: DataPayload,
) -> Cord:
"""
Serializes a list of tensors emitted by ExecuTorch into a binary blob.
Serializes a list of bytes emitted by ExecuTorch into a binary blob.
Args:
data: the tensor buffers and tensor layout information required for
serialization.
data: buffers and corresponding metadata used for serialization.
Returns:
A binary blob that contains the serialized data.
Expand All @@ -99,14 +67,14 @@ def serialize(
@abstractmethod
def deserialize(self, blob: Cord) -> DataPayload:
"""
Deserializes a blob into a list of tensors. Reverses the effect of
Deserializes a blob into a DataPayload. Reverses the effect of
serialize.
Args:
blob: A binary blob that contains the serialized data.
Returns:
DataPayload: tensor buffers and tensor layout information
deserialized from `blob`.
DataPayload: buffers and corresponding metadata deserialized
from `blob`.
"""
raise NotImplementedError("deserialize_data")
Loading
Loading