Skip to content

Commit f2987eb

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
embed extended header inside .ptd flatbuffer section (#7965)
Summary: Embed the header inside the flatbuffer. We do this for .pte and it lets us reuse a lot of flatbuffer tools natively. Differential Revision: D68578075
1 parent a836b64 commit f2987eb

File tree

4 files changed

+89
-29
lines changed

4 files changed

+89
-29
lines changed

extension/flat_tensor/serialize/serialize.cpp

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,38 @@ runtime::Error save_ptd(
109109
tensor_alignment,
110110
builder.CreateVector(tensors),
111111
builder.CreateVector(buffers));
112-
builder.Finish(flat_tensor); // Our flatbuffer is created now.
112+
builder.Finish(flat_tensor, ::flat_tensor_flatbuffer::FlatTensorIdentifier());
113+
// Our flatbuffer is created now.
113114

114115
// Calculate flatbuffer padding.
115116
auto padded_flatbufer_size =
116117
aligned_size(builder.GetSize(), tensor_alignment);
117118
auto padded_header_size =
118119
aligned_size(FlatTensorHeader::kHeaderExpectedLength, tensor_alignment);
119120

121+
// The general structure of the file is:
122+
// [flatbuffer offset to root table][flat tensor magic bytes][header]
123+
// [flatbuffer contents][padding][segment data][padding].
124+
// This means we first serialize the first 8 bytes of the flatbuffer,
125+
// updating the offset to the root table, then the header, then the
126+
// flatbuffer. We are embedding the header inside the flatbuffer doing
127+
// this which allows us to continue using flatbuffer tools directly on the
128+
// .ptd file.
129+
130+
// Calculate new offset to root table.
131+
uint32_t current_offset =
132+
*reinterpret_cast<uint32_t*>(builder.GetBufferPointer());
133+
uint32_t new_offset = current_offset + padded_header_size;
134+
135+
// Write flatbuffer offset to root table
136+
out.write(reinterpret_cast<const char*>(&new_offset), sizeof(new_offset));
137+
138+
// Write flatbuffer magic bytes
139+
out.write(
140+
reinterpret_cast<const char*>(builder.GetBufferPointer()) +
141+
sizeof(new_offset),
142+
4); // This is the 'FT01' magic bytes from flat_tensor.fbs.
143+
120144
// Write header
121145
out.write(FlatTensorHeader::kMagic, sizeof(FlatTensorHeader::kMagic));
122146
out.write(
@@ -149,10 +173,10 @@ runtime::Error save_ptd(
149173
padding_required(
150174
FlatTensorHeader::kHeaderExpectedLength, tensor_alignment));
151175

152-
// Write flatbuffer
176+
// Write flatbuffer, offset by 8 bytes since we wrote those before the header.
153177
out.write(
154-
reinterpret_cast<const char*>(builder.GetBufferPointer()),
155-
builder.GetSize());
178+
reinterpret_cast<const char*>(builder.GetBufferPointer()) + 8,
179+
builder.GetSize() - 8);
156180

157181
// Write flatbuffer padding
158182
write_nulls(out, padding_required(builder.GetSize(), tensor_alignment));

extension/flat_tensor/serialize/serialize.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass
1818

1919
from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile
20+
from executorch.exir._serialize._program import _insert_flatbuffer_header
2021
from executorch.exir._serialize.data_serializer import DataPayload, DataSerializer
2122

2223
from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required
@@ -197,6 +198,17 @@ def to_bytes(self) -> bytes:
197198
return data
198199

199200

201+
def _get_extended_header(flat_tensor_data: bytes) -> Optional[FlatTensorHeader]:
202+
"""Returns the extended header of the flat_tensor data, if present and valid."""
203+
try:
204+
eh = FlatTensorHeader.from_bytes(flat_tensor_data[8:])
205+
if eh.is_valid():
206+
return eh
207+
except ValueError:
208+
pass
209+
return None
210+
211+
200212
class FlatTensorSerializer(DataSerializer):
201213
"""A concrete implementation of the DataSerializer interface that
202214
serializes and deserializes data to/from the FlatTensor format.
@@ -299,14 +311,29 @@ def serialize(
299311

300312
# Pad header and payload to segment alignment.
301313
header_data = pad_to(header_data, padded_header_length)
314+
original_flatbuffer_payload_size = len(flatbuffer_payload)
302315
flatbuffer_payload.append(
303316
b"\x00" * (padded_flatbuffer_length - len(flatbuffer_payload))
304317
)
318+
injected_flatbuffer_data: bytes = _insert_flatbuffer_header(
319+
flatbuffer_data=flatbuffer_payload.__bytes__(),
320+
magic_regex=r"FT[0-9a-zA-Z][0-9a-zA-Z]",
321+
header_data=header_data,
322+
)
323+
324+
eh = _get_extended_header(injected_flatbuffer_data)
325+
assert eh is not None
326+
assert eh.flatbuffer_size == original_flatbuffer_payload_size
327+
assert eh.segment_base_offset == segment_base_offset
328+
assert eh.flatbuffer_offset == padded_header_length
329+
assert eh.segment_data_size == len(flat_tensor_data)
330+
331+
del header_data
332+
del flatbuffer_payload
305333

306334
# Place everything into one segment.
307335
payload = Cord()
308-
payload.append(header_data)
309-
payload.append(flatbuffer_payload)
336+
payload.append(injected_flatbuffer_data)
310337
payload.append(flat_tensor_data)
311338

312339
return payload

extension/flat_tensor/test/test_serialize.cpp

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -53,35 +53,46 @@ TEST_F(FlatTensorSerializeTest, ValidFlatTensorSerialized) {
5353
auto x = buf.str();
5454
const char* byte_buffer = x.c_str();
5555

56-
// Check Magic
57-
EXPECT_EQ(byte_buffer[0], 'F');
58-
EXPECT_EQ(byte_buffer[1], 'H');
59-
EXPECT_EQ(byte_buffer[2], '0');
60-
EXPECT_EQ(byte_buffer[3], '1');
56+
// First 4 bytes are an offset to the flatbuffer root table.
57+
58+
// Check magic ids.
59+
EXPECT_EQ(byte_buffer[4], 'F');
60+
EXPECT_EQ(byte_buffer[5], 'T');
61+
ASSERT_EQ(byte_buffer[6], '0');
62+
ASSERT_EQ(byte_buffer[7], '1');
63+
64+
ASSERT_EQ(byte_buffer[8], 'F');
65+
ASSERT_EQ(byte_buffer[9], 'H');
66+
EXPECT_EQ(byte_buffer[10], '0');
67+
EXPECT_EQ(byte_buffer[11], '1');
6168

6269
// Check Header
63-
EXPECT_EQ( // Header length
64-
*(uint32_t*)(byte_buffer + 4),
70+
auto header_buffer = byte_buffer + 8;
71+
EXPECT_EQ( // Check expected length
72+
*(uint32_t*)(header_buffer + 4),
6573
executorch::extension::FlatTensorHeader::kHeaderExpectedLength);
74+
6675
EXPECT_EQ(
67-
*(uint64_t*)(byte_buffer + 8),
68-
48); // Flatbuffer offset, header is 40 bytes + 8 bytes of padding today,
69-
// and then the flatbuffer starts.
76+
*(uint64_t*)(header_buffer + 8),
77+
48); // Flatbuffer offset, header is 40 bytes + 8 bytes of padding
78+
// today, and then the flatbuffer starts.
79+
7080
EXPECT_EQ(
71-
*(uint64_t*)(byte_buffer + 16),
72-
224); // Flatbuffer size, This is fragile, and depends on the schema, the
73-
// builder, and the padding needed.
74-
const uint64_t segment_offset = 48 +
75-
224; // Segment offset, depends on the padded header and flatbuffer sizes.
76-
EXPECT_EQ(*(uint64_t*)(byte_buffer + 24), segment_offset);
81+
*(uint64_t*)(header_buffer + 16),
82+
232); // Flatbuffer size. This is fragile, and depends on the schema,
83+
// the builder, and the padding needed.
84+
85+
// Segment offset, depends on the padded header and flatbuffer sizes.
86+
const uint64_t segment_offset = 48 + 232 + 8; // 8 is padding.
87+
EXPECT_EQ(*(uint64_t*)(header_buffer + 24), segment_offset);
7788

7889
EXPECT_EQ(
79-
*(uint64_t*)(byte_buffer + 32),
90+
*(uint64_t*)(header_buffer + 32),
8091
20); // Segment total size, 8 bytes of data (2 floats), 24 bytes of
8192
// padding.
8293

8394
// Check Flatbuffer
84-
auto flat_tensor = ::flat_tensor_flatbuffer::GetFlatTensor(byte_buffer + 48);
95+
auto flat_tensor = ::flat_tensor_flatbuffer::GetFlatTensor(byte_buffer);
8596

8697
EXPECT_EQ(
8798
flat_tensor->version(),

extension/flat_tensor/test/test_serialize.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test_serialize(self) -> None:
8080

8181
# Check header.
8282
header = FlatTensorHeader.from_bytes(
83-
serialized_data[0 : FlatTensorHeader.EXPECTED_LENGTH]
83+
serialized_data[8 : FlatTensorHeader.EXPECTED_LENGTH + 8]
8484
)
8585
self.assertTrue(header.is_valid())
8686

@@ -107,15 +107,13 @@ def test_serialize(self) -> None:
107107

108108
# Confirm the flatbuffer magic is present.
109109
self.assertEqual(
110-
serialized_data[
111-
header.flatbuffer_offset + 4 : header.flatbuffer_offset + 8
112-
],
110+
serialized_data[4:8],
113111
b"FT01",
114112
)
115113

116114
# Check flat tensor data.
117115
flat_tensor_bytes = serialized_data[
118-
header.flatbuffer_offset : header.flatbuffer_offset + header.flatbuffer_size
116+
0 : header.flatbuffer_offset + header.flatbuffer_size
119117
]
120118

121119
flat_tensor = _deserialize_to_flat_tensor(flat_tensor_bytes)

0 commit comments

Comments
 (0)