Skip to content

Commit 3d41da7

Browse files
committed
[executorch][flat_tensor] Serialize flat tensor
Serialize a flat tensor file. The resulting file looks like: Header with - flatbuffer offset and size - segment data offset and size Flatbuffer Tensor data (in segment) Differential Revision: [D66374253](https://our.internmc.facebook.com/intern/diff/D66374253/) ghstack-source-id: 257465228 Pull Request resolved: #7268
1 parent bfdc3a3 commit 3d41da7

File tree

7 files changed

+367
-0
lines changed

7 files changed

+367
-0
lines changed

exir/_serialize/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ runtime.python_library(
3434
"_flatbuffer.py",
3535
"_program.py",
3636
"utils.py",
37+
"data_serializer.py",
3738
],
3839
resources = {
3940
"//executorch/schema:program.fbs": "program.fbs",

extension/flat_tensor/__init__.py

Whitespace-only changes.

extension/flat_tensor/serialize/TARGETS

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,21 @@ runtime.python_library(
1414
"//executorch/...",
1515
],
1616
)
17+
18+
runtime.python_library(
19+
name = "serialize",
20+
srcs = [
21+
"serialize.py",
22+
],
23+
resources = [
24+
"flat_tensor.fbs",
25+
"scalar_type.fbs",
26+
],
27+
visibility = [
28+
"//executorch/...",
29+
],
30+
deps = [
31+
":schema",
32+
"//executorch/exir/_serialize:lib",
33+
],
34+
)

extension/flat_tensor/serialize/__init__.py

Whitespace-only changes.
Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
import json
2+
import os
3+
import tempfile
4+
from dataclasses import dataclass
5+
from typing import ClassVar, Dict, List, Optional
6+
7+
import pkg_resources
8+
from executorch.exir._serialize._cord import Cord
9+
from executorch.exir._serialize._dataclass import _DataclassEncoder
10+
11+
from executorch.exir._serialize._flatbuffer import _flatc_compile
12+
from executorch.exir._serialize.data_serializer import DataSerializer, SerializationInfo
13+
14+
from executorch.exir._serialize.utils import (
15+
_aligned_size,
16+
_HEADER_BYTEORDER,
17+
_pad_to,
18+
_padding_required,
19+
)
20+
21+
from executorch.extension.flat_tensor.serialize.flat_tensor_schema import (
22+
DataSegment,
23+
FlatTensor,
24+
TensorMetadata,
25+
)
26+
27+
28+
def _convert_to_flatbuffer(flat_tensor: FlatTensor) -> Cord:
29+
flat_tensor_json = json.dumps(flat_tensor, cls=_DataclassEncoder)
30+
with tempfile.TemporaryDirectory() as d:
31+
schema_path = os.path.join(d, "flat_tensor.fbs")
32+
with open(schema_path, "wb") as schema_file:
33+
schema_file.write(
34+
pkg_resources.resource_string(__name__, "flat_tensor.fbs")
35+
)
36+
scalar_type_path = os.path.join(d, "scalar_type.fbs")
37+
with open(scalar_type_path, "wb") as scalar_type_file:
38+
scalar_type_file.write(
39+
pkg_resources.resource_string(__name__, "scalar_type.fbs")
40+
)
41+
json_path = os.path.join(d, "flat_tensor.json")
42+
with open(json_path, "wb") as json_file:
43+
json_file.write(flat_tensor_json.encode("ascii"))
44+
45+
_flatc_compile(d, schema_path, json_path)
46+
output_path = os.path.join(d, "flat_tensor.ptd")
47+
with open(output_path, "rb") as output_file:
48+
return Cord(output_file.read())
49+
50+
51+
@dataclass
52+
class FlatTensorConfig:
53+
tensor_alignment: int = 16
54+
segment_alignment: int = 16
55+
56+
57+
@dataclass
58+
class FlatTensorHeader:
59+
# Class constants.
60+
# The magic bytes that should be at the beginning of the header.
61+
EXPECTED_MAGIC: ClassVar[bytes] = b"FT01"
62+
EXPECTED_LENGTH: ClassVar[int] = (
63+
# Header magic
64+
4
65+
# Header length
66+
+ 4
67+
# Flatbuffer offset
68+
+ 8
69+
# Flatbuffer data size
70+
+ 8
71+
# Segment base offset
72+
+ 8
73+
# Data size
74+
+ 8
75+
)
76+
77+
# Instance attributes. @dataclass will turn these into ctor args.
78+
79+
# Offset to the start of the flatbuffer data, in bytes.
80+
flatbuffer_offset: int
81+
# The size of the serialized data in bytes.
82+
flatbuffer_size: int
83+
# Offset to the start of the first segment, or zero if there
84+
# are no segments.
85+
segment_base_offset: int
86+
# Size of all the segment data, in bytes.
87+
data_size: int
88+
89+
# The magic bytes read from or to be written to the binary header.
90+
magic: bytes = EXPECTED_MAGIC
91+
# The header length, in bytes, read from or to be written to the binary
92+
# header.
93+
length: int = EXPECTED_LENGTH
94+
95+
@staticmethod
96+
def from_bytes(data: bytes) -> "FlatTensorHeader":
97+
"""Tries to read an flat_tensor header from the provided data.
98+
99+
Does not validate that the header is well-formed. Callers should
100+
use is_valid().
101+
102+
Args:
103+
data: The data to read from.
104+
Returns:
105+
The contents of the flat_tensor header.
106+
Raises:
107+
ValueError: If not enough data is provided.
108+
"""
109+
if len(data) < FlatTensorHeader.EXPECTED_LENGTH:
110+
raise ValueError(
111+
f"Not enough data for flat_tensor header: {len(data)} "
112+
+ f"< {FlatTensorHeader.EXPECTED_LENGTH}"
113+
)
114+
115+
return FlatTensorHeader(
116+
magic=data[0:4],
117+
length=int.from_bytes(data[4:8], byteorder=_HEADER_BYTEORDER),
118+
flatbuffer_offset=int.from_bytes(data[8:16], byteorder=_HEADER_BYTEORDER),
119+
flatbuffer_size=int.from_bytes(data[16:24], byteorder=_HEADER_BYTEORDER),
120+
segment_base_offset=int.from_bytes(
121+
data[24:32], byteorder=_HEADER_BYTEORDER
122+
),
123+
data_size=int.from_bytes(data[32:40], byteorder=_HEADER_BYTEORDER),
124+
)
125+
126+
def is_valid(self) -> bool:
127+
"""Returns true if the flat_tensor header appears to be well-formed."""
128+
return (
129+
self.magic == FlatTensorHeader.EXPECTED_MAGIC
130+
and self.length >= FlatTensorHeader.EXPECTED_LENGTH
131+
)
132+
133+
def to_bytes(self) -> bytes:
134+
"""Returns the binary representation of the flat_tensor header.
135+
136+
Note that this will ignore self.magic and self.length and will always
137+
write the proper magic/length.
138+
"""
139+
data: bytes = (
140+
# Extended header magic. This lets consumers detect whether the
141+
# header was inserted or not. Always use the proper magic value
142+
# (i.e., ignore self.magic) since there's no reason to create an
143+
# invalid header.
144+
self.EXPECTED_MAGIC
145+
# uint32_t: Size of this header. This makes it easier to add new
146+
# fields to this header in the future. Always use the proper size
147+
# (i.e., ignore self.length) since there's no reason to create an
148+
# invalid header.
149+
+ self.EXPECTED_LENGTH.to_bytes(4, byteorder=_HEADER_BYTEORDER)
150+
# uint64_t: Offset to the start of the flatbuffer data, in bytes.
151+
+ self.flatbuffer_offset.to_bytes(8, byteorder=_HEADER_BYTEORDER)
152+
# uint64_t: Size of the serialized data in bytes.
153+
+ self.flatbuffer_size.to_bytes(8, byteorder=_HEADER_BYTEORDER)
154+
# uint64_t: Offset to the start of the first segment, or zero if
155+
# there are no segments.
156+
+ self.segment_base_offset.to_bytes(8, byteorder=_HEADER_BYTEORDER)
157+
# uint64_t: Size of all the segment data, in bytes.
158+
+ self.data_size.to_bytes(8, byteorder=_HEADER_BYTEORDER)
159+
)
160+
return data
161+
162+
163+
class FlatTensorSerializer(DataSerializer):
164+
def __init__(self, config: Optional[FlatTensorConfig] = None) -> None:
165+
if config is None:
166+
self.config = FlatTensorConfig()
167+
else:
168+
self.config = config
169+
170+
def serialize_tensors(
171+
self,
172+
serialization_info: SerializationInfo,
173+
) -> Cord:
174+
flat_tensor_metadata: List[TensorMetadata] = []
175+
flat_tensor_data: Cord = Cord()
176+
177+
# {idx, offset}
178+
saved_offsets: Dict[int, int] = {}
179+
180+
for fqn, idx in serialization_info.fqn_to_buffer_index.items():
181+
tensor_layout = serialization_info.fqn_to_tensor_layout.get(fqn, None)
182+
assert tensor_layout is not None
183+
# Check index into the tensor buffers is valid.
184+
assert idx < len(serialization_info.tensor_buffers)
185+
186+
# Check if the tensor has already been saved.
187+
offset = saved_offsets.get(idx, -1)
188+
if offset == -1:
189+
if len(flat_tensor_data) > 0:
190+
# Add padding to round off the previous tensor offset.
191+
pad_length = _padding_required(
192+
len(flat_tensor_data), self.config.tensor_alignment
193+
)
194+
flat_tensor_data.append(b"\x00" * pad_length)
195+
# Add to saved offsets.
196+
offset = len(flat_tensor_data)
197+
saved_offsets[idx] = offset
198+
# Append to flat_tensor_data at the offset.
199+
flat_tensor_data.append(serialization_info.tensor_buffers[idx])
200+
201+
flat_tensor_metadata.append(
202+
TensorMetadata(
203+
fully_qualified_name=fqn,
204+
scalar_type=tensor_layout.scalar_type,
205+
dim_sizes=tensor_layout.dim_sizes,
206+
dim_order=tensor_layout.dim_order,
207+
segment_index=0,
208+
offset=offset,
209+
)
210+
)
211+
# Pad to segment alignment.
212+
segment_pad_length = _padding_required(
213+
len(flat_tensor_data), self.config.segment_alignment
214+
)
215+
if segment_pad_length > 0:
216+
flat_tensor_data.append(b"\x00" * segment_pad_length)
217+
218+
# Organize the tensors and segments.
219+
flat_tensor = FlatTensor(
220+
version=0,
221+
tensor_alignment=self.config.tensor_alignment,
222+
tensors=flat_tensor_metadata,
223+
segments=[DataSegment(offset=0, size=len(flat_tensor_data))],
224+
)
225+
226+
flatbuffer_payload = _convert_to_flatbuffer(flat_tensor)
227+
padded_flatbuffer_length: int = _aligned_size(
228+
input_size=len(flatbuffer_payload),
229+
alignment=self.config.tensor_alignment,
230+
)
231+
232+
padded_header_length: int = _aligned_size(
233+
input_size=FlatTensorHeader.EXPECTED_LENGTH,
234+
alignment=self.config.tensor_alignment,
235+
)
236+
237+
segment_base_offset = _aligned_size(
238+
padded_flatbuffer_length + padded_header_length,
239+
self.config.segment_alignment,
240+
)
241+
242+
header_data: bytes = FlatTensorHeader(
243+
flatbuffer_offset=padded_header_length,
244+
flatbuffer_size=len(flatbuffer_payload),
245+
segment_base_offset=segment_base_offset,
246+
data_size=len(flat_tensor_data),
247+
).to_bytes()
248+
249+
# Pad header and payload to segment alignment.
250+
header_data = _pad_to(header_data, padded_header_length)
251+
flatbuffer_payload.append(
252+
b"\x00" * (padded_flatbuffer_length - len(flatbuffer_payload))
253+
)
254+
255+
# Place everything into one segment.
256+
payload = Cord()
257+
payload.append(header_data)
258+
payload.append(flatbuffer_payload)
259+
payload.append(flat_tensor_data)
260+
261+
return payload
262+
263+
def deserialize_tensors(self, blob: Cord) -> SerializationInfo:
264+
"""
265+
Deserializes a blob into a list of tensor metadata and tensors.
266+
"""
267+
raise NotImplementedError("deserialize_data")

extension/flat_tensor/test/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
2+
3+
oncall("executorch")
4+
5+
python_unittest(
6+
name = "serialize",
7+
srcs = [
8+
"test_serialize.py",
9+
],
10+
deps = [
11+
"//executorch/extension/flat_tensor/serialize:serialize",
12+
"//executorch/extension/flat_tensor/serialize:schema",
13+
],
14+
)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import typing
8+
import unittest
9+
from typing import List
10+
11+
from executorch.exir._serialize.data_serializer import (
12+
DataSerializer,
13+
SerializationInfo,
14+
TensorLayout,
15+
)
16+
17+
from executorch.exir.schema import ScalarType
18+
19+
from executorch.extension.flat_tensor.serialize.serialize import (
20+
FlatTensorHeader,
21+
FlatTensorSerializer,
22+
)
23+
24+
# Test artifacts
25+
TEST_TENSOR_BUFFER = [b"tensor"]
26+
TEST_TENSOR_MAP = {
27+
"fqn1": 0,
28+
"fqn2": 0,
29+
}
30+
31+
TEST_TENSOR_LAYOUT = {
32+
"fqn1": TensorLayout(
33+
scalar_type=ScalarType.FLOAT,
34+
dim_sizes=[1, 1, 1],
35+
dim_order=typing.cast(List[bytes], [0, 1, 2]),
36+
),
37+
"fqn2": TensorLayout(
38+
scalar_type=ScalarType.FLOAT,
39+
dim_sizes=[1, 1, 1],
40+
dim_order=typing.cast(List[bytes], [0, 1, 2]),
41+
),
42+
}
43+
44+
45+
class TestSerialize(unittest.TestCase):
46+
def test_serialize(self) -> None:
47+
serializer: DataSerializer = FlatTensorSerializer()
48+
49+
data = bytes(
50+
serializer.serialize_tensors(
51+
SerializationInfo(
52+
TEST_TENSOR_BUFFER, TEST_TENSOR_MAP, TEST_TENSOR_LAYOUT
53+
)
54+
)
55+
)
56+
57+
header = FlatTensorHeader.from_bytes(data[0 : FlatTensorHeader.EXPECTED_LENGTH])
58+
self.assertTrue(header.is_valid())
59+
60+
self.assertEqual(header.flatbuffer_offset, 48)
61+
self.assertEqual(header.flatbuffer_size, 200)
62+
self.assertEqual(header.segment_base_offset, 256)
63+
self.assertEqual(header.data_size, 16)
64+
65+
self.assertEqual(
66+
data[header.flatbuffer_offset + 4 : header.flatbuffer_offset + 8], b"FT01"
67+
)

0 commit comments

Comments
 (0)