11
11
import re
12
12
13
13
from dataclasses import dataclass
14
- from typing import ClassVar , List , Literal , Optional , Tuple
14
+ from typing import ClassVar , List , Optional , Tuple
15
15
16
16
from executorch .exir ._serialize ._cord import Cord
17
17
from executorch .exir ._serialize ._dataclass import _DataclassEncoder , _json_to_dataclass
21
21
_program_json_to_flatbuffer ,
22
22
)
23
23
24
+ from executorch .exir ._serialize .utils import (
25
+ aligned_size ,
26
+ HEADER_BYTEORDER ,
27
+ pad_to ,
28
+ padding_required ,
29
+ )
30
+
24
31
from executorch .exir .schema import (
25
32
BackendDelegateDataReference ,
26
33
BackendDelegateInlineData ,
33
40
from executorch .exir .tensor import ALIGNMENT
34
41
35
42
36
- # Byte order of numbers written to program headers. Always little-endian
37
- # regardless of the host system, since all commonly-used modern CPUs are little
38
- # endian.
39
- _HEADER_BYTEORDER : Literal ["little" ] = "little"
40
-
41
-
42
43
def _program_to_json (program : Program ) -> str :
43
44
"""Returns the JSON representation of the given Program."""
44
45
return json .dumps (program , cls = _DataclassEncoder )
@@ -50,19 +51,6 @@ def _json_to_program(program_json: bytes) -> Program:
50
51
return _json_to_dataclass (json .loads (program_json ), cls = Program )
51
52
52
53
53
- def _padding_required (offset : int , alignment : int ) -> int :
54
- """Returns the padding required to align `offset` to `alignment`."""
55
- remainder : int = offset % alignment
56
- if remainder != 0 :
57
- return alignment - remainder
58
- return 0
59
-
60
-
61
- def _aligned_size (input_size : int , alignment : int ) -> int :
62
- """Returns input_size padded up to the next whole multiple of alignment."""
63
- return input_size + _padding_required (input_size , alignment )
64
-
65
-
66
54
def _insert_flatbuffer_header (
67
55
flatbuffer_data : bytes , magic_regex : str , header_data : bytes
68
56
) -> bytes :
@@ -102,11 +90,11 @@ def _insert_flatbuffer_header(
102
90
return flatbuffer_data
103
91
104
92
# We will need to adjust the root object offset after inserting the header.
105
- root_offset = int .from_bytes (flatbuffer_data [0 :4 ], byteorder = _HEADER_BYTEORDER )
93
+ root_offset = int .from_bytes (flatbuffer_data [0 :4 ], byteorder = HEADER_BYTEORDER )
106
94
107
95
return (
108
96
# New root offset.
109
- (root_offset + len (header_data )).to_bytes (4 , byteorder = _HEADER_BYTEORDER )
97
+ (root_offset + len (header_data )).to_bytes (4 , byteorder = HEADER_BYTEORDER )
110
98
# Existing magic bytes.
111
99
+ flatbuffer_data [4 :8 ]
112
100
# Provided header + padding.
@@ -171,11 +159,9 @@ def from_bytes(data: bytes) -> "_ExtendedHeader":
171
159
172
160
return _ExtendedHeader (
173
161
magic = data [0 :4 ],
174
- length = int .from_bytes (data [4 :8 ], byteorder = _HEADER_BYTEORDER ),
175
- program_size = int .from_bytes (data [8 :16 ], byteorder = _HEADER_BYTEORDER ),
176
- segment_base_offset = int .from_bytes (
177
- data [16 :24 ], byteorder = _HEADER_BYTEORDER
178
- ),
162
+ length = int .from_bytes (data [4 :8 ], byteorder = HEADER_BYTEORDER ),
163
+ program_size = int .from_bytes (data [8 :16 ], byteorder = HEADER_BYTEORDER ),
164
+ segment_base_offset = int .from_bytes (data [16 :24 ], byteorder = HEADER_BYTEORDER ),
179
165
)
180
166
181
167
def is_valid (self ) -> bool :
@@ -201,35 +187,16 @@ def to_bytes(self) -> bytes:
201
187
# fields to this header in the future. Always use the proper size
202
188
# (i.e., ignore self.length) since there's no reason to create an
203
189
# invalid header.
204
- + self .EXPECTED_LENGTH .to_bytes (4 , byteorder = _HEADER_BYTEORDER )
190
+ + self .EXPECTED_LENGTH .to_bytes (4 , byteorder = HEADER_BYTEORDER )
205
191
# uint64_t: Size of the flatbuffer data, including this header.
206
- + self .program_size .to_bytes (8 , byteorder = _HEADER_BYTEORDER )
192
+ + self .program_size .to_bytes (8 , byteorder = HEADER_BYTEORDER )
207
193
# uint64_t: Offset to the start of the first segment, or zero if
208
194
# there are no segments.
209
- + self .segment_base_offset .to_bytes (8 , byteorder = _HEADER_BYTEORDER )
195
+ + self .segment_base_offset .to_bytes (8 , byteorder = HEADER_BYTEORDER )
210
196
)
211
197
return data
212
198
213
199
214
- def _pad_to (data : bytes , length : int ) -> bytes :
215
- """Returns the input followed by enough zero bytes to become the requested length.
216
-
217
- Args:
218
- data: The data to pad.
219
- length: The length of the returned data.
220
- Returns:
221
- The padded data.
222
- Raises:
223
- ValueError: If the requested length is less than the input length.
224
- """
225
- if length < len (data ):
226
- raise ValueError (f"Data length { len (data )} > padded length { length } " )
227
- if length > len (data ):
228
- data = data + b"\x00 " * (length - len (data ))
229
- assert len (data ) == length
230
- return data
231
-
232
-
233
200
def _get_extended_header (program_data : bytes ) -> Optional [_ExtendedHeader ]:
234
201
"""Returns the extended header of the program data, if present and valid."""
235
202
try :
@@ -330,7 +297,7 @@ def _extract_constant_segment(
330
297
constant_segment_data .append (buffer .storage )
331
298
buffer_length = len (buffer .storage )
332
299
pad_length = (
333
- _padding_required (buffer_length , tensor_alignment )
300
+ padding_required (buffer_length , tensor_alignment )
334
301
if tensor_alignment is not None
335
302
else 0
336
303
)
@@ -432,11 +399,11 @@ def serialize_pte_binary(
432
399
)
433
400
program .segments .append (
434
401
DataSegment (
435
- offset = _aligned_size (prev_end , segment_alignment ), size = len (data )
402
+ offset = aligned_size (prev_end , segment_alignment ), size = len (data )
436
403
)
437
404
)
438
405
# Add to aggregate segments cord with padding.
439
- padding_length = _padding_required (len (segments_data ), segment_alignment )
406
+ padding_length = padding_required (len (segments_data ), segment_alignment )
440
407
if padding_length > 0 :
441
408
segments_data .append (b"\x00 " * padding_length )
442
409
segments_data .append (data )
@@ -454,15 +421,15 @@ def serialize_pte_binary(
454
421
455
422
# Size of the header to insert. Its size is padded to the largest
456
423
# force_align value present in the schema.
457
- padded_header_length : int = _aligned_size (
424
+ padded_header_length : int = aligned_size (
458
425
input_size = _ExtendedHeader .EXPECTED_LENGTH ,
459
426
alignment = result .max_alignment ,
460
427
)
461
428
# Size of the program with the header inserted.
462
429
program_size : int = padded_header_length + len (result .data )
463
430
# Offset to the first segment, or zero if there are no segments.
464
431
segment_base_offset : int = (
465
- _aligned_size (input_size = program_size , alignment = segment_alignment )
432
+ aligned_size (input_size = program_size , alignment = segment_alignment )
466
433
if len (segments_data ) > 0
467
434
else 0
468
435
)
@@ -471,7 +438,7 @@ def serialize_pte_binary(
471
438
header_data : bytes = _ExtendedHeader (
472
439
program_size = program_size , segment_base_offset = segment_base_offset
473
440
).to_bytes ()
474
- header_data = _pad_to (header_data , padded_header_length )
441
+ header_data = pad_to (header_data , padded_header_length )
475
442
476
443
# Insert the header into the flatbuffer data.
477
444
program_data : bytes = _insert_flatbuffer_header (
@@ -496,7 +463,7 @@ def serialize_pte_binary(
496
463
# - segments data (optional); aligned to segment_alignment.
497
464
pte_data = Cord (program_data )
498
465
if len (segments_data ) > 0 :
499
- padding_length = _padding_required (len (pte_data ), segment_alignment )
466
+ padding_length = padding_required (len (pte_data ), segment_alignment )
500
467
pte_data .append (b"\x00 " * padding_length )
501
468
# The first segment after program data should start at the segment base offset.
502
469
assert (
0 commit comments