@@ -168,20 +168,52 @@ class DataLoaderSpy : public DataLoader {
168
168
public:
169
169
// / A record of an operation performed on this DataLoader.
170
170
struct Operation {
171
- enum { Load, Free } op;
172
- size_t offset; // Set for Load; zero for Free.
173
- void * data; // Set for Free; nullptr for Load.
174
- size_t size; // Set for Load and Free.
171
+ enum { Load, Free, DeprecatedLoad } op;
172
+ size_t offset; // Set for Load/DeprecatedLoad; zero for Free.
173
+ void * data; // Set for Free; nullptr for Load/DeprecatedLoad.
174
+ size_t size; // Set for Load/DeprecatedLoad and Free.
175
+ std::unique_ptr<const DataLoader::SegmentInfo>
176
+ segment_info; // Set for Load; nullptr for Free/DeprecatedLoad.
175
177
};
176
178
177
179
explicit DataLoaderSpy (DataLoader* delegate) : delegate_(delegate) {}
178
180
181
+ /* *
182
+ * Override the deprecated "Load" method. We will be looking to test that
183
+ * this function is not called if the new "load" method is called.
184
+ */
179
185
Result<FreeableBuffer> Load (size_t offset, size_t size) override {
180
186
Result<FreeableBuffer> buf = delegate_->Load (offset, size);
181
187
if (!buf.ok ()) {
182
188
return buf.error ();
183
189
}
184
- operations_.push_back ({Operation::Load, offset, /* data=*/ nullptr , size});
190
+ operations_.push_back (
191
+ {Operation::DeprecatedLoad,
192
+ offset,
193
+ /* data=*/ nullptr ,
194
+ size,
195
+ /* segment_info=*/ nullptr });
196
+ auto * context = new SpyContext (&operations_, std::move (buf.get ()));
197
+ // Use context->buffer since buf has been moved.
198
+ return FreeableBuffer (
199
+ context->buffer .data (), context->buffer .size (), FreeBuffer, context);
200
+ }
201
+
202
+ Result<FreeableBuffer>
203
+ load (size_t offset, size_t size, const SegmentInfo& segment_info) override {
204
+ Result<FreeableBuffer> buf = delegate_->load (offset, size, segment_info);
205
+ if (!buf.ok ()) {
206
+ return buf.error ();
207
+ }
208
+
209
+ auto segment_info_cpy =
210
+ std::make_unique<const DataLoader::SegmentInfo>(segment_info);
211
+ operations_.push_back (
212
+ {Operation::Load,
213
+ offset,
214
+ /* data=*/ nullptr ,
215
+ size,
216
+ /* segment_info=*/ std::move (segment_info_cpy)});
185
217
auto * context = new SpyContext (&operations_, std::move (buf.get ()));
186
218
// Use context->buffer since buf has been moved.
187
219
return FreeableBuffer (
@@ -200,6 +232,33 @@ class DataLoaderSpy : public DataLoader {
200
232
return operations_;
201
233
}
202
234
235
+ /* *
236
+ * Returns true if the DataLoader::load() method was called with the correct
237
+ * segment info.
238
+ */
239
+ bool UsedLoad (
240
+ DataLoader::SegmentInfo::Type segment_type,
241
+ const char * descriptor) const {
242
+ for (const auto & op : operations_) {
243
+ // We should not be using the deprecated DataLoader::Load() function.
244
+ if (op.op == Operation::DeprecatedLoad)
245
+ return false ;
246
+ if (op.op != Operation::Load)
247
+ continue ;
248
+ // We have a load op.
249
+ if (op.segment_info ->segment_type == segment_type) {
250
+ if (segment_type != DataLoader::SegmentInfo::Type::Backend) {
251
+ // For non-backend segments, the descriptor is irrelevant / a nullptr.
252
+ return true ;
253
+ } else {
254
+ if (strcmp (op.segment_info ->descriptor , descriptor) == 0 )
255
+ return true ;
256
+ }
257
+ }
258
+ }
259
+ return false ;
260
+ }
261
+
203
262
/* *
204
263
* Returns true if the operations list shows that the provided data pointer
205
264
* was freed.
@@ -223,7 +282,8 @@ class DataLoaderSpy : public DataLoader {
223
282
224
283
static void FreeBuffer (void * context, void * data, size_t size) {
225
284
auto * sc = reinterpret_cast <SpyContext*>(context);
226
- sc->operations ->push_back ({Operation::Free, /* offset=*/ 0 , data, size});
285
+ sc->operations ->push_back (
286
+ {Operation::Free, /* offset=*/ 0 , data, size, /* segment_info=*/ nullptr });
227
287
delete sc;
228
288
}
229
289
@@ -333,7 +393,7 @@ TEST_P(BackendIntegrationTest, FreeingProcessedBufferSucceeds) {
333
393
EXPECT_EQ (method_res.error (), Error::Ok);
334
394
335
395
// Demonstrate that our installed init was called.
336
- EXPECT_EQ (init_called, true );
396
+ EXPECT_TRUE (init_called);
337
397
338
398
// See if the processed data was freed.
339
399
bool processed_was_freed = spy_loader.WasFreed (processed_data);
@@ -444,6 +504,53 @@ TEST_P(BackendIntegrationTest, EndToEndTestWithProcessedAsHandle) {
444
504
EXPECT_EQ (execute_handle, destroy_handle);
445
505
}
446
506
507
+ /* *
508
+ * Tests that the DataLoader's load is receiving the correct segment info for
509
+ * different types of segments.
510
+ */
511
+ TEST_P (BackendIntegrationTest, SegmentInfoIsPassedIntoDataLoader) {
512
+ const void * processed_data = nullptr ;
513
+ StubBackend::singleton ().install_init (
514
+ [&](FreeableBuffer* processed,
515
+ __ET_UNUSED ArrayRef<CompileSpec> compile_specs,
516
+ __ET_UNUSED MemoryAllocator* runtime_allocator)
517
+ -> Result<DelegateHandle*> {
518
+ processed_data = processed->data ();
519
+ processed->Free ();
520
+ return nullptr ;
521
+ });
522
+
523
+ // Wrap the real loader in a spy so we can see which operations were
524
+ // performed.
525
+ Result<FileDataLoader> loader = FileDataLoader::from (program_path ());
526
+ ASSERT_EQ (loader.error (), Error::Ok);
527
+ DataLoaderSpy spy_loader (&loader.get ());
528
+
529
+ // Load the program.
530
+ Result<Program> program = Program::load (&spy_loader);
531
+ ASSERT_EQ (program.error (), Error::Ok);
532
+ ManagedMemoryManager mmm (kDefaultNonConstMemBytes , kDefaultRuntimeMemBytes );
533
+
534
+ // Expect that load was called correctly on program segments.
535
+ bool program_load_was_called =
536
+ spy_loader.UsedLoad (DataLoader::SegmentInfo::Type::Program, nullptr );
537
+
538
+ // Load a method.
539
+ Result<Method> method_res = program->load_method (" forward" , &mmm.get ());
540
+ EXPECT_EQ (method_res.error (), Error::Ok);
541
+
542
+ // Expect that load was called correctly on a backend segment.
543
+ bool backend_load_was_called = spy_loader.UsedLoad (
544
+ DataLoader::SegmentInfo::Type::Backend,
545
+ " backend_segment" ); // TODO(jackzhxng): replace with actual mock PTE
546
+ // file's backend_id in next chained PR.
547
+
548
+ EXPECT_TRUE (program_load_was_called);
549
+ if (using_segments ()) {
550
+ EXPECT_TRUE (backend_load_was_called);
551
+ }
552
+ }
553
+
447
554
// TODO: Add more tests for the runtime-to-backend interface. E.g.:
448
555
// - Errors during init() or execute() result in runtime init/execution failures
449
556
// - Correct values are passed to init()/execute()
0 commit comments