Skip to content

Use cords to store constant and delegate segment data #2281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions exir/_serialize/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,6 @@ runtime.python_library(
],
deps = [
"//executorch/exir:schema",
"//executorch/exir:tensor",
],
)
289 changes: 81 additions & 208 deletions exir/_serialize/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dataclasses import dataclass
from typing import ClassVar, List, Literal, Optional, Tuple

from executorch.exir._serialize._cord import Cord
from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass
from executorch.exir._serialize._flatbuffer import (
_FlatbufferResult,
Expand All @@ -29,6 +30,7 @@
Program,
SubsegmentOffsets,
)
from executorch.exir.tensor import ALIGNMENT


# Byte order of numbers written to program headers. Always little-endian
Expand Down Expand Up @@ -240,15 +242,15 @@ def _get_extended_header(program_data: bytes) -> Optional[_ExtendedHeader]:


def _extract_delegate_segments(
program: Program, segments: List[bytes], segment_alignment: int
program: Program,
segments: List[Cord],
) -> None:
"""The input program and segments list are modified in place.
"""Extracts the delegate segments inlined in the program into a list of buffers.
The program is modified in-place to remove the delegate data.

Args:
program: The program to extract segments from. Modified in-place.
segments: A list to which extracted segments will be appended. Modified in-place.
segment_alignment: Alignment in bytes. The starting offset of each
segment will be aligned to this value.
segments: A list of buffers to append extracted segments to. Modified in-place.
"""
remaining_inline: List[BackendDelegateInlineData] = []
inline_indices_seen: set[int] = set()
Expand Down Expand Up @@ -278,24 +280,11 @@ def _extract_delegate_segments(
if inline.data:
# Move the delegate data out of the program.
segment_index = len(segments)
segments.append(inline.data)
segments.append(Cord(inline.data))
delegate.processed = BackendDelegateDataReference(
location=DataLocation.SEGMENT,
index=segment_index,
)

# Update the segment list in the root Program object.
prev_end = (
program.segments[-1].offset + program.segments[-1].size
if program.segments
else 0
)
program.segments.append(
DataSegment(
offset=_aligned_size(prev_end, segment_alignment),
size=len(inline.data),
),
)
else:
# Not moving into a segment. Keep it inline, but update the
# index.
Expand All @@ -321,183 +310,32 @@ def _extract_delegate_segments(
def _extract_constant_segment(
constant_buffer: List[Buffer],
tensor_alignment: int,
) -> Tuple[bytes, List[int]]:
"""Copies the tensors from the provided list into a single buffer and tracks the offsets
of each tensor.
) -> Tuple[Cord, List[int]]:
"""Copies the tensors from the provided list into a Cord and tracks the offsets
of each tensor.

Args:
constant_buffer: list of Buffers from which to extract constants from. Not modified.
tensor_alignment: Alignment in bytes. The starting offset of each tensor in the
constant segment will be aligned to this value. Default to 16.
tensor_alignment: Alignment in bytes. Each tensor in the cord will be padded to align
with this value. Defaults to ALIGNMENT.

Returns:
A tuple of (constant segment, list of offsets for each tensor in the segment)
"""
constant_segment_data: bytearray = bytearray()
constant_segment_data: Cord = Cord()
constant_segment_offsets: List[int] = []
current_offset: int = 0
for i in range(len(constant_buffer)):
buffer = constant_buffer[i]
constant_segment_data.append(buffer.storage)
buffer_length = len(buffer.storage)
pad_length = _padding_required(buffer_length, tensor_alignment)

# Append each constant buffer to the constant segment.
constant_segment_data += buffer.storage
# Add padding for all but the last tensor.
if i < len(constant_buffer) - 1:
constant_segment_data += b"\x00" * pad_length

# Append constant data offset.
constant_segment_data.append(b"\x00" * pad_length)
constant_segment_offsets.append(current_offset)
current_offset += buffer_length + pad_length
return bytes(constant_segment_data), constant_segment_offsets


def _extract_segments(
program: Program,
extract_delegate_segments: bool,
extract_constant_segment: bool,
segment_alignment: int,
constant_tensor_alignment: int,
) -> Tuple[Program, List[bytes]]:
"""Extracts constant and/or delegate data from a given Program into separate segments.

Args:
program: The Program to extract segments from.
extract_delegate_segments: Whether to extract delegate data blobs from the program.
extract_constant_segment: Whether to extract constant data from the program.
segment_alignment: Alignment in bytes. The starting offset of each
segment will be aligned to this value in the output data.
constant_tensor_alignment: Alignment in bytes. The starting offset of each tensor
in the constant segment will be aligned to this value.
Returns:
A tuple of (modified program, list of segment data).
Raises:
ValueError, if the program already contains segments.
"""
if program.segments:
raise ValueError(
f"Program already has {len(program.segments)} segments: "
+ f"{repr(program.segments)}"
)

# Don't modify the original program.
# TODO(T144120904): Could avoid yet more huge copies with a more shallow
# copy, reusing the actual data blobs.
program = copy.deepcopy(program)

# Segment data to be written to the file following the flatbuffer data.
segments: List[bytes] = []

if extract_constant_segment:
constant_segment_data, constant_segment_offsets = _extract_constant_segment(
program.constant_buffer, tensor_alignment=constant_tensor_alignment
)

if constant_segment_data:
# Append constant_segment_data to the list of segments if non-empty.
segments.append(constant_segment_data)
# Append constant_segment offset to the list of DataSegments. Added as the
# first segment here, but it's not mandatory that the constant segment be
# the first one in the file.
program.segments.append(
DataSegment(offset=0, size=len(constant_segment_data))
)

# Fill in constant_segment offsets and clear the constant buffer; only one of
# constant_segment and constant_buffer should be non-empty.
program.constant_segment = SubsegmentOffsets(
segment_index=0, offsets=constant_segment_offsets
)
program.constant_buffer = []

if extract_delegate_segments:
_extract_delegate_segments(
program, segments=segments, segment_alignment=segment_alignment
)
return program, segments


def _append_segments(
program_data: bytes,
segments: List[bytes],
alignment: int,
segment_table: List[DataSegment],
base_offset: int,
) -> bytes:
"""Appends segments to the end of the program data.

Appends each element of `segments` to `program_data`, with '\0' padding to
ensure that the offset of each segment is aligned to `alignment`.

Args:
program_data: The flatbuffer-serialized Program.
segments: The list of segments to append to `program_data`.
alignment: Alignment in bytes. The starting offset of each
segment will be aligned to this value in the output data.
segment_table: The expected offsets and sizes of each element in
`segments`. This is typically `program.segments`. Must have the
same length as `segments`.
base_offset: The expected segment base offset from the extended header.
Should point to the aligned offset following the end of
`program_data`.
Returns:
A copy of `program_data` with the segment data and padding appended.
If there are no segments, returns `program_data` directly.
Raises:
ValueError: If the length of `segments` doesn't match the length of
`segment_table`.
"""
if len(segments) != len(segment_table):
raise ValueError(
f"Segments length {len(segments)} does not match "
+ f"segment_table length {len(segment_table)}"
)
if not segments:
return program_data

# The pieces that will be concatenated to create the output data.
# `program_data` will be its first element.
padded_segments: List[bytes] = []
# Length of all elements in padded_segments. Only used for assertions.
current_offset: int = 0
for i, segment in enumerate([program_data] + segments):
# Add padding if necessary to align the start of this segment.
pad_length: int = _padding_required(current_offset, alignment)
if pad_length > 0:
padded_segments.append(b"\x00" * pad_length)
current_offset += pad_length

# Make sure that we're about to add this segment to the offset that
# agrees with program.segments. Skip the first entry, which is the
# Program itself and isn't included in program.segments.
if i == 1:
# The first real segment should start at the base offset.
assert current_offset == base_offset, (
f"Offset of first segment {current_offset} "
+ f"!= base_offset {base_offset}"
)
if i > 0:
# Adding a real segment, not `program_data`.
expected_segment = segment_table[i - 1]
expected_offset = base_offset + expected_segment.offset
assert current_offset == expected_offset, (
f"Segment {i} offset {current_offset} "
+ f"!= expected offset {expected_offset} "
+ f"(base {base_offset} + {expected_segment.offset}) "
)
assert expected_segment.size == len(segment), (
f"Segment {i} size {len(segment)} "
+ f"!= expected size {expected_segment.size}"
)

# Add the payload. If this is the final segment, it does not need
# padding after it.
padded_segments.append(segment)
current_offset += len(segment)

# Use join() instead of appending to avoid O(n) reallocation of these
# potentially-large buffers.
return b"".join(padded_segments)
return constant_segment_data, constant_segment_offsets


def serialize_pte_binary(
Expand All @@ -524,9 +362,8 @@ def serialize_pte_binary(
into a separate segment.
segment_alignment: Alignment in bytes. The starting offset of each
segment will be aligned to this value in the output data.
constant_tensor_alignment: If provided, the minimum alignment of tensor
buffers in the program. Must be a power of 2. If not provided, uses
the value in the schema file.
constant_tensor_alignment: The minimum alignment of tensor
buffers in the program. Must be a power of 2. Defaults to ALIGNMENT.
delegate_alignment: If provided, the minimum alignment of delegate data
in the program. Must be a power of 2. If not provided, uses the
value in the schema file.
Expand All @@ -535,20 +372,53 @@ def serialize_pte_binary(
"""
# Default tensor alignment.
if constant_tensor_alignment is None:
constant_tensor_alignment = 16
constant_tensor_alignment = ALIGNMENT

# Segment data to be written to the file following the flatbuffer data.
segments: List[bytes] = []
# Don't modify the original program.
# TODO(T144120904): Could avoid yet more huge copies with a more shallow
# copy, reusing the actual data blobs.
program = copy.deepcopy(program)

# Store extracted segment data; this may be constant data or delegate data.
segments: List[Cord] = []

if extract_constant_segment:
constant_segment_data, constant_segment_offsets = _extract_constant_segment(
program.constant_buffer, tensor_alignment=constant_tensor_alignment
)
if len(constant_segment_data) > 0:
# Update program.constant_segment with constant subsegment offset information.
program.constant_segment = SubsegmentOffsets(
segment_index=len(segments), offsets=constant_segment_offsets
)
# Clear the constant buffer, as constant data will be stored in segments.
program.constant_buffer = []
# Add to the aggregate segments cord.
segments.append(constant_segment_data)

# Extract constant segment and delegate segments, if requested.
if extract_constant_segment or extract_delegate_segments:
program, segments = _extract_segments(
program=program,
extract_delegate_segments=extract_delegate_segments,
extract_constant_segment=extract_constant_segment,
segment_alignment=segment_alignment,
constant_tensor_alignment=constant_tensor_alignment,
if extract_delegate_segments:
_extract_delegate_segments(program, segments)

# Append all segments into a single Cord, adding any necessary padding to ensure that
# each segment begins at the required alignment.
# Update program.segments with the offsets to each segment.
segments_data = Cord()
for data in segments:
prev_end = (
(program.segments[-1].offset + program.segments[-1].size)
if program.segments
else 0
)
program.segments.append(
DataSegment(
offset=_aligned_size(prev_end, segment_alignment), size=len(data)
)
)
# Add to aggregate segments cord with padding.
padding_length = _padding_required(len(segments_data), segment_alignment)
if padding_length > 0:
segments_data.append(b"\x00" * padding_length)
segments_data.append(data)

# Convert to a standard flatbuffer binary.
result: _FlatbufferResult = _program_json_to_flatbuffer(
Expand All @@ -558,7 +428,7 @@ def serialize_pte_binary(
)

# If there are no segments present, do not insert the extended header.
if not segments:
if len(segments_data) == 0:
return result.data

# Size of the header to insert. Its size is padded to the largest
Expand All @@ -572,7 +442,7 @@ def serialize_pte_binary(
# Offset to the first segment, or zero if there are no segments.
segment_base_offset: int = (
_aligned_size(input_size=program_size, alignment=segment_alignment)
if segments
if len(segments_data) > 0
else 0
)

Expand Down Expand Up @@ -600,18 +470,21 @@ def serialize_pte_binary(
assert eh.program_size == program_size
assert eh.segment_base_offset == segment_base_offset

if segments:
# Add segments to the end of the data, in order, with the appropriate
# padding.
program_data = _append_segments(
program_data=program_data,
segments=segments,
alignment=segment_alignment,
segment_table=program.segments,
base_offset=segment_base_offset,
)

return program_data
# Construct the final pte file containing:
# - program data; written to offset 0.
# - segments data (optional); aligned to segment_alignment.
pte_data = Cord(program_data)
if len(segments_data) > 0:
padding_length = _padding_required(len(pte_data), segment_alignment)
pte_data.append(b"\x00" * padding_length)
# The first segment after program data should start at the segment base offset.
assert (
len(pte_data) == segment_base_offset
), f"Offset of first segment {len(pte_data)} != segment base offset {segment_base_offset}"
pte_data.append(segments_data)

# TODO(lfq): this creates a copy of all the data; once we update existing callsites this will change.
return bytes(pte_data)


def _restore_segments(program: Program, segment_data: bytes) -> Program:
Expand Down
Loading