Skip to content

Commit 47b837b

Browse files
lucylqfacebook-github-bot
authored andcommitted
Use cords to store constant and delegate segment data (#2281)
Summary: Pull Request resolved: #2281 Update `extract_constant_segment` and `extract_delegate_segment` to place constants and delegates into cord data structures. Update `serialize_pte_binary` to compile the cords together. Remove the original extract_constant_segment/extract_delegate_segment logic, remove append_segments. Reviewed By: dbort Differential Revision: D54523957 fbshipit-source-id: 196b00710f5980344406aa435eebe75a97430ddf
1 parent 8602e9e commit 47b837b

File tree

3 files changed

+83
-208
lines changed

3 files changed

+83
-208
lines changed

exir/_serialize/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,5 +61,6 @@ runtime.python_library(
6161
],
6262
deps = [
6363
"//executorch/exir:schema",
64+
"//executorch/exir:tensor",
6465
],
6566
)

exir/_serialize/_program.py

Lines changed: 81 additions & 208 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from dataclasses import dataclass
1414
from typing import ClassVar, List, Literal, Optional, Tuple
1515

16+
from executorch.exir._serialize._cord import Cord
1617
from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass
1718
from executorch.exir._serialize._flatbuffer import (
1819
_FlatbufferResult,
@@ -29,6 +30,7 @@
2930
Program,
3031
SubsegmentOffsets,
3132
)
33+
from executorch.exir.tensor import ALIGNMENT
3234

3335

3436
# Byte order of numbers written to program headers. Always little-endian
@@ -240,15 +242,15 @@ def _get_extended_header(program_data: bytes) -> Optional[_ExtendedHeader]:
240242

241243

242244
def _extract_delegate_segments(
243-
program: Program, segments: List[bytes], segment_alignment: int
245+
program: Program,
246+
segments: List[Cord],
244247
) -> None:
245-
"""The input program and segments list are modified in place.
248+
"""Extracts the delegate segments inlined in the program into a list of buffers.
249+
The program is modified in-place to remove the delegate data.
246250
247251
Args:
248252
program: The program to extract segments from. Modified in-place.
249-
segments: A list to which extracted segments will be appended. Modified in-place.
250-
segment_alignment: Alignment in bytes. The starting offset of each
251-
segment will be aligned to this value.
253+
segments: A list of buffers to append extracted segments to. Modified in-place.
252254
"""
253255
remaining_inline: List[BackendDelegateInlineData] = []
254256
inline_indices_seen: set[int] = set()
@@ -278,24 +280,11 @@ def _extract_delegate_segments(
278280
if inline.data:
279281
# Move the delegate data out of the program.
280282
segment_index = len(segments)
281-
segments.append(inline.data)
283+
segments.append(Cord(inline.data))
282284
delegate.processed = BackendDelegateDataReference(
283285
location=DataLocation.SEGMENT,
284286
index=segment_index,
285287
)
286-
287-
# Update the segment list in the root Program object.
288-
prev_end = (
289-
program.segments[-1].offset + program.segments[-1].size
290-
if program.segments
291-
else 0
292-
)
293-
program.segments.append(
294-
DataSegment(
295-
offset=_aligned_size(prev_end, segment_alignment),
296-
size=len(inline.data),
297-
),
298-
)
299288
else:
300289
# Not moving into a segment. Keep it inline, but update the
301290
# index.
@@ -321,183 +310,32 @@ def _extract_delegate_segments(
321310
def _extract_constant_segment(
322311
constant_buffer: List[Buffer],
323312
tensor_alignment: int,
324-
) -> Tuple[bytes, List[int]]:
325-
"""Copies the tensors from the provided list into a single buffer and tracks the offsets
326-
of each tensor.
313+
) -> Tuple[Cord, List[int]]:
314+
"""Copies the tensors from the provided list into a Cord and tracks the offsets
315+
of each tensor.
327316
317+
Args:
328318
constant_buffer: list of Buffers from which to extract constants from. Not modified.
329-
tensor_alignment: Alignment in bytes. The starting offset of each tensor in the
330-
constant segment will be aligned to this value. Default to 16.
319+
tensor_alignment: Alignment in bytes. Each tensor in the cord will be padded to align
320+
with this value. Defaults to ALIGNMENT.
331321
332322
Returns:
333323
A tuple of (constant segment, list of offsets for each tensor in the segment)
334324
"""
335-
constant_segment_data: bytearray = bytearray()
325+
constant_segment_data: Cord = Cord()
336326
constant_segment_offsets: List[int] = []
337327
current_offset: int = 0
338328
for i in range(len(constant_buffer)):
339329
buffer = constant_buffer[i]
330+
constant_segment_data.append(buffer.storage)
340331
buffer_length = len(buffer.storage)
341332
pad_length = _padding_required(buffer_length, tensor_alignment)
342-
343-
# Append each constant buffer to the constant segment.
344-
constant_segment_data += buffer.storage
345-
# Add padding for all but the last tensor.
346333
if i < len(constant_buffer) - 1:
347-
constant_segment_data += b"\x00" * pad_length
348-
349-
# Append constant data offset.
334+
constant_segment_data.append(b"\x00" * pad_length)
350335
constant_segment_offsets.append(current_offset)
351336
current_offset += buffer_length + pad_length
352-
return bytes(constant_segment_data), constant_segment_offsets
353-
354-
355-
def _extract_segments(
356-
program: Program,
357-
extract_delegate_segments: bool,
358-
extract_constant_segment: bool,
359-
segment_alignment: int,
360-
constant_tensor_alignment: int,
361-
) -> Tuple[Program, List[bytes]]:
362-
"""Extracts constant and/or delegate data from a given Program into separate segments.
363-
364-
Args:
365-
program: The Program to extract segments from.
366-
extract_delegate_segments: Whether to extract delegate data blobs from the program.
367-
extract_constant_segment: Whether to extract constant data from the program.
368-
segment_alignment: Alignment in bytes. The starting offset of each
369-
segment will be aligned to this value in the output data.
370-
constant_tensor_alignment: Alignment in bytes. The starting offset of each tensor
371-
in the constant segment will be aligned to this value.
372-
Returns:
373-
A tuple of (modified program, list of segment data).
374-
Raises:
375-
ValueError, if the program already contains segments.
376-
"""
377-
if program.segments:
378-
raise ValueError(
379-
f"Program already has {len(program.segments)} segments: "
380-
+ f"{repr(program.segments)}"
381-
)
382-
383-
# Don't modify the original program.
384-
# TODO(T144120904): Could avoid yet more huge copies with a more shallow
385-
# copy, reusing the actual data blobs.
386-
program = copy.deepcopy(program)
387-
388-
# Segment data to be written to the file following the flatbuffer data.
389-
segments: List[bytes] = []
390-
391-
if extract_constant_segment:
392-
constant_segment_data, constant_segment_offsets = _extract_constant_segment(
393-
program.constant_buffer, tensor_alignment=constant_tensor_alignment
394-
)
395-
396-
if constant_segment_data:
397-
# Append constant_segment_data to the list of segments if non-empty.
398-
segments.append(constant_segment_data)
399-
# Append constant_segment offset to the list of DataSegments. Added as the
400-
# first segment here, but it's not mandatory that the constant segment be
401-
# the first one in the file.
402-
program.segments.append(
403-
DataSegment(offset=0, size=len(constant_segment_data))
404-
)
405-
406-
# Fill in constant_segment offsets and clear the constant buffer; only one of
407-
# constant_segment and constant_buffer should be non-empty.
408-
program.constant_segment = SubsegmentOffsets(
409-
segment_index=0, offsets=constant_segment_offsets
410-
)
411-
program.constant_buffer = []
412-
413-
if extract_delegate_segments:
414-
_extract_delegate_segments(
415-
program, segments=segments, segment_alignment=segment_alignment
416-
)
417-
return program, segments
418-
419-
420-
def _append_segments(
421-
program_data: bytes,
422-
segments: List[bytes],
423-
alignment: int,
424-
segment_table: List[DataSegment],
425-
base_offset: int,
426-
) -> bytes:
427-
"""Appends segments to the end of the program data.
428-
429-
Appends each element of `segments` to `program_data`, with '\0' padding to
430-
ensure that the offset of each segment is aligned to `alignment`.
431-
432-
Args:
433-
program_data: The flatbuffer-serialized Program.
434-
segments: The list of segments to append to `program_data`.
435-
alignment: Alignment in bytes. The starting offset of each
436-
segment will be aligned to this value in the output data.
437-
segment_table: The expected offsets and sizes of each element in
438-
`segments`. This is typically `program.segments`. Must have the
439-
same length as `segments`.
440-
base_offset: The expected segment base offset from the extended header.
441-
Should point to the aligned offset following the end of
442-
`program_data`.
443-
Returns:
444-
A copy of `program_data` with the segment data and padding appended.
445-
If there are no segments, returns `program_data` directly.
446-
Raises:
447-
ValueError: If the length of `segments` doesn't match the length of
448-
`segment_table`.
449-
"""
450-
if len(segments) != len(segment_table):
451-
raise ValueError(
452-
f"Segments length {len(segments)} does not match "
453-
+ f"segment_table length {len(segment_table)}"
454-
)
455-
if not segments:
456-
return program_data
457-
458-
# The pieces that will be concatenated to create the output data.
459-
# `program_data` will be its first element.
460-
padded_segments: List[bytes] = []
461-
# Length of all elements in padded_segments. Only used for assertions.
462-
current_offset: int = 0
463-
for i, segment in enumerate([program_data] + segments):
464-
# Add padding if necessary to align the start of this segment.
465-
pad_length: int = _padding_required(current_offset, alignment)
466-
if pad_length > 0:
467-
padded_segments.append(b"\x00" * pad_length)
468-
current_offset += pad_length
469-
470-
# Make sure that we're about to add this segment to the offset that
471-
# agrees with program.segments. Skip the first entry, which is the
472-
# Program itself and isn't included in program.segments.
473-
if i == 1:
474-
# The first real segment should start at the base offset.
475-
assert current_offset == base_offset, (
476-
f"Offset of first segment {current_offset} "
477-
+ f"!= base_offset {base_offset}"
478-
)
479-
if i > 0:
480-
# Adding a real segment, not `program_data`.
481-
expected_segment = segment_table[i - 1]
482-
expected_offset = base_offset + expected_segment.offset
483-
assert current_offset == expected_offset, (
484-
f"Segment {i} offset {current_offset} "
485-
+ f"!= expected offset {expected_offset} "
486-
+ f"(base {base_offset} + {expected_segment.offset}) "
487-
)
488-
assert expected_segment.size == len(segment), (
489-
f"Segment {i} size {len(segment)} "
490-
+ f"!= expected size {expected_segment.size}"
491-
)
492-
493-
# Add the payload. If this is the final segment, it does not need
494-
# padding after it.
495-
padded_segments.append(segment)
496-
current_offset += len(segment)
497337

498-
# Use join() instead of appending to avoid O(n) reallocation of these
499-
# potentially-large buffers.
500-
return b"".join(padded_segments)
338+
return constant_segment_data, constant_segment_offsets
501339

502340

503341
def serialize_pte_binary(
@@ -524,9 +362,8 @@ def serialize_pte_binary(
524362
into a separate segment.
525363
segment_alignment: Alignment in bytes. The starting offset of each
526364
segment will be aligned to this value in the output data.
527-
constant_tensor_alignment: If provided, the minimum alignment of tensor
528-
buffers in the program. Must be a power of 2. If not provided, uses
529-
the value in the schema file.
365+
constant_tensor_alignment: The minimum alignment of tensor
366+
buffers in the program. Must be a power of 2. Defaults to ALIGNMENT.
530367
delegate_alignment: If provided, the minimum alignment of delegate data
531368
in the program. Must be a power of 2. If not provided, uses the
532369
value in the schema file.
@@ -535,20 +372,53 @@ def serialize_pte_binary(
535372
"""
536373
# Default tensor alignment.
537374
if constant_tensor_alignment is None:
538-
constant_tensor_alignment = 16
375+
constant_tensor_alignment = ALIGNMENT
539376

540-
# Segment data to be written to the file following the flatbuffer data.
541-
segments: List[bytes] = []
377+
# Don't modify the original program.
378+
# TODO(T144120904): Could avoid yet more huge copies with a more shallow
379+
# copy, reusing the actual data blobs.
380+
program = copy.deepcopy(program)
381+
382+
# Store extracted segment data; this may be constant data or delegate data.
383+
segments: List[Cord] = []
384+
385+
if extract_constant_segment:
386+
constant_segment_data, constant_segment_offsets = _extract_constant_segment(
387+
program.constant_buffer, tensor_alignment=constant_tensor_alignment
388+
)
389+
if len(constant_segment_data) > 0:
390+
# Update program.constant_segment with constant subsegment offset information.
391+
program.constant_segment = SubsegmentOffsets(
392+
segment_index=len(segments), offsets=constant_segment_offsets
393+
)
394+
# Clear the constant buffer, as constant data will be stored in segments.
395+
program.constant_buffer = []
396+
# Add to the aggregate segments cord.
397+
segments.append(constant_segment_data)
542398

543-
# Extract constant segment and delegate segments, if requested.
544-
if extract_constant_segment or extract_delegate_segments:
545-
program, segments = _extract_segments(
546-
program=program,
547-
extract_delegate_segments=extract_delegate_segments,
548-
extract_constant_segment=extract_constant_segment,
549-
segment_alignment=segment_alignment,
550-
constant_tensor_alignment=constant_tensor_alignment,
399+
if extract_delegate_segments:
400+
_extract_delegate_segments(program, segments)
401+
402+
# Append all segments into a single Cord, adding any necessary padding to ensure that
403+
# each segment begins at the required alignment.
404+
# Update program.segments with the offsets to each segment.
405+
segments_data = Cord()
406+
for data in segments:
407+
prev_end = (
408+
(program.segments[-1].offset + program.segments[-1].size)
409+
if program.segments
410+
else 0
411+
)
412+
program.segments.append(
413+
DataSegment(
414+
offset=_aligned_size(prev_end, segment_alignment), size=len(data)
415+
)
551416
)
417+
# Add to aggregate segments cord with padding.
418+
padding_length = _padding_required(len(segments_data), segment_alignment)
419+
if padding_length > 0:
420+
segments_data.append(b"\x00" * padding_length)
421+
segments_data.append(data)
552422

553423
# Convert to a standard flatbuffer binary.
554424
result: _FlatbufferResult = _program_json_to_flatbuffer(
@@ -558,7 +428,7 @@ def serialize_pte_binary(
558428
)
559429

560430
# If there are no segments present, do not insert the extended header.
561-
if not segments:
431+
if len(segments_data) == 0:
562432
return result.data
563433

564434
# Size of the header to insert. Its size is padded to the largest
@@ -572,7 +442,7 @@ def serialize_pte_binary(
572442
# Offset to the first segment, or zero if there are no segments.
573443
segment_base_offset: int = (
574444
_aligned_size(input_size=program_size, alignment=segment_alignment)
575-
if segments
445+
if len(segments_data) > 0
576446
else 0
577447
)
578448

@@ -600,18 +470,21 @@ def serialize_pte_binary(
600470
assert eh.program_size == program_size
601471
assert eh.segment_base_offset == segment_base_offset
602472

603-
if segments:
604-
# Add segments to the end of the data, in order, with the appropriate
605-
# padding.
606-
program_data = _append_segments(
607-
program_data=program_data,
608-
segments=segments,
609-
alignment=segment_alignment,
610-
segment_table=program.segments,
611-
base_offset=segment_base_offset,
612-
)
613-
614-
return program_data
473+
# Construct the final pte file containing:
474+
# - program data; written to offset 0.
475+
# - segments data (optional); aligned to segment_alignment.
476+
pte_data = Cord(program_data)
477+
if len(segments_data) > 0:
478+
padding_length = _padding_required(len(pte_data), segment_alignment)
479+
pte_data.append(b"\x00" * padding_length)
480+
# The first segment after program data should start at the segment base offset.
481+
assert (
482+
len(pte_data) == segment_base_offset
483+
), f"Offset of first segment {len(pte_data)} != segment base offset {segment_base_offset}"
484+
pte_data.append(segments_data)
485+
486+
# TODO(lfq): this creates a copy of all the data; once we update existing callsites this will change.
487+
return bytes(pte_data)
615488

616489

617490
def _restore_segments(program: Program, segment_data: bytes) -> Program:

0 commit comments

Comments
 (0)