Skip to content

Commit 41448d2

Browse files
committed
[executorch][flat_tensor] Serialize flat tensor tests
More comprehensive testing for flat tensor serialization. Differential Revision: [D67007821](https://our.internmc.facebook.com/intern/diff/D67007821/) ghstack-source-id: 257465213 Pull Request resolved: #7269
1 parent 3d41da7 commit 41448d2

File tree

2 files changed

+106
-7
lines changed

2 files changed

+106
-7
lines changed

extension/flat_tensor/serialize/serialize.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
import pkg_resources
88
from executorch.exir._serialize._cord import Cord
9-
from executorch.exir._serialize._dataclass import _DataclassEncoder
9+
from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass
1010

11-
from executorch.exir._serialize._flatbuffer import _flatc_compile
11+
from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile
1212
from executorch.exir._serialize.data_serializer import DataSerializer, SerializationInfo
1313

1414
from executorch.exir._serialize.utils import (
@@ -48,6 +48,31 @@ def _convert_to_flatbuffer(flat_tensor: FlatTensor) -> Cord:
4848
return Cord(output_file.read())
4949

5050

51+
def _convert_to_flat_tensor(flatbuffer: bytes) -> FlatTensor:
52+
with tempfile.TemporaryDirectory() as d:
53+
schema_path = os.path.join(d, "flat_tensor.fbs")
54+
with open(schema_path, "wb") as schema_file:
55+
schema_file.write(
56+
pkg_resources.resource_string(__name__, "flat_tensor.fbs")
57+
)
58+
59+
scalar_type_path = os.path.join(d, "scalar_type.fbs")
60+
with open(scalar_type_path, "wb") as scalar_type_file:
61+
scalar_type_file.write(
62+
pkg_resources.resource_string(__name__, "scalar_type.fbs")
63+
)
64+
65+
bin_path = os.path.join(d, "flat_tensor.bin")
66+
with open(bin_path, "wb") as bin_file:
67+
bin_file.write(flatbuffer)
68+
69+
_flatc_decompile(d, schema_path, bin_path, ["--raw-binary"])
70+
71+
json_path = os.path.join(d, "flat_tensor.json")
72+
with open(json_path, "rb") as output_file:
73+
return _json_to_dataclass(json.load(output_file), cls=FlatTensor)
74+
75+
5176
@dataclass
5277
class FlatTensorConfig:
5378
tensor_alignment: int = 16

extension/flat_tensor/test/test_serialize.py

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,21 @@
1515
)
1616

1717
from executorch.exir.schema import ScalarType
18+
from executorch.extension.flat_tensor.serialize.flat_tensor_schema import TensorMetadata
1819

1920
from executorch.extension.flat_tensor.serialize.serialize import (
21+
_convert_to_flat_tensor,
22+
FlatTensorConfig,
2023
FlatTensorHeader,
2124
FlatTensorSerializer,
2225
)
2326

2427
# Test artifacts
25-
TEST_TENSOR_BUFFER = [b"tensor"]
28+
TEST_TENSOR_BUFFER = [b"\x11"*4, b"\x22"*32]
2629
TEST_TENSOR_MAP = {
2730
"fqn1": 0,
2831
"fqn2": 0,
32+
"fqn3": 1,
2933
}
3034

3135
TEST_TENSOR_LAYOUT = {
@@ -39,12 +43,25 @@
3943
dim_sizes=[1, 1, 1],
4044
dim_order=typing.cast(List[bytes], [0, 1, 2]),
4145
),
46+
"fqn3": TensorLayout(
47+
scalar_type=ScalarType.INT,
48+
dim_sizes=[2, 2, 2],
49+
dim_order=typing.cast(List[bytes], [0, 1]),
50+
),
4251
}
4352

4453

4554
class TestSerialize(unittest.TestCase):
55+
def check_tensor_metadata(
56+
self, tensor_layout: TensorLayout, tensor_metadata: TensorMetadata
57+
) -> None:
58+
self.assertEqual(tensor_layout.scalar_type, tensor_metadata.scalar_type)
59+
self.assertEqual(tensor_layout.dim_sizes, tensor_metadata.dim_sizes)
60+
self.assertEqual(tensor_layout.dim_order, tensor_metadata.dim_order)
61+
4662
def test_serialize(self) -> None:
47-
serializer: DataSerializer = FlatTensorSerializer()
63+
config = FlatTensorConfig()
64+
serializer: DataSerializer = FlatTensorSerializer(config)
4865

4966
data = bytes(
5067
serializer.serialize_tensors(
@@ -54,14 +71,71 @@ def test_serialize(self) -> None:
5471
)
5572
)
5673

74+
# Check header.
5775
header = FlatTensorHeader.from_bytes(data[0 : FlatTensorHeader.EXPECTED_LENGTH])
5876
self.assertTrue(header.is_valid())
5977

6078
self.assertEqual(header.flatbuffer_offset, 48)
61-
self.assertEqual(header.flatbuffer_size, 200)
62-
self.assertEqual(header.segment_base_offset, 256)
63-
self.assertEqual(header.data_size, 16)
79+
self.assertEqual(header.flatbuffer_size, 288)
80+
self.assertEqual(header.segment_base_offset, 336)
81+
self.assertEqual(header.data_size, 48)
6482

6583
self.assertEqual(
6684
data[header.flatbuffer_offset + 4 : header.flatbuffer_offset + 8], b"FT01"
6785
)
86+
87+
# Check flat tensor data.
88+
flat_tensor_bytes = data[
89+
header.flatbuffer_offset : header.flatbuffer_offset + header.flatbuffer_size
90+
]
91+
92+
flat_tensor = _convert_to_flat_tensor(flat_tensor_bytes)
93+
94+
self.assertEqual(flat_tensor.version, 0)
95+
self.assertEqual(flat_tensor.tensor_alignment, config.tensor_alignment)
96+
97+
tensors = flat_tensor.tensors
98+
self.assertEqual(len(tensors), 3)
99+
self.assertEqual(tensors[0].fully_qualified_name, "fqn1")
100+
self.check_tensor_metadata(TEST_TENSOR_LAYOUT["fqn1"], tensors[0])
101+
self.assertEqual(tensors[0].segment_index, 0)
102+
self.assertEqual(tensors[0].offset, 0)
103+
104+
self.assertEqual(tensors[1].fully_qualified_name, "fqn2")
105+
self.check_tensor_metadata(TEST_TENSOR_LAYOUT["fqn2"], tensors[1])
106+
self.assertEqual(tensors[1].segment_index, 0)
107+
self.assertEqual(tensors[1].offset, 0)
108+
109+
self.assertEqual(tensors[2].fully_qualified_name, "fqn3")
110+
self.check_tensor_metadata(TEST_TENSOR_LAYOUT["fqn3"], tensors[2])
111+
self.assertEqual(tensors[2].segment_index, 0)
112+
self.assertEqual(tensors[2].offset, config.tensor_alignment)
113+
114+
segments = flat_tensor.segments
115+
self.assertEqual(len(segments), 1)
116+
self.assertEqual(segments[0].offset, 0)
117+
self.assertEqual(segments[0].size, config.tensor_alignment * 3)
118+
119+
# Check segment data.
120+
segment_data = data[
121+
header.segment_base_offset : header.segment_base_offset + segments[0].size
122+
]
123+
124+
t0_start = 0
125+
t0_len = len(TEST_TENSOR_BUFFER[0])
126+
t0_end = config.tensor_alignment
127+
self.assertEqual(
128+
segment_data[t0_start : t0_start + t0_len], TEST_TENSOR_BUFFER[0]
129+
)
130+
padding = b"\x00" * (t0_end - t0_len)
131+
self.assertEqual(segment_data[t0_start + t0_len : t0_end], padding)
132+
133+
t1_start = config.tensor_alignment
134+
t1_len = len(TEST_TENSOR_BUFFER[1])
135+
t1_end = config.tensor_alignment * 3
136+
self.assertEqual(
137+
segment_data[t1_start : t1_start + t1_len],
138+
TEST_TENSOR_BUFFER[1],
139+
)
140+
padding = b"\x00" * (t1_end - (t1_len + t1_start))
141+
self.assertEqual(segment_data[t1_start + t1_len : t1_end], padding)

0 commit comments

Comments
 (0)