Skip to content

embed extended header inside .ptd flatbuffer section #7965

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions extension/flat_tensor/serialize/serialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,39 @@ runtime::Error save_ptd(
tensor_alignment,
builder.CreateVector(tensors),
builder.CreateVector(buffers));
builder.Finish(flat_tensor); // Our flatbuffer is created now.
builder.Finish(flat_tensor, ::flat_tensor_flatbuffer::FlatTensorIdentifier());
// Our flatbuffer is created now.

// Calculate flatbuffer padding.
auto padded_flatbufer_size =
aligned_size(builder.GetSize(), tensor_alignment);
auto padded_header_size =
aligned_size(FlatTensorHeader::kHeaderExpectedLength, tensor_alignment);

// The general structure of the file is:
// [flatbuffer offset to root table][flatbuffer file indentifier]
// [FlatTensorHeader][padding][flatbuffer contents][padding]
// [segment data].
// This means we first serialize the first 8 bytes of the flatbuffer,
// updating the offset to the root table, then the header, then the
// flatbuffer. We are embedding the header inside the flatbuffer doing
// this which allows us to continue using flatbuffer tools directly on the
// .ptd file.

// Calculate new offset to root table.
uint32_t current_offset =
*reinterpret_cast<uint32_t*>(builder.GetBufferPointer());
uint32_t new_offset = current_offset + padded_header_size;

// Write flatbuffer offset to root table
out.write(reinterpret_cast<const char*>(&new_offset), sizeof(new_offset));

// Write flatbuffer magic bytes
out.write(
reinterpret_cast<const char*>(builder.GetBufferPointer()) +
sizeof(new_offset),
4); // This is the file identifier from flat_tensor.fbs.

// Write header
out.write(FlatTensorHeader::kMagic, sizeof(FlatTensorHeader::kMagic));
out.write(
Expand Down Expand Up @@ -149,10 +174,11 @@ runtime::Error save_ptd(
padding_required(
FlatTensorHeader::kHeaderExpectedLength, tensor_alignment));

// Write flatbuffer
// Write flatbuffer, offset by 8 bytes (4-byte root table offset + 4-byte
// file identifier) since we wrote those before the FlatTensorHeader.
out.write(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize());
reinterpret_cast<const char*>(builder.GetBufferPointer()) + 8,
builder.GetSize() - 8);

// Write flatbuffer padding
write_nulls(out, padding_required(builder.GetSize(), tensor_alignment));
Expand Down
31 changes: 29 additions & 2 deletions extension/flat_tensor/serialize/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass

from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile
from executorch.exir._serialize._program import _insert_flatbuffer_header
from executorch.exir._serialize.data_serializer import DataPayload, DataSerializer

from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required
Expand Down Expand Up @@ -197,6 +198,17 @@ def to_bytes(self) -> bytes:
return data


def _get_extended_header(flat_tensor_data: bytes) -> Optional[FlatTensorHeader]:
"""Returns the extended header of the flat_tensor data, if present and valid."""
try:
eh = FlatTensorHeader.from_bytes(flat_tensor_data[8:])
if eh.is_valid():
return eh
except ValueError:
pass
return None


class FlatTensorSerializer(DataSerializer):
"""A concrete implementation of the DataSerializer interface that
serializes and deserializes data to/from the FlatTensor format.
Expand Down Expand Up @@ -299,14 +311,29 @@ def serialize(

# Pad header and payload to segment alignment.
header_data = pad_to(header_data, padded_header_length)
original_flatbuffer_payload_size = len(flatbuffer_payload)
flatbuffer_payload.append(
b"\x00" * (padded_flatbuffer_length - len(flatbuffer_payload))
)
injected_flatbuffer_data: bytes = _insert_flatbuffer_header(
flatbuffer_data=flatbuffer_payload.__bytes__(),
magic_regex=r"FT[0-9a-zA-Z][0-9a-zA-Z]",
header_data=header_data,
)

eh = _get_extended_header(injected_flatbuffer_data)
assert eh is not None
assert eh.flatbuffer_size == original_flatbuffer_payload_size
assert eh.segment_base_offset == segment_base_offset
assert eh.flatbuffer_offset == padded_header_length
assert eh.segment_data_size == len(flat_tensor_data)

del header_data
del flatbuffer_payload

# Place everything into one segment.
payload = Cord()
payload.append(header_data)
payload.append(flatbuffer_payload)
payload.append(injected_flatbuffer_data)
payload.append(flat_tensor_data)

return payload
Expand Down
47 changes: 29 additions & 18 deletions extension/flat_tensor/test/test_serialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,35 +53,46 @@ TEST_F(FlatTensorSerializeTest, ValidFlatTensorSerialized) {
auto x = buf.str();
const char* byte_buffer = x.c_str();

// Check Magic
EXPECT_EQ(byte_buffer[0], 'F');
EXPECT_EQ(byte_buffer[1], 'H');
EXPECT_EQ(byte_buffer[2], '0');
EXPECT_EQ(byte_buffer[3], '1');
// First 4 bytes are an offset to the flatbuffer root table.

// Check magic ids.
EXPECT_EQ(byte_buffer[4], 'F');
EXPECT_EQ(byte_buffer[5], 'T');
ASSERT_EQ(byte_buffer[6], '0');
ASSERT_EQ(byte_buffer[7], '1');

ASSERT_EQ(byte_buffer[8], 'F');
ASSERT_EQ(byte_buffer[9], 'H');
EXPECT_EQ(byte_buffer[10], '0');
EXPECT_EQ(byte_buffer[11], '1');

// Check Header
EXPECT_EQ( // Header length
*(uint32_t*)(byte_buffer + 4),
auto header_buffer = byte_buffer + 8;
EXPECT_EQ( // Check expected length
*(uint32_t*)(header_buffer + 4),
executorch::extension::FlatTensorHeader::kHeaderExpectedLength);

EXPECT_EQ(
*(uint64_t*)(byte_buffer + 8),
48); // Flatbuffer offset, header is 40 bytes + 8 bytes of padding today,
// and then the flatbuffer starts.
*(uint64_t*)(header_buffer + 8),
48); // Flatbuffer offset, header is 40 bytes + 8 bytes of padding
// today, and then the flatbuffer starts.

EXPECT_EQ(
*(uint64_t*)(byte_buffer + 16),
224); // Flatbuffer size, This is fragile, and depends on the schema, the
// builder, and the padding needed.
const uint64_t segment_offset = 48 +
224; // Segment offset, depends on the padded header and flatbuffer sizes.
EXPECT_EQ(*(uint64_t*)(byte_buffer + 24), segment_offset);
*(uint64_t*)(header_buffer + 16),
232); // Flatbuffer size. This is fragile, and depends on the schema,
// the builder, and the padding needed.

// Segment offset, depends on the padded header and flatbuffer sizes.
const uint64_t segment_offset = 48 + 232 + 8; // 8 is padding.
EXPECT_EQ(*(uint64_t*)(header_buffer + 24), segment_offset);

EXPECT_EQ(
*(uint64_t*)(byte_buffer + 32),
*(uint64_t*)(header_buffer + 32),
20); // Segment total size, 8 bytes of data (2 floats), 24 bytes of
// padding.

// Check Flatbuffer
auto flat_tensor = ::flat_tensor_flatbuffer::GetFlatTensor(byte_buffer + 48);
auto flat_tensor = ::flat_tensor_flatbuffer::GetFlatTensor(byte_buffer);

EXPECT_EQ(
flat_tensor->version(),
Expand Down
8 changes: 3 additions & 5 deletions extension/flat_tensor/test/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_serialize(self) -> None:

# Check header.
header = FlatTensorHeader.from_bytes(
serialized_data[0 : FlatTensorHeader.EXPECTED_LENGTH]
serialized_data[8 : FlatTensorHeader.EXPECTED_LENGTH + 8]
)
self.assertTrue(header.is_valid())

Expand All @@ -107,15 +107,13 @@ def test_serialize(self) -> None:

# Confirm the flatbuffer magic is present.
self.assertEqual(
serialized_data[
header.flatbuffer_offset + 4 : header.flatbuffer_offset + 8
],
serialized_data[4:8],
b"FT01",
)

# Check flat tensor data.
flat_tensor_bytes = serialized_data[
header.flatbuffer_offset : header.flatbuffer_offset + header.flatbuffer_size
0 : header.flatbuffer_offset + header.flatbuffer_size
]

flat_tensor = _deserialize_to_flat_tensor(flat_tensor_bytes)
Expand Down
Loading