16
16
from executorch .exir ._serialize .padding import aligned_size
17
17
18
18
from executorch .exir .schema import ScalarType
19
+ from executorch .extension .flat_tensor .serialize .flat_tensor_schema import TensorMetadata
19
20
20
21
from executorch .extension .flat_tensor .serialize .serialize import (
22
+ _deserialize_to_flat_tensor ,
21
23
FlatTensorConfig ,
22
24
FlatTensorHeader ,
23
25
FlatTensorSerializer ,
24
26
)
25
27
26
28
# Test artifacts.
27
- TEST_TENSOR_BUFFER = [b"tensor" ]
29
+ TEST_TENSOR_BUFFER = [b"\x11 " * 4 , b" \x22 " * 32 ]
28
30
TEST_TENSOR_MAP = {
29
31
"fqn1" : TensorEntry (
30
32
buffer_index = 0 ,
42
44
dim_order = [0 , 1 , 2 ],
43
45
),
44
46
),
47
+ "fqn3" : TensorEntry (
48
+ buffer_index = 1 ,
49
+ layout = TensorLayout (
50
+ scalar_type = ScalarType .INT ,
51
+ sizes = [2 , 2 , 2 ],
52
+ dim_order = [0 , 1 ],
53
+ ),
54
+ ),
45
55
}
46
56
TEST_DATA_PAYLOAD = DataPayload (
47
57
buffers = TEST_TENSOR_BUFFER ,
50
60
51
61
52
62
class TestSerialize (unittest .TestCase ):
63
+ # TODO(T211851359): improve test coverage.
64
+ def check_tensor_metadata (
65
+ self , tensor_layout : TensorLayout , tensor_metadata : TensorMetadata
66
+ ) -> None :
67
+ self .assertEqual (tensor_layout .scalar_type , tensor_metadata .scalar_type )
68
+ self .assertEqual (tensor_layout .sizes , tensor_metadata .sizes )
69
+ self .assertEqual (tensor_layout .dim_order , tensor_metadata .dim_order )
70
+
53
71
def test_serialize (self ) -> None :
54
72
config = FlatTensorConfig ()
55
73
serializer : DataSerializer = FlatTensorSerializer (config )
56
74
57
- data = bytes (serializer .serialize (TEST_DATA_PAYLOAD ))
75
+ serialized_data = bytes (serializer .serialize (TEST_DATA_PAYLOAD ))
58
76
59
- header = FlatTensorHeader .from_bytes (data [0 : FlatTensorHeader .EXPECTED_LENGTH ])
77
+ # Check header.
78
+ header = FlatTensorHeader .from_bytes (serialized_data [0 : FlatTensorHeader .EXPECTED_LENGTH ])
60
79
self .assertTrue (header .is_valid ())
61
80
62
81
# Header is aligned to config.segment_alignment, which is where the flatbuffer starts.
@@ -75,9 +94,81 @@ def test_serialize(self) -> None:
75
94
self .assertTrue (header .segment_base_offset , expected_segment_base_offset )
76
95
77
96
# TEST_TENSOR_BUFFER is aligned to config.segment_alignment.
78
- self .assertEqual (header .segment_data_size , config .segment_alignment )
97
+ expected_segment_data_size = aligned_size (
98
+ sum (len (buffer ) for buffer in TEST_TENSOR_BUFFER ), config .segment_alignment
99
+ )
100
+ self .assertEqual (header .segment_data_size , expected_segment_data_size )
79
101
80
102
# Confirm the flatbuffer magic is present.
81
103
self .assertEqual (
82
- data [header .flatbuffer_offset + 4 : header .flatbuffer_offset + 8 ], b"FT01"
104
+ serialized_data [header .flatbuffer_offset + 4 : header .flatbuffer_offset + 8 ], b"FT01"
83
105
)
106
+
107
+ # Check flat tensor data.
108
+ flat_tensor_bytes = serialized_data [
109
+ header .flatbuffer_offset : header .flatbuffer_offset + header .flatbuffer_size
110
+ ]
111
+
112
+ flat_tensor = _deserialize_to_flat_tensor (flat_tensor_bytes )
113
+
114
+ self .assertEqual (flat_tensor .version , 0 )
115
+ self .assertEqual (flat_tensor .tensor_alignment , config .tensor_alignment )
116
+
117
+ tensors = flat_tensor .tensors
118
+ self .assertEqual (len (tensors ), 3 )
119
+ self .assertEqual (tensors [0 ].fully_qualified_name , "fqn1" )
120
+ self .check_tensor_metadata (TEST_TENSOR_MAP ["fqn1" ].layout , tensors [0 ])
121
+ self .assertEqual (tensors [0 ].segment_index , 0 )
122
+ self .assertEqual (tensors [0 ].offset , 0 )
123
+
124
+ self .assertEqual (tensors [1 ].fully_qualified_name , "fqn2" )
125
+ self .check_tensor_metadata (TEST_TENSOR_MAP ["fqn2" ].layout , tensors [1 ])
126
+ self .assertEqual (tensors [1 ].segment_index , 0 )
127
+ self .assertEqual (tensors [1 ].offset , 0 )
128
+
129
+ self .assertEqual (tensors [2 ].fully_qualified_name , "fqn3" )
130
+ self .check_tensor_metadata (TEST_TENSOR_MAP ["fqn3" ].layout , tensors [2 ])
131
+ self .assertEqual (tensors [2 ].segment_index , 0 )
132
+ self .assertEqual (tensors [2 ].offset , config .tensor_alignment )
133
+
134
+ segments = flat_tensor .segments
135
+ self .assertEqual (len (segments ), 1 )
136
+ self .assertEqual (segments [0 ].offset , 0 )
137
+ self .assertEqual (segments [0 ].size , config .tensor_alignment * 3 )
138
+
139
+ # Length of serialized_data matches segment_base_offset + segment_data_size.
140
+ self .assertEqual (
141
+ header .segment_base_offset + header .segment_data_size , len (serialized_data )
142
+ )
143
+ self .assertTrue (segments [0 ].size <= header .segment_data_size )
144
+
145
+ # Check the contents of the segment. Expecting two tensors from
146
+ # TEST_TENSOR_BUFFER = [b"\x11" * 4, b"\x22" * 32]
147
+ segment_data = serialized_data [
148
+ header .segment_base_offset : header .segment_base_offset + segments [0 ].size
149
+ ]
150
+
151
+ # Tensor: b"\x11" * 4
152
+ t0_start = 0
153
+ t0_len = len (TEST_TENSOR_BUFFER [0 ])
154
+ t0_end = t0_start + aligned_size (t0_len , config .tensor_alignment )
155
+ self .assertEqual (
156
+ segment_data [t0_start : t0_start + t0_len ], TEST_TENSOR_BUFFER [0 ]
157
+ )
158
+ padding = b"\x00 " * (t0_end - t0_len )
159
+ self .assertEqual (segment_data [t0_start + t0_len : t0_end ], padding )
160
+
161
+ # Tensor: b"\x22" * 32
162
+ t1_start = t0_end
163
+ t1_len = len (TEST_TENSOR_BUFFER [1 ])
164
+ t1_end = t1_start + aligned_size (t1_len , config .tensor_alignment )
165
+ self .assertEqual (
166
+ segment_data [t1_start : t1_start + t1_len ],
167
+ TEST_TENSOR_BUFFER [1 ],
168
+ )
169
+ padding = b"\x00 " * (t1_end - (t1_len + t1_start ))
170
+ self .assertEqual (segment_data [t1_start + t1_len : t1_start + t1_end ], padding )
171
+
172
+ # Check length of the segment is expected.
173
+ self .assertEqual (segments [0 ].size , aligned_size (t1_end , config .segment_alignment ))
174
+ self .assertEqual (segments [0 ].size , header .segment_data_size )
0 commit comments