13
13
from dataclasses import dataclass
14
14
from typing import ClassVar , List , Literal , Optional , Tuple
15
15
16
+ from executorch .exir ._serialize ._cord import Cord
16
17
from executorch .exir ._serialize ._dataclass import _DataclassEncoder , _json_to_dataclass
17
18
from executorch .exir ._serialize ._flatbuffer import (
18
19
_FlatbufferResult ,
29
30
Program ,
30
31
SubsegmentOffsets ,
31
32
)
33
+ from executorch .exir .tensor import ALIGNMENT
32
34
33
35
34
36
# Byte order of numbers written to program headers. Always little-endian
@@ -240,15 +242,15 @@ def _get_extended_header(program_data: bytes) -> Optional[_ExtendedHeader]:
240
242
241
243
242
244
def _extract_delegate_segments (
243
- program : Program , segments : List [bytes ], segment_alignment : int
245
+ program : Program ,
246
+ segments : List [Cord ],
244
247
) -> None :
245
- """The input program and segments list are modified in place.
248
+ """Extracts the delegate segments inlined in the program into a list of buffers.
249
+ The program is modified in-place to remove the delegate data.
246
250
247
251
Args:
248
252
program: The program to extract segments from. Modified in-place.
249
- segments: A list to which extracted segments will be appended. Modified in-place.
250
- segment_alignment: Alignment in bytes. The starting offset of each
251
- segment will be aligned to this value.
253
+ segments: A list of buffers to append extracted segments to. Modified in-place.
252
254
"""
253
255
remaining_inline : List [BackendDelegateInlineData ] = []
254
256
inline_indices_seen : set [int ] = set ()
@@ -278,24 +280,11 @@ def _extract_delegate_segments(
278
280
if inline .data :
279
281
# Move the delegate data out of the program.
280
282
segment_index = len (segments )
281
- segments .append (inline .data )
283
+ segments .append (Cord ( inline .data ) )
282
284
delegate .processed = BackendDelegateDataReference (
283
285
location = DataLocation .SEGMENT ,
284
286
index = segment_index ,
285
287
)
286
-
287
- # Update the segment list in the root Program object.
288
- prev_end = (
289
- program .segments [- 1 ].offset + program .segments [- 1 ].size
290
- if program .segments
291
- else 0
292
- )
293
- program .segments .append (
294
- DataSegment (
295
- offset = _aligned_size (prev_end , segment_alignment ),
296
- size = len (inline .data ),
297
- ),
298
- )
299
288
else :
300
289
# Not moving into a segment. Keep it inline, but update the
301
290
# index.
@@ -321,183 +310,32 @@ def _extract_delegate_segments(
321
310
def _extract_constant_segment (
322
311
constant_buffer : List [Buffer ],
323
312
tensor_alignment : int ,
324
- ) -> Tuple [bytes , List [int ]]:
325
- """Copies the tensors from the provided list into a single buffer and tracks the offsets
326
- of each tensor.
313
+ ) -> Tuple [Cord , List [int ]]:
314
+ """Copies the tensors from the provided list into a Cord and tracks the offsets
315
+ of each tensor.
327
316
317
+ Args:
328
318
constant_buffer: list of Buffers from which to extract constants from. Not modified.
329
- tensor_alignment: Alignment in bytes. The starting offset of each tensor in the
330
- constant segment will be aligned to this value. Default to 16 .
319
+ tensor_alignment: Alignment in bytes. Each tensor in the cord will be padded to align
320
+ with this value. Defaults to ALIGNMENT .
331
321
332
322
Returns:
333
323
A tuple of (constant segment, list of offsets for each tensor in the segment)
334
324
"""
335
- constant_segment_data : bytearray = bytearray ()
325
+ constant_segment_data : Cord = Cord ()
336
326
constant_segment_offsets : List [int ] = []
337
327
current_offset : int = 0
338
328
for i in range (len (constant_buffer )):
339
329
buffer = constant_buffer [i ]
330
+ constant_segment_data .append (buffer .storage )
340
331
buffer_length = len (buffer .storage )
341
332
pad_length = _padding_required (buffer_length , tensor_alignment )
342
-
343
- # Append each constant buffer to the constant segment.
344
- constant_segment_data += buffer .storage
345
- # Add padding for all but the last tensor.
346
333
if i < len (constant_buffer ) - 1 :
347
- constant_segment_data += b"\x00 " * pad_length
348
-
349
- # Append constant data offset.
334
+ constant_segment_data .append (b"\x00 " * pad_length )
350
335
constant_segment_offsets .append (current_offset )
351
336
current_offset += buffer_length + pad_length
352
- return bytes (constant_segment_data ), constant_segment_offsets
353
-
354
-
355
- def _extract_segments (
356
- program : Program ,
357
- extract_delegate_segments : bool ,
358
- extract_constant_segment : bool ,
359
- segment_alignment : int ,
360
- constant_tensor_alignment : int ,
361
- ) -> Tuple [Program , List [bytes ]]:
362
- """Extracts constant and/or delegate data from a given Program into separate segments.
363
-
364
- Args:
365
- program: The Program to extract segments from.
366
- extract_delegate_segments: Whether to extract delegate data blobs from the program.
367
- extract_constant_segment: Whether to extract constant data from the program.
368
- segment_alignment: Alignment in bytes. The starting offset of each
369
- segment will be aligned to this value in the output data.
370
- constant_tensor_alignment: Alignment in bytes. The starting offset of each tensor
371
- in the constant segment will be aligned to this value.
372
- Returns:
373
- A tuple of (modified program, list of segment data).
374
- Raises:
375
- ValueError, if the program already contains segments.
376
- """
377
- if program .segments :
378
- raise ValueError (
379
- f"Program already has { len (program .segments )} segments: "
380
- + f"{ repr (program .segments )} "
381
- )
382
-
383
- # Don't modify the original program.
384
- # TODO(T144120904): Could avoid yet more huge copies with a more shallow
385
- # copy, reusing the actual data blobs.
386
- program = copy .deepcopy (program )
387
-
388
- # Segment data to be written to the file following the flatbuffer data.
389
- segments : List [bytes ] = []
390
-
391
- if extract_constant_segment :
392
- constant_segment_data , constant_segment_offsets = _extract_constant_segment (
393
- program .constant_buffer , tensor_alignment = constant_tensor_alignment
394
- )
395
-
396
- if constant_segment_data :
397
- # Append constant_segment_data to the list of segments if non-empty.
398
- segments .append (constant_segment_data )
399
- # Append constant_segment offset to the list of DataSegments. Added as the
400
- # first segment here, but it's not mandatory that the constant segment be
401
- # the first one in the file.
402
- program .segments .append (
403
- DataSegment (offset = 0 , size = len (constant_segment_data ))
404
- )
405
-
406
- # Fill in constant_segment offsets and clear the constant buffer; only one of
407
- # constant_segment and constant_buffer should be non-empty.
408
- program .constant_segment = SubsegmentOffsets (
409
- segment_index = 0 , offsets = constant_segment_offsets
410
- )
411
- program .constant_buffer = []
412
-
413
- if extract_delegate_segments :
414
- _extract_delegate_segments (
415
- program , segments = segments , segment_alignment = segment_alignment
416
- )
417
- return program , segments
418
-
419
-
420
- def _append_segments (
421
- program_data : bytes ,
422
- segments : List [bytes ],
423
- alignment : int ,
424
- segment_table : List [DataSegment ],
425
- base_offset : int ,
426
- ) -> bytes :
427
- """Appends segments to the end of the program data.
428
-
429
- Appends each element of `segments` to `program_data`, with '\0 ' padding to
430
- ensure that the offset of each segment is aligned to `alignment`.
431
-
432
- Args:
433
- program_data: The flatbuffer-serialized Program.
434
- segments: The list of segments to append to `program_data`.
435
- alignment: Alignment in bytes. The starting offset of each
436
- segment will be aligned to this value in the output data.
437
- segment_table: The expected offsets and sizes of each element in
438
- `segments`. This is typically `program.segments`. Must have the
439
- same length as `segments`.
440
- base_offset: The expected segment base offset from the extended header.
441
- Should point to the aligned offset following the end of
442
- `program_data`.
443
- Returns:
444
- A copy of `program_data` with the segment data and padding appended.
445
- If there are no segments, returns `program_data` directly.
446
- Raises:
447
- ValueError: If the length of `segments` doesn't match the length of
448
- `segment_table`.
449
- """
450
- if len (segments ) != len (segment_table ):
451
- raise ValueError (
452
- f"Segments length { len (segments )} does not match "
453
- + f"segment_table length { len (segment_table )} "
454
- )
455
- if not segments :
456
- return program_data
457
-
458
- # The pieces that will be concatenated to create the output data.
459
- # `program_data` will be its first element.
460
- padded_segments : List [bytes ] = []
461
- # Length of all elements in padded_segments. Only used for assertions.
462
- current_offset : int = 0
463
- for i , segment in enumerate ([program_data ] + segments ):
464
- # Add padding if necessary to align the start of this segment.
465
- pad_length : int = _padding_required (current_offset , alignment )
466
- if pad_length > 0 :
467
- padded_segments .append (b"\x00 " * pad_length )
468
- current_offset += pad_length
469
-
470
- # Make sure that we're about to add this segment to the offset that
471
- # agrees with program.segments. Skip the first entry, which is the
472
- # Program itself and isn't included in program.segments.
473
- if i == 1 :
474
- # The first real segment should start at the base offset.
475
- assert current_offset == base_offset , (
476
- f"Offset of first segment { current_offset } "
477
- + f"!= base_offset { base_offset } "
478
- )
479
- if i > 0 :
480
- # Adding a real segment, not `program_data`.
481
- expected_segment = segment_table [i - 1 ]
482
- expected_offset = base_offset + expected_segment .offset
483
- assert current_offset == expected_offset , (
484
- f"Segment { i } offset { current_offset } "
485
- + f"!= expected offset { expected_offset } "
486
- + f"(base { base_offset } + { expected_segment .offset } ) "
487
- )
488
- assert expected_segment .size == len (segment ), (
489
- f"Segment { i } size { len (segment )} "
490
- + f"!= expected size { expected_segment .size } "
491
- )
492
-
493
- # Add the payload. If this is the final segment, it does not need
494
- # padding after it.
495
- padded_segments .append (segment )
496
- current_offset += len (segment )
497
337
498
- # Use join() instead of appending to avoid O(n) reallocation of these
499
- # potentially-large buffers.
500
- return b"" .join (padded_segments )
338
+ return constant_segment_data , constant_segment_offsets
501
339
502
340
503
341
def serialize_pte_binary (
@@ -524,9 +362,8 @@ def serialize_pte_binary(
524
362
into a separate segment.
525
363
segment_alignment: Alignment in bytes. The starting offset of each
526
364
segment will be aligned to this value in the output data.
527
- constant_tensor_alignment: If provided, the minimum alignment of tensor
528
- buffers in the program. Must be a power of 2. If not provided, uses
529
- the value in the schema file.
365
+ constant_tensor_alignment: The minimum alignment of tensor
366
+ buffers in the program. Must be a power of 2. Defaults to ALIGNMENT.
530
367
delegate_alignment: If provided, the minimum alignment of delegate data
531
368
in the program. Must be a power of 2. If not provided, uses the
532
369
value in the schema file.
@@ -535,20 +372,53 @@ def serialize_pte_binary(
535
372
"""
536
373
# Default tensor alignment.
537
374
if constant_tensor_alignment is None :
538
- constant_tensor_alignment = 16
375
+ constant_tensor_alignment = ALIGNMENT
539
376
540
- # Segment data to be written to the file following the flatbuffer data.
541
- segments : List [bytes ] = []
377
+ # Don't modify the original program.
378
+ # TODO(T144120904): Could avoid yet more huge copies with a more shallow
379
+ # copy, reusing the actual data blobs.
380
+ program = copy .deepcopy (program )
381
+
382
+ # Store extracted segment data; this may be constant data or delegate data.
383
+ segments : List [Cord ] = []
384
+
385
+ if extract_constant_segment :
386
+ constant_segment_data , constant_segment_offsets = _extract_constant_segment (
387
+ program .constant_buffer , tensor_alignment = constant_tensor_alignment
388
+ )
389
+ if len (constant_segment_data ) > 0 :
390
+ # Update program.constant_segment with constant subsegment offset information.
391
+ program .constant_segment = SubsegmentOffsets (
392
+ segment_index = len (segments ), offsets = constant_segment_offsets
393
+ )
394
+ # Clear the constant buffer, as constant data will be stored in segments.
395
+ program .constant_buffer = []
396
+ # Add to the aggregate segments cord.
397
+ segments .append (constant_segment_data )
542
398
543
- # Extract constant segment and delegate segments, if requested.
544
- if extract_constant_segment or extract_delegate_segments :
545
- program , segments = _extract_segments (
546
- program = program ,
547
- extract_delegate_segments = extract_delegate_segments ,
548
- extract_constant_segment = extract_constant_segment ,
549
- segment_alignment = segment_alignment ,
550
- constant_tensor_alignment = constant_tensor_alignment ,
399
+ if extract_delegate_segments :
400
+ _extract_delegate_segments (program , segments )
401
+
402
+ # Append all segments into a single Cord, adding any necessary padding to ensure that
403
+ # each segment begins at the required alignment.
404
+ # Update program.segments with the offsets to each segment.
405
+ segments_data = Cord ()
406
+ for data in segments :
407
+ prev_end = (
408
+ (program .segments [- 1 ].offset + program .segments [- 1 ].size )
409
+ if program .segments
410
+ else 0
411
+ )
412
+ program .segments .append (
413
+ DataSegment (
414
+ offset = _aligned_size (prev_end , segment_alignment ), size = len (data )
415
+ )
551
416
)
417
+ # Add to aggregate segments cord with padding.
418
+ padding_length = _padding_required (len (segments_data ), segment_alignment )
419
+ if padding_length > 0 :
420
+ segments_data .append (b"\x00 " * padding_length )
421
+ segments_data .append (data )
552
422
553
423
# Convert to a standard flatbuffer binary.
554
424
result : _FlatbufferResult = _program_json_to_flatbuffer (
@@ -558,7 +428,7 @@ def serialize_pte_binary(
558
428
)
559
429
560
430
# If there are no segments present, do not insert the extended header.
561
- if not segments :
431
+ if len ( segments_data ) == 0 :
562
432
return result .data
563
433
564
434
# Size of the header to insert. Its size is padded to the largest
@@ -572,7 +442,7 @@ def serialize_pte_binary(
572
442
# Offset to the first segment, or zero if there are no segments.
573
443
segment_base_offset : int = (
574
444
_aligned_size (input_size = program_size , alignment = segment_alignment )
575
- if segments
445
+ if len ( segments_data ) > 0
576
446
else 0
577
447
)
578
448
@@ -600,18 +470,21 @@ def serialize_pte_binary(
600
470
assert eh .program_size == program_size
601
471
assert eh .segment_base_offset == segment_base_offset
602
472
603
- if segments :
604
- # Add segments to the end of the data, in order, with the appropriate
605
- # padding.
606
- program_data = _append_segments (
607
- program_data = program_data ,
608
- segments = segments ,
609
- alignment = segment_alignment ,
610
- segment_table = program .segments ,
611
- base_offset = segment_base_offset ,
612
- )
613
-
614
- return program_data
473
+ # Construct the final pte file containing:
474
+ # - program data; written to offset 0.
475
+ # - segments data (optional); aligned to segment_alignment.
476
+ pte_data = Cord (program_data )
477
+ if len (segments_data ) > 0 :
478
+ padding_length = _padding_required (len (pte_data ), segment_alignment )
479
+ pte_data .append (b"\x00 " * padding_length )
480
+ # The first segment after program data should start at the segment base offset.
481
+ assert (
482
+ len (pte_data ) == segment_base_offset
483
+ ), f"Offset of first segment { len (pte_data )} != segment base offset { segment_base_offset } "
484
+ pte_data .append (segments_data )
485
+
486
+ # TODO(lfq): this creates a copy of all the data; once we update existing callsites this will change.
487
+ return bytes (pte_data )
615
488
616
489
617
490
def _restore_segments (program : Program , segment_data : bytes ) -> Program :
0 commit comments