@@ -168,20 +168,52 @@ class DataLoaderSpy : public DataLoader {
168
168
public:
169
169
// / A record of an operation performed on this DataLoader.
170
170
struct Operation {
171
- enum { Load, Free } op;
172
- size_t offset; // Set for Load; zero for Free.
173
- void * data; // Set for Free; nullptr for Load.
174
- size_t size; // Set for Load and Free.
171
+ enum { Load, Free, DeprecatedLoad } op;
172
+ size_t offset; // Set for Load/DeprecatedLoad; zero for Free.
173
+ void * data; // Set for Free; nullptr for Load/DeprecatedLoad.
174
+ size_t size; // Set for Load/DeprecatedLoad and Free.
175
+ std::unique_ptr<const DataLoader::SegmentInfo>
176
+ segment_info; // Set for Load; nullptr for Free/DeprecatedLoad.
175
177
};
176
178
177
179
explicit DataLoaderSpy (DataLoader* delegate) : delegate_(delegate) {}
178
180
181
+ /* *
182
+ * Override the deprecated "Load" method. We will be looking to test that
183
+ * this function is not called if the new "load" method is called.
184
+ */
179
185
Result<FreeableBuffer> Load (size_t offset, size_t size) override {
180
186
Result<FreeableBuffer> buf = delegate_->Load (offset, size);
181
187
if (!buf.ok ()) {
182
188
return buf.error ();
183
189
}
184
- operations_.push_back ({Operation::Load, offset, /* data=*/ nullptr , size});
190
+ operations_.push_back (
191
+ {Operation::DeprecatedLoad,
192
+ offset,
193
+ /* data=*/ nullptr ,
194
+ size,
195
+ /* segment_info=*/ nullptr });
196
+ auto * context = new SpyContext (&operations_, std::move (buf.get ()));
197
+ // Use context->buffer since buf has been moved.
198
+ return FreeableBuffer (
199
+ context->buffer .data (), context->buffer .size (), FreeBuffer, context);
200
+ }
201
+
202
+ Result<FreeableBuffer>
203
+ load (size_t offset, size_t size, const SegmentInfo& segment_info) override {
204
+ Result<FreeableBuffer> buf = delegate_->load (offset, size, segment_info);
205
+ if (!buf.ok ()) {
206
+ return buf.error ();
207
+ }
208
+
209
+ auto segment_info_cpy =
210
+ std::make_unique<const DataLoader::SegmentInfo>(segment_info);
211
+ operations_.push_back (
212
+ {Operation::Load,
213
+ offset,
214
+ /* data=*/ nullptr ,
215
+ size,
216
+ /* segment_info=*/ std::move (segment_info_cpy)});
185
217
auto * context = new SpyContext (&operations_, std::move (buf.get ()));
186
218
// Use context->buffer since buf has been moved.
187
219
return FreeableBuffer (
@@ -200,6 +232,36 @@ class DataLoaderSpy : public DataLoader {
200
232
return operations_;
201
233
}
202
234
235
+ /* *
236
+ * Returns true if the DataLoader::load() method was called with the correct
237
+ * segment info.
238
+ */
239
+ bool UsedLoad (
240
+ DataLoader::SegmentInfo::Type segment_type,
241
+ const char * descriptor = nullptr ) const {
242
+ for (const auto & op : operations_) {
243
+ // We should not be using the deprecated DataLoader::Load() function.
244
+ if (op.op == Operation::DeprecatedLoad) {
245
+ return false ;
246
+ }
247
+ if (op.op != Operation::Load) {
248
+ continue ;
249
+ }
250
+ // We have a load op.
251
+ if (op.segment_info ->segment_type == segment_type) {
252
+ if (segment_type != DataLoader::SegmentInfo::Type::Backend) {
253
+ // For non-backend segments, the descriptor is irrelevant / a nullptr.
254
+ return true ;
255
+ } else {
256
+ if (strcmp (op.segment_info ->descriptor , descriptor) == 0 ) {
257
+ return true ;
258
+ }
259
+ }
260
+ }
261
+ }
262
+ return false ;
263
+ }
264
+
203
265
/* *
204
266
* Returns true if the operations list shows that the provided data pointer
205
267
* was freed.
@@ -223,7 +285,8 @@ class DataLoaderSpy : public DataLoader {
223
285
224
286
static void FreeBuffer (void * context, void * data, size_t size) {
225
287
auto * sc = reinterpret_cast <SpyContext*>(context);
226
- sc->operations ->push_back ({Operation::Free, /* offset=*/ 0 , data, size});
288
+ sc->operations ->push_back (
289
+ {Operation::Free, /* offset=*/ 0 , data, size, /* segment_info=*/ nullptr });
227
290
delete sc;
228
291
}
229
292
@@ -333,7 +396,7 @@ TEST_P(BackendIntegrationTest, FreeingProcessedBufferSucceeds) {
333
396
EXPECT_EQ (method_res.error (), Error::Ok);
334
397
335
398
// Demonstrate that our installed init was called.
336
- EXPECT_EQ (init_called, true );
399
+ EXPECT_TRUE (init_called);
337
400
338
401
// See if the processed data was freed.
339
402
bool processed_was_freed = spy_loader.WasFreed (processed_data);
@@ -444,6 +507,51 @@ TEST_P(BackendIntegrationTest, EndToEndTestWithProcessedAsHandle) {
444
507
EXPECT_EQ (execute_handle, destroy_handle);
445
508
}
446
509
510
+ /* *
511
+ * Tests that the DataLoader's load is receiving the correct segment info for
512
+ * different types of segments.
513
+ */
514
+ TEST_P (BackendIntegrationTest, SegmentInfoIsPassedIntoDataLoader) {
515
+ const void * processed_data = nullptr ;
516
+ StubBackend::singleton ().install_init (
517
+ [&](FreeableBuffer* processed,
518
+ __ET_UNUSED ArrayRef<CompileSpec> compile_specs,
519
+ __ET_UNUSED MemoryAllocator* runtime_allocator)
520
+ -> Result<DelegateHandle*> {
521
+ processed_data = processed->data ();
522
+ processed->Free ();
523
+ return nullptr ;
524
+ });
525
+
526
+ // Wrap the real loader in a spy so we can see which operations were
527
+ // performed.
528
+ Result<FileDataLoader> loader = FileDataLoader::from (program_path ());
529
+ ASSERT_EQ (loader.error (), Error::Ok);
530
+ DataLoaderSpy spy_loader (&loader.get ());
531
+
532
+ // Load the program.
533
+ Result<Program> program = Program::load (&spy_loader);
534
+ ASSERT_EQ (program.error (), Error::Ok);
535
+ ManagedMemoryManager mmm (kDefaultNonConstMemBytes , kDefaultRuntimeMemBytes );
536
+
537
+ // Expect that load was called correctly on program segments.
538
+ bool program_load_was_called =
539
+ spy_loader.UsedLoad (DataLoader::SegmentInfo::Type::Program, nullptr );
540
+
541
+ // Load a method.
542
+ Result<Method> method_res = program->load_method (" forward" , &mmm.get ());
543
+ EXPECT_EQ (method_res.error (), Error::Ok);
544
+
545
+ // Expect that load was called correctly on a backend segment.
546
+ bool backend_load_was_called = spy_loader.UsedLoad (
547
+ DataLoader::SegmentInfo::Type::Backend,
548
+ " backend_segment" ); // TODO(jackzhxng): replace with actual mock PTE
549
+ // file's backend_id in next chained PR.
550
+
551
+ EXPECT_TRUE (program_load_was_called);
552
+ EXPECT_EQ (backend_load_was_called, using_segments ());
553
+ }
554
+
447
555
// TODO: Add more tests for the runtime-to-backend interface. E.g.:
448
556
// - Errors during init() or execute() result in runtime init/execution failures
449
557
// - Correct values are passed to init()/execute()
0 commit comments