Skip to content

Commit 76cc1fa

Browse files
pytorchbotlucylq
authored andcommitted
[executorch][flat_tensor] Serialize flat tensor (#7641)
Pull Request resolved: #7268 Serialize a flat tensor file. The resulting file looks like: Header containing: - flatbuffer offset and size - segment data offset and size Flatbuffer containing: - Items described in [flat_tensor.fbs](https://www.internalfb.com/code/fbsource/[079ba95593be856a16783bd3f3b3579580595fbb]/fbcode/executorch/extension/flat_tensor/flat_tensor.fbs) Tensor data (in segment) - Raw tensor data ghstack-source-id: 261273078 @exported-using-ghexport Differential Revision: [D66374253](https://our.internmc.facebook.com/intern/diff/D66374253/) Co-authored-by: lucylq <[email protected]>
1 parent f1aca0f commit 76cc1fa

File tree

8 files changed

+409
-1
lines changed

8 files changed

+409
-1
lines changed

exir/_serialize/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ runtime.python_library(
3333
"_dataclass.py",
3434
"_flatbuffer.py",
3535
"_program.py",
36+
"data_serializer.py",
3637
"padding.py",
3738
],
3839
resources = {

extension/flat_tensor/__init__.py

Whitespace-only changes.

extension/flat_tensor/serialize/TARGETS

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,21 @@ runtime.python_library(
1414
"//executorch/...",
1515
],
1616
)
17+
18+
runtime.python_library(
19+
name = "serialize",
20+
srcs = [
21+
"serialize.py",
22+
],
23+
resources = [
24+
"flat_tensor.fbs",
25+
"scalar_type.fbs",
26+
],
27+
visibility = [
28+
"//executorch/...",
29+
],
30+
deps = [
31+
":schema",
32+
"//executorch/exir/_serialize:lib",
33+
],
34+
)

extension/flat_tensor/serialize/__init__.py

Whitespace-only changes.

extension/flat_tensor/serialize/flat_tensor_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class TensorMetadata:
1818
fully_qualified_name: str
1919
scalar_type: ScalarType
2020
sizes: List[int]
21-
dim_order: List[bytes]
21+
dim_order: List[int]
2222

2323
segment_index: int
2424
offset: int
Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import json
10+
import os
11+
import tempfile
12+
from dataclasses import dataclass
13+
from typing import ClassVar, Dict, List, Literal, Optional
14+
15+
import pkg_resources
16+
from executorch.exir._serialize._cord import Cord
17+
from executorch.exir._serialize._dataclass import _DataclassEncoder
18+
19+
from executorch.exir._serialize._flatbuffer import _flatc_compile
20+
from executorch.exir._serialize.data_serializer import DataPayload, DataSerializer
21+
22+
from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required
23+
24+
# Byte order of numbers written to flat tensor headers. Always little-endian
25+
# regardless of the host system, since all commonly-used modern CPUs are little
26+
# endian.
27+
_HEADER_BYTEORDER: Literal["little"] = "little"
28+
29+
from executorch.extension.flat_tensor.serialize.flat_tensor_schema import (
30+
DataSegment,
31+
FlatTensor,
32+
TensorMetadata,
33+
)
34+
35+
36+
def _convert_to_flatbuffer(flat_tensor: FlatTensor) -> Cord:
37+
"""Converts a FlatTensor to a flatbuffer and returns the serialized data."""
38+
flat_tensor_json = json.dumps(flat_tensor, cls=_DataclassEncoder)
39+
with tempfile.TemporaryDirectory() as d:
40+
schema_path = os.path.join(d, "flat_tensor.fbs")
41+
with open(schema_path, "wb") as schema_file:
42+
schema_file.write(
43+
pkg_resources.resource_string(__name__, "flat_tensor.fbs")
44+
)
45+
scalar_type_path = os.path.join(d, "scalar_type.fbs")
46+
with open(scalar_type_path, "wb") as scalar_type_file:
47+
scalar_type_file.write(
48+
pkg_resources.resource_string(__name__, "scalar_type.fbs")
49+
)
50+
json_path = os.path.join(d, "flat_tensor.json")
51+
with open(json_path, "wb") as json_file:
52+
json_file.write(flat_tensor_json.encode("ascii"))
53+
54+
_flatc_compile(d, schema_path, json_path)
55+
output_path = os.path.join(d, "flat_tensor.ptd")
56+
with open(output_path, "rb") as output_file:
57+
return Cord(output_file.read())
58+
59+
60+
@dataclass
61+
class FlatTensorConfig:
62+
tensor_alignment: int = 16
63+
segment_alignment: int = 16
64+
65+
66+
@dataclass
67+
class FlatTensorHeader:
68+
# Class constants.
69+
# The magic bytes that should be at the beginning of the header.
70+
EXPECTED_MAGIC: ClassVar[bytes] = b"FH01"
71+
EXPECTED_LENGTH: ClassVar[int] = (
72+
# Header magic
73+
4
74+
# Header length
75+
+ 4
76+
# Flatbuffer offset
77+
+ 8
78+
# Flatbuffer data size
79+
+ 8
80+
# Segment base offset
81+
+ 8
82+
# Data size
83+
+ 8
84+
)
85+
86+
# Instance attributes. @dataclass will turn these into ctor args.
87+
88+
# Offset to the start of the flatbuffer data, in bytes.
89+
flatbuffer_offset: int
90+
# The size of the serialized data in bytes.
91+
flatbuffer_size: int
92+
# Offset to the start of the first segment, or zero if there
93+
# are no segments.
94+
segment_base_offset: int
95+
# Size of all the segment data, in bytes.
96+
segment_data_size: int
97+
98+
# The magic bytes read from or to be written to the binary header.
99+
magic: bytes = EXPECTED_MAGIC
100+
# The header length, in bytes, read from or to be written to the binary
101+
# header.
102+
length: int = EXPECTED_LENGTH
103+
104+
@staticmethod
105+
def from_bytes(data: bytes) -> "FlatTensorHeader":
106+
"""Tries to read an flat_tensor header from the provided data.
107+
108+
Does not validate that the header is well-formed. Callers should
109+
use is_valid().
110+
111+
Args:
112+
data: The data to read from.
113+
Returns:
114+
The contents of the flat_tensor header.
115+
Raises:
116+
ValueError: If not enough data is provided.
117+
"""
118+
if len(data) < FlatTensorHeader.EXPECTED_LENGTH:
119+
raise ValueError(
120+
f"Not enough data for flat_tensor header: {len(data)} "
121+
+ f"< {FlatTensorHeader.EXPECTED_LENGTH}"
122+
)
123+
124+
return FlatTensorHeader(
125+
magic=data[0:4],
126+
length=int.from_bytes(data[4:8], byteorder=_HEADER_BYTEORDER),
127+
flatbuffer_offset=int.from_bytes(data[8:16], byteorder=_HEADER_BYTEORDER),
128+
flatbuffer_size=int.from_bytes(data[16:24], byteorder=_HEADER_BYTEORDER),
129+
segment_base_offset=int.from_bytes(
130+
data[24:32], byteorder=_HEADER_BYTEORDER
131+
),
132+
segment_data_size=int.from_bytes(data[32:40], byteorder=_HEADER_BYTEORDER),
133+
)
134+
135+
def is_valid(self) -> bool:
136+
"""Returns true if the flat_tensor header appears to be well-formed."""
137+
return (
138+
self.magic == FlatTensorHeader.EXPECTED_MAGIC
139+
and self.length >= FlatTensorHeader.EXPECTED_LENGTH
140+
)
141+
142+
def to_bytes(self) -> bytes:
143+
"""Returns the binary representation of the flat_tensor header.
144+
145+
Note that this will ignore self.magic and self.length and will always
146+
write the proper magic/length.
147+
"""
148+
data: bytes = (
149+
# Extended header magic. This lets consumers detect whether the
150+
# header was inserted or not. Always use the proper magic value
151+
# (i.e., ignore self.magic) since there's no reason to create an
152+
# invalid header.
153+
self.EXPECTED_MAGIC
154+
# uint32_t: Size of this header. This makes it easier to add new
155+
# fields to this header in the future. Always use the proper size
156+
# (i.e., ignore self.length) since there's no reason to create an
157+
# invalid header.
158+
+ self.EXPECTED_LENGTH.to_bytes(4, byteorder=_HEADER_BYTEORDER)
159+
# uint64_t: Offset to the start of the flatbuffer data, in bytes.
160+
+ self.flatbuffer_offset.to_bytes(8, byteorder=_HEADER_BYTEORDER)
161+
# uint64_t: Size of the serialized data in bytes.
162+
+ self.flatbuffer_size.to_bytes(8, byteorder=_HEADER_BYTEORDER)
163+
# uint64_t: Offset to the start of the first segment, or zero if
164+
# there are no segments.
165+
+ self.segment_base_offset.to_bytes(8, byteorder=_HEADER_BYTEORDER)
166+
# uint64_t: Size of all the segment data, in bytes.
167+
+ self.segment_data_size.to_bytes(8, byteorder=_HEADER_BYTEORDER)
168+
)
169+
return data
170+
171+
172+
class FlatTensorSerializer(DataSerializer):
173+
"""A concrete implementation of the DataSerializer interface that
174+
serializes and deserializes data to/from the FlatTensor format.
175+
"""
176+
177+
def __init__(self, config: Optional[FlatTensorConfig] = None) -> None:
178+
"""FlatTensorConfig holds information required for serialization,
179+
eg. alignment.
180+
"""
181+
if config is None:
182+
self.config: FlatTensorConfig = FlatTensorConfig()
183+
else:
184+
self.config: FlatTensorConfig = config
185+
186+
def serialize(
187+
self,
188+
data: DataPayload,
189+
) -> Cord:
190+
"""Serializes a list of tensor metadata and tensors into a blob."""
191+
192+
flat_tensor_metadata: List[TensorMetadata] = []
193+
flat_tensor_data: Cord = Cord()
194+
195+
# {idx, offset}
196+
saved_offsets: Dict[int, int] = {}
197+
198+
for fqn, tensor_entry in data.fqn_to_tensor.items():
199+
assert tensor_entry.layout is not None
200+
# Check index into the tensor buffers is valid.
201+
assert tensor_entry.buffer_index < len(
202+
data.buffers
203+
), f"Invalid index {tensor_entry.buffer_index} is greater than tensor buffer size {len(data.buffers)}."
204+
205+
# Check if the tensor has already been appended to the flat_tensor_data.
206+
offset = saved_offsets.get(tensor_entry.buffer_index, -1)
207+
if offset == -1:
208+
if len(flat_tensor_data) > 0:
209+
# Add padding to round off the previous tensor offset.
210+
pad_length = padding_required(
211+
len(flat_tensor_data), self.config.tensor_alignment
212+
)
213+
flat_tensor_data.append(b"\x00" * pad_length)
214+
# Add to saved offsets.
215+
offset = len(flat_tensor_data)
216+
saved_offsets[tensor_entry.buffer_index] = offset
217+
# Append to flat_tensor_data at the offset.
218+
flat_tensor_data.append(data.buffers[tensor_entry.buffer_index])
219+
220+
flat_tensor_metadata.append(
221+
TensorMetadata(
222+
fully_qualified_name=fqn,
223+
scalar_type=tensor_entry.layout.scalar_type,
224+
sizes=tensor_entry.layout.sizes,
225+
dim_order=tensor_entry.layout.dim_order,
226+
segment_index=0,
227+
offset=offset,
228+
)
229+
)
230+
231+
# Pad flat_tensor_data to segment alignment.
232+
segment_pad_length = padding_required(
233+
len(flat_tensor_data), self.config.segment_alignment
234+
)
235+
if segment_pad_length > 0:
236+
flat_tensor_data.append(b"\x00" * segment_pad_length)
237+
238+
# Create FlatTensor, which describes of the contents of the file and
239+
# points to all the data segments. It will be serialized to flatbuffer.
240+
flat_tensor = FlatTensor(
241+
version=0,
242+
tensor_alignment=self.config.tensor_alignment,
243+
tensors=flat_tensor_metadata,
244+
segments=[DataSegment(offset=0, size=len(flat_tensor_data))],
245+
)
246+
247+
flatbuffer_payload = _convert_to_flatbuffer(flat_tensor)
248+
padded_flatbuffer_length: int = aligned_size(
249+
input_size=len(flatbuffer_payload),
250+
alignment=self.config.tensor_alignment,
251+
)
252+
253+
padded_header_length: int = aligned_size(
254+
input_size=FlatTensorHeader.EXPECTED_LENGTH,
255+
alignment=self.config.tensor_alignment,
256+
)
257+
258+
segment_base_offset = aligned_size(
259+
padded_flatbuffer_length + padded_header_length,
260+
self.config.segment_alignment,
261+
)
262+
263+
# Create FlatTensorHeader, which stores the offsets and sizes of the
264+
# FlatTensor flatbuffer and the segment data.
265+
header_data: bytes = FlatTensorHeader(
266+
flatbuffer_offset=padded_header_length,
267+
flatbuffer_size=len(flatbuffer_payload),
268+
segment_base_offset=segment_base_offset,
269+
segment_data_size=len(flat_tensor_data),
270+
).to_bytes()
271+
272+
# Pad header and payload to segment alignment.
273+
header_data = pad_to(header_data, padded_header_length)
274+
flatbuffer_payload.append(
275+
b"\x00" * (padded_flatbuffer_length - len(flatbuffer_payload))
276+
)
277+
278+
# Place everything into one segment.
279+
payload = Cord()
280+
payload.append(header_data)
281+
payload.append(flatbuffer_payload)
282+
payload.append(flat_tensor_data)
283+
284+
return payload
285+
286+
def deserialize(self, blob: Cord) -> DataPayload:
287+
"""
288+
Deserializes a flat_tensor blob into a list of tensor metadata and tensors.
289+
"""
290+
raise NotImplementedError("deserialize_data")

extension/flat_tensor/test/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
2+
3+
oncall("executorch")
4+
5+
python_unittest(
6+
name = "serialize",
7+
srcs = [
8+
"test_serialize.py",
9+
],
10+
deps = [
11+
"//executorch/extension/flat_tensor/serialize:serialize",
12+
"//executorch/extension/flat_tensor/serialize:schema",
13+
],
14+
)

0 commit comments

Comments
 (0)