@@ -52,11 +52,14 @@ Result<const flat_tensor_flatbuffer::TensorMetadata*> get_flat_tensor_metadata(
52
52
for (int i = 0 ; i < tensors->size (); i++) {
53
53
if (std::strcmp (tensors->Get (i)->fully_qualified_name ()->c_str (), key) ==
54
54
0 ) {
55
- // TODO(T214294528): Support multiple segments in FlatTensor.
56
- if (tensors->Get (i)->segment_index () != 0 ) {
57
- return Error::InvalidExternalData;
58
- }
59
- return tensors->Get (i);
55
+ const auto * metadata = tensors->Get (i);
56
+ ET_CHECK_OR_RETURN_ERROR (
57
+ metadata->segment_index () >= 0 && metadata->offset () >= 0 ,
58
+ InvalidExternalData,
59
+ " Invalid segment_index %d or offset %" PRIu64 " ; malformed PTD file." ,
60
+ metadata->segment_index (),
61
+ metadata->offset ());
62
+ return metadata;
60
63
}
61
64
}
62
65
return Error::NotFound;
@@ -75,6 +78,23 @@ Result<const TensorLayout> create_tensor_layout(
75
78
scalar_type);
76
79
}
77
80
81
+ Result<int > get_and_check_segment_offset (
82
+ const flatbuffers::Vector<
83
+ flatbuffers::Offset<flat_tensor_flatbuffer::DataSegment>>* segments,
84
+ const flat_tensor_flatbuffer::TensorMetadata* metadata) {
85
+ ET_CHECK_OR_RETURN_ERROR (
86
+ segments != nullptr ,
87
+ InvalidExternalData,
88
+ " No segments in external data flatbuffer." );
89
+
90
+ ET_CHECK_OR_RETURN_ERROR (
91
+ metadata->segment_index () < segments->size (),
92
+ InvalidExternalData,
93
+ " Invalid segment_index %d; malformed PTD file." ,
94
+ metadata->segment_index ());
95
+ return segments->Get (metadata->segment_index ())->offset ();
96
+ }
97
+
78
98
} // namespace
79
99
80
100
ET_NODISCARD Result<const TensorLayout> FlatTensorDataMap::get_metadata (
@@ -89,39 +109,73 @@ ET_NODISCARD Result<const TensorLayout> FlatTensorDataMap::get_metadata(
89
109
90
110
ET_NODISCARD Result<FreeableBuffer> FlatTensorDataMap::get_data (
91
111
const char * key) const {
92
- auto tensor_metadata = flat_tensor_->tensors ();
93
-
94
- Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata_res =
95
- get_flat_tensor_metadata (key, tensor_metadata);
96
- if (!metadata_res.ok ()) {
97
- return metadata_res.error ();
112
+ Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata =
113
+ get_flat_tensor_metadata (key, flat_tensor_->tensors ());
114
+ if (!metadata.ok ()) {
115
+ return metadata.error ();
98
116
}
99
- const auto metadata = metadata_res. get ();
100
- if (metadata-> segment_index () < 0 || metadata-> offset () < 0 ) {
101
- // Invalid segment_index/offset; malformed PTD file.
102
- return Error::InvalidExternalData ;
117
+ Result< const TensorLayout> tensor_layout =
118
+ create_tensor_layout ( metadata. get ());
119
+ if (!tensor_layout. ok ()) {
120
+ return tensor_layout. error () ;
103
121
}
104
-
105
- Result< const TensorLayout> tensor_layout_res = create_tensor_layout ( metadata);
106
- if (!tensor_layout_res .ok ()) {
107
- return tensor_layout_res .error ();
122
+ Result< int > segment_offset =
123
+ get_and_check_segment_offset (flat_tensor_-> segments (), metadata. get () );
124
+ if (!segment_offset .ok ()) {
125
+ return segment_offset .error ();
108
126
}
109
127
110
- // This FreeableBuffer doesn't own the underlying data, and will not free it,
111
- // which is why the free function is a nullptr.
112
- // TODO(T214294528): Remove data_ro_ and instead load the data here, letting
113
- // FreeableBuffer own it.
114
- return FreeableBuffer (
115
- static_cast <const uint8_t *>(data_ro_.data ()) + metadata->offset (),
116
- tensor_layout_res.get ().nbytes (),
117
- nullptr );
128
+ // Load constant data.
129
+ ET_CHECK_OR_RETURN_ERROR (
130
+ segment_offset.get () <
131
+ header_.segment_base_offset + header_.segment_data_size ,
132
+ InvalidExternalData,
133
+ " Invalid segment offset %d is larger than the segment_base_offset + segment_data_size %" PRIu64
134
+ " ; malformed PTD file." ,
135
+ segment_offset.get (),
136
+ header_.segment_base_offset + header_.segment_data_size );
137
+ return loader_->load (
138
+ header_.segment_base_offset + segment_offset.get () +
139
+ metadata.get ()->offset (),
140
+ tensor_layout.get ().nbytes (),
141
+ DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
118
142
}
119
143
120
144
ET_NODISCARD Result<size_t > FlatTensorDataMap::load_data_into (
121
145
ET_UNUSED const char * key,
122
146
ET_UNUSED void * buffer,
123
147
ET_UNUSED size_t size) const {
124
- return Error::NotImplemented;
148
+ Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata =
149
+ get_flat_tensor_metadata (key, flat_tensor_->tensors ());
150
+ if (!metadata.ok ()) {
151
+ return metadata.error ();
152
+ }
153
+ Result<const TensorLayout> tensor_layout =
154
+ create_tensor_layout (metadata.get ());
155
+ if (!tensor_layout.ok ()) {
156
+ return tensor_layout.error ();
157
+ }
158
+ ET_CHECK_OR_RETURN_ERROR (
159
+ size < tensor_layout.get ().nbytes (),
160
+ InvalidArgument,
161
+ " Buffer size %zu is smaller than tensor size %zu" ,
162
+ size,
163
+ tensor_layout.get ().nbytes ());
164
+
165
+ Result<int > segment_offset =
166
+ get_and_check_segment_offset (flat_tensor_->segments (), metadata.get ());
167
+ if (!segment_offset.ok ()) {
168
+ return segment_offset.error ();
169
+ }
170
+ // Load mutable data.
171
+ DataLoader::SegmentInfo info = DataLoader::SegmentInfo (
172
+ DataLoader::SegmentInfo::Type::Mutable, 0 , nullptr );
173
+ return loader_->load_into (
174
+ header_.segment_base_offset + segment_offset.get () +
175
+ metadata.get ()->offset (),
176
+ tensor_layout.get ().nbytes (),
177
+ info,
178
+ buffer);
125
179
}
126
180
127
181
ET_NODISCARD Result<size_t > FlatTensorDataMap::get_num_keys () const {
@@ -138,45 +192,34 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
138
192
139
193
/* static */ Result<FlatTensorDataMap> FlatTensorDataMap::load (
140
194
DataLoader* loader) {
141
- // Load data map.
142
- size_t flatbuffer_offset = 0 ;
143
- size_t flatbuffer_size = 0 ;
144
- size_t segment_base_offset = 0 ;
145
- size_t segment_data_size = 0 ;
146
- {
147
- // Check header.
148
- Result<FreeableBuffer> header = loader->load (
149
- /* offset=*/ 0 ,
150
- FlatTensorHeader::kNumHeadBytes ,
151
- DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
152
- if (!header.ok ()) {
153
- return header.error ();
154
- }
155
- Result<FlatTensorHeader> fh =
156
- FlatTensorHeader::Parse (header->data (), header->size ());
157
- if (fh.ok ()) {
158
- // The header has the data map size.
159
- flatbuffer_offset = fh->flatbuffer_offset ;
160
- flatbuffer_size = fh->flatbuffer_size ;
161
- segment_base_offset = fh->segment_base_offset ;
162
- segment_data_size = fh->segment_data_size ;
163
- } else if (fh.error () == Error::NotFound) {
164
- // No header, throw error.
165
- ET_LOG (Error, " No FlatTensorHeader found." );
166
- return fh.error ();
167
- } else {
168
- // corruption, throw error.
169
- ET_LOG (Error, " Flat tensor header may be corrupt." );
170
- return fh.error ();
171
- }
195
+ // Check header.
196
+ Result<FreeableBuffer> header = loader->load (
197
+ /* offset=*/ 0 ,
198
+ FlatTensorHeader::kNumHeadBytes ,
199
+ DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
200
+ if (!header.ok ()) {
201
+ ET_LOG (Error, " Failed to load header." );
202
+ return header.error ();
203
+ }
204
+ Result<FlatTensorHeader> fh =
205
+ FlatTensorHeader::Parse (header->data (), header->size ());
206
+ if (fh.error () == Error::NotFound) {
207
+ // No header, throw error.
208
+ ET_LOG (Error, " No FlatTensorHeader found." );
209
+ return fh.error ();
210
+ } else if (fh.error () != Error::Ok) {
211
+ // corruption, throw error.
212
+ ET_LOG (Error, " Flat tensor header may be corrupt." );
213
+ return fh.error ();
172
214
}
173
215
174
216
// Load flatbuffer data as a segment.
175
217
Result<FreeableBuffer> flat_tensor_data = loader->load (
176
218
/* offset=*/ 0 ,
177
- flatbuffer_offset + flatbuffer_size,
219
+ fh-> flatbuffer_offset + fh-> flatbuffer_size ,
178
220
DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
179
221
if (!flat_tensor_data.ok ()) {
222
+ ET_LOG (Error, " Failed to load flat_tensor data." );
180
223
return flat_tensor_data.error ();
181
224
}
182
225
@@ -204,54 +247,8 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
204
247
const flat_tensor_flatbuffer::FlatTensor* flat_tensor =
205
248
flat_tensor_flatbuffer::GetFlatTensor (flat_tensor_data->data ());
206
249
207
- // Validate flatbuffer data.
208
- flatbuffers::Verifier verifier (
209
- reinterpret_cast <const uint8_t *>(flat_tensor_data->data ()),
210
- flat_tensor_data->size ());
211
- bool ok = flat_tensor_flatbuffer::VerifyFlatTensorBuffer (verifier);
212
- ET_CHECK_OR_RETURN_ERROR (
213
- ok,
214
- InvalidExternalData,
215
- " Verification failed; data may be truncated or corrupt" );
216
-
217
- // Get pointer to tensor metadata.
218
- const auto * s_tensor_metadata = flat_tensor->tensors ();
219
- if (s_tensor_metadata == nullptr ) {
220
- ET_LOG (Error, " FlatTensor has no tensor metadata." );
221
- return Error::InvalidExternalData;
222
- }
223
-
224
- // Load constant data.
225
- const auto * s_data_segment = flat_tensor->segments ();
226
-
227
- // TODO(T214294528): Support multiple segments in FlatTensor.
228
- if (s_data_segment->size () != 1 ) {
229
- ET_LOG (
230
- Error,
231
- " FlatTensor has %u segments, only 1 supported." ,
232
- s_data_segment->size ());
233
- }
234
- // First segment size should be <= the total segment data size.
235
- int segment_size = s_data_segment->Get (0 )->size ();
236
- int segment_offset = s_data_segment->Get (0 )->offset ();
237
- if (segment_size > segment_data_size) {
238
- ET_LOG (
239
- Error,
240
- " FlatTensor segment size %d > segment data size %zu" ,
241
- segment_size,
242
- segment_data_size);
243
- }
244
-
245
- Result<FreeableBuffer> data_ro = loader->load (
246
- /* offset=*/ segment_base_offset + segment_offset,
247
- segment_size,
248
- DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
249
- if (!data_ro.ok ()) {
250
- return data_ro.error ();
251
- }
252
-
253
250
return FlatTensorDataMap (
254
- std::move (flat_tensor_data .get ()), flat_tensor, std::move (data_ro .get ()));
251
+ fh .get (), std::move (flat_tensor_data .get ()), flat_tensor, loader );
255
252
}
256
253
257
254
} // namespace extension
0 commit comments