15
15
)
16
16
17
17
from executorch .exir .schema import ScalarType
18
+ from executorch .extension .flat_tensor .serialize .flat_tensor_schema import TensorMetadata
18
19
19
20
from executorch .extension .flat_tensor .serialize .serialize import (
21
+ _convert_to_flat_tensor ,
22
+ FlatTensorConfig ,
20
23
FlatTensorHeader ,
21
24
FlatTensorSerializer ,
22
25
)
23
26
24
27
# Test artifacts
25
- TEST_TENSOR_BUFFER = [b"tensor" ]
28
+ TEST_TENSOR_BUFFER = [b"\x11 " * 4 , b" \x22 " * 32 ]
26
29
TEST_TENSOR_MAP = {
27
30
"fqn1" : 0 ,
28
31
"fqn2" : 0 ,
32
+ "fqn3" : 1 ,
29
33
}
30
34
31
35
TEST_TENSOR_LAYOUT = {
39
43
sizes = [1 , 1 , 1 ],
40
44
dim_order = typing .cast (List [bytes ], [0 , 1 , 2 ]),
41
45
),
46
+ "fqn3" : TensorLayout (
47
+ scalar_type = ScalarType .INT ,
48
+ sizes = [2 , 2 , 2 ],
49
+ dim_order = typing .cast (List [bytes ], [0 , 1 ]),
50
+ ),
42
51
}
43
52
44
53
45
54
class TestSerialize (unittest .TestCase ):
55
+ def check_tensor_metadata (
56
+ self , tensor_layout : TensorLayout , tensor_metadata : TensorMetadata
57
+ ) -> None :
58
+ self .assertEqual (tensor_layout .scalar_type , tensor_metadata .scalar_type )
59
+ self .assertEqual (tensor_layout .sizes , tensor_metadata .sizes )
60
+ self .assertEqual (tensor_layout .dim_order , tensor_metadata .dim_order )
61
+
46
62
def test_serialize (self ) -> None :
47
- serializer : DataSerializer = FlatTensorSerializer ()
63
+ config = FlatTensorConfig ()
64
+ serializer : DataSerializer = FlatTensorSerializer (config )
48
65
49
66
data = bytes (
50
67
serializer .serialize_tensors (
@@ -54,14 +71,71 @@ def test_serialize(self) -> None:
54
71
)
55
72
)
56
73
74
+ # Check header.
57
75
header = FlatTensorHeader .from_bytes (data [0 : FlatTensorHeader .EXPECTED_LENGTH ])
58
76
self .assertTrue (header .is_valid ())
59
77
60
78
self .assertEqual (header .flatbuffer_offset , 48 )
61
- self .assertEqual (header .flatbuffer_size , 200 )
62
- self .assertEqual (header .segment_base_offset , 256 )
63
- self .assertEqual (header .data_size , 16 )
79
+ self .assertEqual (header .flatbuffer_size , 288 )
80
+ self .assertEqual (header .segment_base_offset , 336 )
81
+ self .assertEqual (header .data_size , 48 )
64
82
65
83
self .assertEqual (
66
84
data [header .flatbuffer_offset + 4 : header .flatbuffer_offset + 8 ], b"FT01"
67
85
)
86
+
87
+ # Check flat tensor data.
88
+ flat_tensor_bytes = data [
89
+ header .flatbuffer_offset : header .flatbuffer_offset + header .flatbuffer_size
90
+ ]
91
+
92
+ flat_tensor = _convert_to_flat_tensor (flat_tensor_bytes )
93
+
94
+ self .assertEqual (flat_tensor .version , 0 )
95
+ self .assertEqual (flat_tensor .tensor_alignment , config .tensor_alignment )
96
+
97
+ tensors = flat_tensor .tensors
98
+ self .assertEqual (len (tensors ), 3 )
99
+ self .assertEqual (tensors [0 ].fully_qualified_name , "fqn1" )
100
+ self .check_tensor_metadata (TEST_TENSOR_LAYOUT ["fqn1" ], tensors [0 ])
101
+ self .assertEqual (tensors [0 ].segment_index , 0 )
102
+ self .assertEqual (tensors [0 ].offset , 0 )
103
+
104
+ self .assertEqual (tensors [1 ].fully_qualified_name , "fqn2" )
105
+ self .check_tensor_metadata (TEST_TENSOR_LAYOUT ["fqn2" ], tensors [1 ])
106
+ self .assertEqual (tensors [1 ].segment_index , 0 )
107
+ self .assertEqual (tensors [1 ].offset , 0 )
108
+
109
+ self .assertEqual (tensors [2 ].fully_qualified_name , "fqn3" )
110
+ self .check_tensor_metadata (TEST_TENSOR_LAYOUT ["fqn3" ], tensors [2 ])
111
+ self .assertEqual (tensors [2 ].segment_index , 0 )
112
+ self .assertEqual (tensors [2 ].offset , config .tensor_alignment )
113
+
114
+ segments = flat_tensor .segments
115
+ self .assertEqual (len (segments ), 1 )
116
+ self .assertEqual (segments [0 ].offset , 0 )
117
+ self .assertEqual (segments [0 ].size , config .tensor_alignment * 3 )
118
+
119
+ # Check segment data.
120
+ segment_data = data [
121
+ header .segment_base_offset : header .segment_base_offset + segments [0 ].size
122
+ ]
123
+
124
+ t0_start = 0
125
+ t0_len = len (TEST_TENSOR_BUFFER [0 ])
126
+ t0_end = config .tensor_alignment
127
+ self .assertEqual (
128
+ segment_data [t0_start : t0_start + t0_len ], TEST_TENSOR_BUFFER [0 ]
129
+ )
130
+ padding = b"\x00 " * (t0_end - t0_len )
131
+ self .assertEqual (segment_data [t0_start + t0_len : t0_end ], padding )
132
+
133
+ t1_start = config .tensor_alignment
134
+ t1_len = len (TEST_TENSOR_BUFFER [1 ])
135
+ t1_end = config .tensor_alignment * 3
136
+ self .assertEqual (
137
+ segment_data [t1_start : t1_start + t1_len ],
138
+ TEST_TENSOR_BUFFER [1 ],
139
+ )
140
+ padding = b"\x00 " * (t1_end - (t1_len + t1_start ))
141
+ self .assertEqual (segment_data [t1_start + t1_len : t1_end ], padding )
0 commit comments