Skip to content

Commit c6faa2d

Browse files
committed
[executorch][serialization] Refactor flatbuffer utils into separate file
Pull Request resolved: #7254 Todo: let xnnpack and vulkan serialization use these utils instead of redefining the same functions. For usage in extension/flat_tensor/serialize. ghstack-source-id: 258747562 @exported-using-ghexport Differential Revision: [D66854756](https://our.internmc.facebook.com/intern/diff/D66854756/)
1 parent 61c9c95 commit c6faa2d

File tree

3 files changed

+66
-56
lines changed

3 files changed

+66
-56
lines changed

exir/_serialize/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ runtime.python_library(
3333
"_dataclass.py",
3434
"_flatbuffer.py",
3535
"_program.py",
36+
"utils.py",
3637
],
3738
resources = {
3839
"//executorch/schema:program.fbs": "program.fbs",

exir/_serialize/_program.py

Lines changed: 23 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import re
1212

1313
from dataclasses import dataclass
14-
from typing import ClassVar, List, Literal, Optional, Tuple
14+
from typing import ClassVar, List, Optional, Tuple
1515

1616
from executorch.exir._serialize._cord import Cord
1717
from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass
@@ -21,6 +21,13 @@
2121
_program_json_to_flatbuffer,
2222
)
2323

24+
from executorch.exir._serialize.utils import (
25+
aligned_size,
26+
HEADER_BYTEORDER,
27+
pad_to,
28+
padding_required,
29+
)
30+
2431
from executorch.exir.schema import (
2532
BackendDelegateDataReference,
2633
BackendDelegateInlineData,
@@ -33,12 +40,6 @@
3340
from executorch.exir.tensor import ALIGNMENT
3441

3542

36-
# Byte order of numbers written to program headers. Always little-endian
37-
# regardless of the host system, since all commonly-used modern CPUs are little
38-
# endian.
39-
_HEADER_BYTEORDER: Literal["little"] = "little"
40-
41-
4243
def _program_to_json(program: Program) -> str:
4344
"""Returns the JSON representation of the given Program."""
4445
return json.dumps(program, cls=_DataclassEncoder)
@@ -50,19 +51,6 @@ def _json_to_program(program_json: bytes) -> Program:
5051
return _json_to_dataclass(json.loads(program_json), cls=Program)
5152

5253

53-
def _padding_required(offset: int, alignment: int) -> int:
54-
"""Returns the padding required to align `offset` to `alignment`."""
55-
remainder: int = offset % alignment
56-
if remainder != 0:
57-
return alignment - remainder
58-
return 0
59-
60-
61-
def _aligned_size(input_size: int, alignment: int) -> int:
62-
"""Returns input_size padded up to the next whole multiple of alignment."""
63-
return input_size + _padding_required(input_size, alignment)
64-
65-
6654
def _insert_flatbuffer_header(
6755
flatbuffer_data: bytes, magic_regex: str, header_data: bytes
6856
) -> bytes:
@@ -102,11 +90,11 @@ def _insert_flatbuffer_header(
10290
return flatbuffer_data
10391

10492
# We will need to adjust the root object offset after inserting the header.
105-
root_offset = int.from_bytes(flatbuffer_data[0:4], byteorder=_HEADER_BYTEORDER)
93+
root_offset = int.from_bytes(flatbuffer_data[0:4], byteorder=HEADER_BYTEORDER)
10694

10795
return (
10896
# New root offset.
109-
(root_offset + len(header_data)).to_bytes(4, byteorder=_HEADER_BYTEORDER)
97+
(root_offset + len(header_data)).to_bytes(4, byteorder=HEADER_BYTEORDER)
11098
# Existing magic bytes.
11199
+ flatbuffer_data[4:8]
112100
# Provided header + padding.
@@ -171,11 +159,9 @@ def from_bytes(data: bytes) -> "_ExtendedHeader":
171159

172160
return _ExtendedHeader(
173161
magic=data[0:4],
174-
length=int.from_bytes(data[4:8], byteorder=_HEADER_BYTEORDER),
175-
program_size=int.from_bytes(data[8:16], byteorder=_HEADER_BYTEORDER),
176-
segment_base_offset=int.from_bytes(
177-
data[16:24], byteorder=_HEADER_BYTEORDER
178-
),
162+
length=int.from_bytes(data[4:8], byteorder=HEADER_BYTEORDER),
163+
program_size=int.from_bytes(data[8:16], byteorder=HEADER_BYTEORDER),
164+
segment_base_offset=int.from_bytes(data[16:24], byteorder=HEADER_BYTEORDER),
179165
)
180166

181167
def is_valid(self) -> bool:
@@ -201,35 +187,16 @@ def to_bytes(self) -> bytes:
201187
# fields to this header in the future. Always use the proper size
202188
# (i.e., ignore self.length) since there's no reason to create an
203189
# invalid header.
204-
+ self.EXPECTED_LENGTH.to_bytes(4, byteorder=_HEADER_BYTEORDER)
190+
+ self.EXPECTED_LENGTH.to_bytes(4, byteorder=HEADER_BYTEORDER)
205191
# uint64_t: Size of the flatbuffer data, including this header.
206-
+ self.program_size.to_bytes(8, byteorder=_HEADER_BYTEORDER)
192+
+ self.program_size.to_bytes(8, byteorder=HEADER_BYTEORDER)
207193
# uint64_t: Offset to the start of the first segment, or zero if
208194
# there are no segments.
209-
+ self.segment_base_offset.to_bytes(8, byteorder=_HEADER_BYTEORDER)
195+
+ self.segment_base_offset.to_bytes(8, byteorder=HEADER_BYTEORDER)
210196
)
211197
return data
212198

213199

214-
def _pad_to(data: bytes, length: int) -> bytes:
215-
"""Returns the input followed by enough zero bytes to become the requested length.
216-
217-
Args:
218-
data: The data to pad.
219-
length: The length of the returned data.
220-
Returns:
221-
The padded data.
222-
Raises:
223-
ValueError: If the requested length is less than the input length.
224-
"""
225-
if length < len(data):
226-
raise ValueError(f"Data length {len(data)} > padded length {length}")
227-
if length > len(data):
228-
data = data + b"\x00" * (length - len(data))
229-
assert len(data) == length
230-
return data
231-
232-
233200
def _get_extended_header(program_data: bytes) -> Optional[_ExtendedHeader]:
234201
"""Returns the extended header of the program data, if present and valid."""
235202
try:
@@ -330,7 +297,7 @@ def _extract_constant_segment(
330297
constant_segment_data.append(buffer.storage)
331298
buffer_length = len(buffer.storage)
332299
pad_length = (
333-
_padding_required(buffer_length, tensor_alignment)
300+
padding_required(buffer_length, tensor_alignment)
334301
if tensor_alignment is not None
335302
else 0
336303
)
@@ -432,11 +399,11 @@ def serialize_pte_binary(
432399
)
433400
program.segments.append(
434401
DataSegment(
435-
offset=_aligned_size(prev_end, segment_alignment), size=len(data)
402+
offset=aligned_size(prev_end, segment_alignment), size=len(data)
436403
)
437404
)
438405
# Add to aggregate segments cord with padding.
439-
padding_length = _padding_required(len(segments_data), segment_alignment)
406+
padding_length = padding_required(len(segments_data), segment_alignment)
440407
if padding_length > 0:
441408
segments_data.append(b"\x00" * padding_length)
442409
segments_data.append(data)
@@ -454,15 +421,15 @@ def serialize_pte_binary(
454421

455422
# Size of the header to insert. Its size is padded to the largest
456423
# force_align value present in the schema.
457-
padded_header_length: int = _aligned_size(
424+
padded_header_length: int = aligned_size(
458425
input_size=_ExtendedHeader.EXPECTED_LENGTH,
459426
alignment=result.max_alignment,
460427
)
461428
# Size of the program with the header inserted.
462429
program_size: int = padded_header_length + len(result.data)
463430
# Offset to the first segment, or zero if there are no segments.
464431
segment_base_offset: int = (
465-
_aligned_size(input_size=program_size, alignment=segment_alignment)
432+
aligned_size(input_size=program_size, alignment=segment_alignment)
466433
if len(segments_data) > 0
467434
else 0
468435
)
@@ -471,7 +438,7 @@ def serialize_pte_binary(
471438
header_data: bytes = _ExtendedHeader(
472439
program_size=program_size, segment_base_offset=segment_base_offset
473440
).to_bytes()
474-
header_data = _pad_to(header_data, padded_header_length)
441+
header_data = pad_to(header_data, padded_header_length)
475442

476443
# Insert the header into the flatbuffer data.
477444
program_data: bytes = _insert_flatbuffer_header(
@@ -496,7 +463,7 @@ def serialize_pte_binary(
496463
# - segments data (optional); aligned to segment_alignment.
497464
pte_data = Cord(program_data)
498465
if len(segments_data) > 0:
499-
padding_length = _padding_required(len(pte_data), segment_alignment)
466+
padding_length = padding_required(len(pte_data), segment_alignment)
500467
pte_data.append(b"\x00" * padding_length)
501468
# The first segment after program data should start at the segment base offset.
502469
assert (

exir/_serialize/utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
from typing import Literal
6+
7+
# Byte order of numbers written to program headers. Always little-endian
8+
# regardless of the host system, since all commonly-used modern CPUs are little
9+
# endian.
10+
HEADER_BYTEORDER: Literal["little"] = "little"
11+
12+
13+
def pad_to(data: bytes, length: int) -> bytes:
14+
"""Returns the input followed by enough zero bytes to become the requested length.
15+
16+
Args:
17+
data: The data to pad.
18+
length: The length of the returned data.
19+
Returns:
20+
The padded data.
21+
Raises:
22+
ValueError: If the requested length is less than the input length.
23+
"""
24+
if length < len(data):
25+
raise ValueError(f"Data length {len(data)} > padded length {length}")
26+
if length > len(data):
27+
data = data + b"\x00" * (length - len(data))
28+
assert len(data) == length
29+
return data
30+
31+
32+
def padding_required(offset: int, alignment: int) -> int:
33+
"""Returns the padding required to align `offset` to `alignment`."""
34+
remainder: int = offset % alignment
35+
if remainder != 0:
36+
return alignment - remainder
37+
return 0
38+
39+
40+
def aligned_size(input_size: int, alignment: int) -> int:
41+
"""Returns input_size padded up to the next whole multiple of alignment."""
42+
return input_size + padding_required(input_size, alignment)

0 commit comments

Comments
 (0)