Skip to content

Commit f2d1eba

Browse files
committed
[executorch][flat_tensor] Serialize flat tensor tests
Pull Request resolved: #7269 Introduce _convert_to_flat_tensor, which interprets a flat_tensor blob as a flat_tensor schema. Use this for more comprehensive testing for flat tensor serialization, and later for deserialization. ghstack-source-id: 260059177 @exported-using-ghexport Differential Revision: [D67007821](https://our.internmc.facebook.com/intern/diff/D67007821/)
1 parent 0ac32ee commit f2d1eba

File tree

2 files changed

+113
-4
lines changed

2 files changed

+113
-4
lines changed

extension/flat_tensor/serialize/serialize.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
import pkg_resources
88
from executorch.exir._serialize._cord import Cord
9-
from executorch.exir._serialize._dataclass import _DataclassEncoder
9+
from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass
1010

11-
from executorch.exir._serialize._flatbuffer import _flatc_compile
11+
from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile
1212
from executorch.exir._serialize.data_serializer import DataPayload, DataSerializer
1313

1414
from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required
@@ -49,6 +49,32 @@ def _convert_to_flatbuffer(flat_tensor: FlatTensor) -> Cord:
4949
return Cord(output_file.read())
5050

5151

52+
def _convert_to_flat_tensor(flatbuffer: bytes) -> FlatTensor:
53+
"""Converts a flatbuffer to a FlatTensor and returns the dataclass."""
54+
with tempfile.TemporaryDirectory() as d:
55+
schema_path = os.path.join(d, "flat_tensor.fbs")
56+
with open(schema_path, "wb") as schema_file:
57+
schema_file.write(
58+
pkg_resources.resource_string(__name__, "flat_tensor.fbs")
59+
)
60+
61+
scalar_type_path = os.path.join(d, "scalar_type.fbs")
62+
with open(scalar_type_path, "wb") as scalar_type_file:
63+
scalar_type_file.write(
64+
pkg_resources.resource_string(__name__, "scalar_type.fbs")
65+
)
66+
67+
bin_path = os.path.join(d, "flat_tensor.bin")
68+
with open(bin_path, "wb") as bin_file:
69+
bin_file.write(flatbuffer)
70+
71+
_flatc_decompile(d, schema_path, bin_path, ["--raw-binary"])
72+
73+
json_path = os.path.join(d, "flat_tensor.json")
74+
with open(json_path, "rb") as output_file:
75+
return _json_to_dataclass(json.load(output_file), cls=FlatTensor)
76+
77+
5278
@dataclass
5379
class FlatTensorConfig:
5480
tensor_alignment: int = 16

extension/flat_tensor/test/test_serialize.py

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,17 @@
1616
from executorch.exir._serialize.padding import aligned_size
1717

1818
from executorch.exir.schema import ScalarType
19+
from executorch.extension.flat_tensor.serialize.flat_tensor_schema import TensorMetadata
1920

2021
from executorch.extension.flat_tensor.serialize.serialize import (
22+
_convert_to_flat_tensor,
2123
FlatTensorConfig,
2224
FlatTensorHeader,
2325
FlatTensorSerializer,
2426
)
2527

2628
# Test artifacts.
27-
TEST_TENSOR_BUFFER = [b"tensor"]
29+
TEST_TENSOR_BUFFER = [b"\x11" * 4, b"\x22" * 32]
2830
TEST_TENSOR_MAP = {
2931
"fqn1": TensorEntry(
3032
buffer_index=0,
@@ -42,6 +44,14 @@
4244
dim_order=[0, 1, 2],
4345
),
4446
),
47+
"fqn3": TensorEntry(
48+
buffer_index=1,
49+
layout=TensorLayout(
50+
scalar_type=ScalarType.INT,
51+
sizes=[2, 2, 2],
52+
dim_order=[0, 1],
53+
),
54+
),
4555
}
4656
TEST_DATA_PAYLOAD = DataPayload(
4757
buffers=TEST_TENSOR_BUFFER,
@@ -50,12 +60,21 @@
5060

5161

5262
class TestSerialize(unittest.TestCase):
63+
# TODO(T211851359): improve test coverage.
64+
def check_tensor_metadata(
65+
self, tensor_layout: TensorLayout, tensor_metadata: TensorMetadata
66+
) -> None:
67+
self.assertEqual(tensor_layout.scalar_type, tensor_metadata.scalar_type)
68+
self.assertEqual(tensor_layout.sizes, tensor_metadata.sizes)
69+
self.assertEqual(tensor_layout.dim_order, tensor_metadata.dim_order)
70+
5371
def test_serialize(self) -> None:
5472
config = FlatTensorConfig()
5573
serializer: DataSerializer = FlatTensorSerializer(config)
5674

5775
data = bytes(serializer.serialize(TEST_DATA_PAYLOAD))
5876

77+
# Check header.
5978
header = FlatTensorHeader.from_bytes(data[0 : FlatTensorHeader.EXPECTED_LENGTH])
6079
self.assertTrue(header.is_valid())
6180

@@ -75,9 +94,73 @@ def test_serialize(self) -> None:
7594
self.assertTrue(header.segment_base_offset, expected_segment_base_offset)
7695

7796
# TEST_TENSOR_BUFFER is aligned to config.segment_alignment.
78-
self.assertEqual(header.segment_data_size, config.segment_alignment)
97+
expected_segment_data_size = aligned_size(
98+
sum(len(buffer) for buffer in TEST_TENSOR_BUFFER), config.segment_alignment
99+
)
100+
self.assertEqual(header.segment_data_size, expected_segment_data_size)
79101

80102
# Confirm the flatbuffer magic is present.
81103
self.assertEqual(
82104
data[header.flatbuffer_offset + 4 : header.flatbuffer_offset + 8], b"FT01"
83105
)
106+
107+
# Check flat tensor data.
108+
flat_tensor_bytes = data[
109+
header.flatbuffer_offset : header.flatbuffer_offset + header.flatbuffer_size
110+
]
111+
112+
flat_tensor = _convert_to_flat_tensor(flat_tensor_bytes)
113+
114+
self.assertEqual(flat_tensor.version, 0)
115+
self.assertEqual(flat_tensor.tensor_alignment, config.tensor_alignment)
116+
117+
tensors = flat_tensor.tensors
118+
self.assertEqual(len(tensors), 3)
119+
self.assertEqual(tensors[0].fully_qualified_name, "fqn1")
120+
self.check_tensor_metadata(TEST_TENSOR_MAP["fqn1"].layout, tensors[0])
121+
self.assertEqual(tensors[0].segment_index, 0)
122+
self.assertEqual(tensors[0].offset, 0)
123+
124+
self.assertEqual(tensors[1].fully_qualified_name, "fqn2")
125+
self.check_tensor_metadata(TEST_TENSOR_MAP["fqn2"].layout, tensors[1])
126+
self.assertEqual(tensors[1].segment_index, 0)
127+
self.assertEqual(tensors[1].offset, 0)
128+
129+
self.assertEqual(tensors[2].fully_qualified_name, "fqn3")
130+
self.check_tensor_metadata(TEST_TENSOR_MAP["fqn3"].layout, tensors[2])
131+
self.assertEqual(tensors[2].segment_index, 0)
132+
self.assertEqual(tensors[2].offset, config.tensor_alignment)
133+
134+
segments = flat_tensor.segments
135+
self.assertEqual(len(segments), 1)
136+
self.assertEqual(segments[0].offset, 0)
137+
self.assertEqual(segments[0].size, config.tensor_alignment * 3)
138+
139+
# Check segment data.
140+
self.assertEqual(
141+
header.segment_base_offset + header.segment_data_size, len(data)
142+
)
143+
self.assertTrue(segments[0].size <= header.segment_data_size)
144+
145+
segment_data = data[
146+
header.segment_base_offset : header.segment_base_offset + segments[0].size
147+
]
148+
149+
t0_start = 0
150+
t0_len = len(TEST_TENSOR_BUFFER[0])
151+
t0_end = config.tensor_alignment
152+
self.assertEqual(
153+
segment_data[t0_start : t0_start + t0_len], TEST_TENSOR_BUFFER[0]
154+
)
155+
padding = b"\x00" * (t0_end - t0_len)
156+
self.assertEqual(segment_data[t0_start + t0_len : t0_end], padding)
157+
158+
t1_start = config.tensor_alignment
159+
t1_len = len(TEST_TENSOR_BUFFER[1])
160+
t1_end = config.tensor_alignment * 3
161+
self.assertEqual(
162+
segment_data[t1_start : t1_start + t1_len],
163+
TEST_TENSOR_BUFFER[1],
164+
)
165+
padding = b"\x00" * (t1_end - (t1_len + t1_start))
166+
self.assertEqual(segment_data[t1_start + t1_len : t1_end], padding)

0 commit comments

Comments
 (0)