Skip to content

Commit 3bad16c

Browse files
committed
[executorch][flat_tensor] Serialize flat tensor tests
Pull Request resolved: #7269 Introduce _convert_to_flat_tensor, which interprets a flat_tensor blob as a flat_tensor schema. Use this for more comprehensive testing for flat tensor serialization, and later for deserialization. ghstack-source-id: 260969721 @exported-using-ghexport Differential Revision: [D67007821](https://our.internmc.facebook.com/intern/diff/D67007821/)
1 parent 86f5e73 commit 3bad16c

File tree

2 files changed

+127
-10
lines changed

2 files changed

+127
-10
lines changed

extension/flat_tensor/serialize/serialize.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
import pkg_resources
88
from executorch.exir._serialize._cord import Cord
9-
from executorch.exir._serialize._dataclass import _DataclassEncoder
9+
from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass
1010

11-
from executorch.exir._serialize._flatbuffer import _flatc_compile
11+
from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile
1212
from executorch.exir._serialize.data_serializer import DataPayload, DataSerializer
1313

1414
from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required
@@ -25,8 +25,8 @@
2525
)
2626

2727

28-
def _convert_to_flatbuffer(flat_tensor: FlatTensor) -> Cord:
29-
"""Converts a FlatTensor to a flatbuffer and returns the serialized data."""
28+
def _serialize_to_flatbuffer(flat_tensor: FlatTensor) -> Cord:
29+
"""Serializes a FlatTensor to a flatbuffer and returns the serialized data."""
3030
flat_tensor_json = json.dumps(flat_tensor, cls=_DataclassEncoder)
3131
with tempfile.TemporaryDirectory() as d:
3232
schema_path = os.path.join(d, "flat_tensor.fbs")
@@ -49,6 +49,32 @@ def _convert_to_flatbuffer(flat_tensor: FlatTensor) -> Cord:
4949
return Cord(output_file.read())
5050

5151

52+
def _deserialize_to_flat_tensor(flatbuffer: bytes) -> FlatTensor:
53+
"""Deserializes a flatbuffer to a FlatTensor and returns the dataclass."""
54+
with tempfile.TemporaryDirectory() as d:
55+
schema_path = os.path.join(d, "flat_tensor.fbs")
56+
with open(schema_path, "wb") as schema_file:
57+
schema_file.write(
58+
pkg_resources.resource_string(__name__, "flat_tensor.fbs")
59+
)
60+
61+
scalar_type_path = os.path.join(d, "scalar_type.fbs")
62+
with open(scalar_type_path, "wb") as scalar_type_file:
63+
scalar_type_file.write(
64+
pkg_resources.resource_string(__name__, "scalar_type.fbs")
65+
)
66+
67+
bin_path = os.path.join(d, "flat_tensor.bin")
68+
with open(bin_path, "wb") as bin_file:
69+
bin_file.write(flatbuffer)
70+
71+
_flatc_decompile(d, schema_path, bin_path, ["--raw-binary"])
72+
73+
json_path = os.path.join(d, "flat_tensor.json")
74+
with open(json_path, "rb") as output_file:
75+
return _json_to_dataclass(json.load(output_file), cls=FlatTensor)
76+
77+
5278
@dataclass
5379
class FlatTensorConfig:
5480
tensor_alignment: int = 16
@@ -236,7 +262,7 @@ def serialize(
236262
segments=[DataSegment(offset=0, size=len(flat_tensor_data))],
237263
)
238264

239-
flatbuffer_payload = _convert_to_flatbuffer(flat_tensor)
265+
flatbuffer_payload = _serialize_to_flatbuffer(flat_tensor)
240266
padded_flatbuffer_length: int = aligned_size(
241267
input_size=len(flatbuffer_payload),
242268
alignment=self.config.tensor_alignment,

extension/flat_tensor/test/test_serialize.py

Lines changed: 96 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,17 @@
1616
from executorch.exir._serialize.padding import aligned_size
1717

1818
from executorch.exir.schema import ScalarType
19+
from executorch.extension.flat_tensor.serialize.flat_tensor_schema import TensorMetadata
1920

2021
from executorch.extension.flat_tensor.serialize.serialize import (
22+
_deserialize_to_flat_tensor,
2123
FlatTensorConfig,
2224
FlatTensorHeader,
2325
FlatTensorSerializer,
2426
)
2527

2628
# Test artifacts.
27-
TEST_TENSOR_BUFFER = [b"tensor"]
29+
TEST_TENSOR_BUFFER = [b"\x11" * 4, b"\x22" * 32]
2830
TEST_TENSOR_MAP = {
2931
"fqn1": TensorEntry(
3032
buffer_index=0,
@@ -42,6 +44,14 @@
4244
dim_order=[0, 1, 2],
4345
),
4446
),
47+
"fqn3": TensorEntry(
48+
buffer_index=1,
49+
layout=TensorLayout(
50+
scalar_type=ScalarType.INT,
51+
sizes=[2, 2, 2],
52+
dim_order=[0, 1],
53+
),
54+
),
4555
}
4656
TEST_DATA_PAYLOAD = DataPayload(
4757
buffers=TEST_TENSOR_BUFFER,
@@ -50,13 +60,22 @@
5060

5161

5262
class TestSerialize(unittest.TestCase):
63+
# TODO(T211851359): improve test coverage.
64+
def check_tensor_metadata(
65+
self, tensor_layout: TensorLayout, tensor_metadata: TensorMetadata
66+
) -> None:
67+
self.assertEqual(tensor_layout.scalar_type, tensor_metadata.scalar_type)
68+
self.assertEqual(tensor_layout.sizes, tensor_metadata.sizes)
69+
self.assertEqual(tensor_layout.dim_order, tensor_metadata.dim_order)
70+
5371
def test_serialize(self) -> None:
5472
config = FlatTensorConfig()
5573
serializer: DataSerializer = FlatTensorSerializer(config)
5674

57-
data = bytes(serializer.serialize(TEST_DATA_PAYLOAD))
75+
serialized_data = bytes(serializer.serialize(TEST_DATA_PAYLOAD))
5876

59-
header = FlatTensorHeader.from_bytes(data[0 : FlatTensorHeader.EXPECTED_LENGTH])
77+
# Check header.
78+
header = FlatTensorHeader.from_bytes(serialized_data[0 : FlatTensorHeader.EXPECTED_LENGTH])
6079
self.assertTrue(header.is_valid())
6180

6281
# Header is aligned to config.segment_alignment, which is where the flatbuffer starts.
@@ -75,9 +94,81 @@ def test_serialize(self) -> None:
7594
self.assertTrue(header.segment_base_offset, expected_segment_base_offset)
7695

7796
# TEST_TENSOR_BUFFER is aligned to config.segment_alignment.
78-
self.assertEqual(header.segment_data_size, config.segment_alignment)
97+
expected_segment_data_size = aligned_size(
98+
sum(len(buffer) for buffer in TEST_TENSOR_BUFFER), config.segment_alignment
99+
)
100+
self.assertEqual(header.segment_data_size, expected_segment_data_size)
79101

80102
# Confirm the flatbuffer magic is present.
81103
self.assertEqual(
82-
data[header.flatbuffer_offset + 4 : header.flatbuffer_offset + 8], b"FT01"
104+
serialized_data[header.flatbuffer_offset + 4 : header.flatbuffer_offset + 8], b"FT01"
83105
)
106+
107+
# Check flat tensor data.
108+
flat_tensor_bytes = serialized_data[
109+
header.flatbuffer_offset : header.flatbuffer_offset + header.flatbuffer_size
110+
]
111+
112+
flat_tensor = _deserialize_to_flat_tensor(flat_tensor_bytes)
113+
114+
self.assertEqual(flat_tensor.version, 0)
115+
self.assertEqual(flat_tensor.tensor_alignment, config.tensor_alignment)
116+
117+
tensors = flat_tensor.tensors
118+
self.assertEqual(len(tensors), 3)
119+
self.assertEqual(tensors[0].fully_qualified_name, "fqn1")
120+
self.check_tensor_metadata(TEST_TENSOR_MAP["fqn1"].layout, tensors[0])
121+
self.assertEqual(tensors[0].segment_index, 0)
122+
self.assertEqual(tensors[0].offset, 0)
123+
124+
self.assertEqual(tensors[1].fully_qualified_name, "fqn2")
125+
self.check_tensor_metadata(TEST_TENSOR_MAP["fqn2"].layout, tensors[1])
126+
self.assertEqual(tensors[1].segment_index, 0)
127+
self.assertEqual(tensors[1].offset, 0)
128+
129+
self.assertEqual(tensors[2].fully_qualified_name, "fqn3")
130+
self.check_tensor_metadata(TEST_TENSOR_MAP["fqn3"].layout, tensors[2])
131+
self.assertEqual(tensors[2].segment_index, 0)
132+
self.assertEqual(tensors[2].offset, config.tensor_alignment)
133+
134+
segments = flat_tensor.segments
135+
self.assertEqual(len(segments), 1)
136+
self.assertEqual(segments[0].offset, 0)
137+
self.assertEqual(segments[0].size, config.tensor_alignment * 3)
138+
139+
# Length of serialized_data matches segment_base_offset + segment_data_size.
140+
self.assertEqual(
141+
header.segment_base_offset + header.segment_data_size, len(serialized_data)
142+
)
143+
self.assertTrue(segments[0].size <= header.segment_data_size)
144+
145+
# Check the contents of the segment. Expecting two tensors from
146+
# TEST_TENSOR_BUFFER = [b"\x11" * 4, b"\x22" * 32]
147+
segment_data = serialized_data[
148+
header.segment_base_offset : header.segment_base_offset + segments[0].size
149+
]
150+
151+
# Tensor: b"\x11" * 4
152+
t0_start = 0
153+
t0_len = len(TEST_TENSOR_BUFFER[0])
154+
t0_end = t0_start + aligned_size(t0_len, config.tensor_alignment)
155+
self.assertEqual(
156+
segment_data[t0_start : t0_start + t0_len], TEST_TENSOR_BUFFER[0]
157+
)
158+
padding = b"\x00" * (t0_end - t0_len)
159+
self.assertEqual(segment_data[t0_start + t0_len : t0_end], padding)
160+
161+
# Tensor: b"\x22" * 32
162+
t1_start = t0_end
163+
t1_len = len(TEST_TENSOR_BUFFER[1])
164+
t1_end = t1_start + aligned_size(t1_len, config.tensor_alignment)
165+
self.assertEqual(
166+
segment_data[t1_start : t1_start + t1_len],
167+
TEST_TENSOR_BUFFER[1],
168+
)
169+
padding = b"\x00" * (t1_end - (t1_len + t1_start))
170+
self.assertEqual(segment_data[t1_start + t1_len : t1_start + t1_end], padding)
171+
172+
# Check length of the segment is expected.
173+
self.assertEqual(segments[0].size, aligned_size(t1_end, config.segment_alignment))
174+
self.assertEqual(segments[0].size, header.segment_data_size)

0 commit comments

Comments
 (0)