@@ -62,7 +62,7 @@ class ProgramTest : public ::testing::Test {
62
62
63
63
add_loader_ = std::make_unique<FileDataLoader>(std::move (loader.get ()));
64
64
65
- // Load the serialized ModuleAdd data.
65
+ // Load the serialized ModuleMultiEntry data.
66
66
path = std::getenv (" ET_MODULE_MULTI_ENTRY_PATH" );
67
67
Result<FileDataLoader> multi_loader = FileDataLoader::from (path);
68
68
ASSERT_EQ (multi_loader.error (), Error::Ok);
@@ -98,6 +98,16 @@ class ProgramTestFriend final {
98
98
return program->LoadSegment (segment_info);
99
99
}
100
100
101
+ __ET_NODISCARD static Error load_mutable_subsegment_into (
102
+ const Program* program,
103
+ size_t mutable_data_segments_index,
104
+ size_t offset_index,
105
+ size_t size,
106
+ void * buffer) {
107
+ return program->load_mutable_subsegment_into (
108
+ mutable_data_segments_index, offset_index, size, buffer);
109
+ }
110
+
101
111
const static executorch_flatbuffer::Program* GetInternalProgram (
102
112
const Program* program) {
103
113
return program->internal_program_ ;
@@ -444,3 +454,86 @@ TEST_F(ProgramTest, LoadConstantSegmentWithNoConstantSegment) {
444
454
// The constant buffer should exist.
445
455
EXPECT_GE (flatbuffer_program->constant_buffer ()->size (), 1 );
446
456
}
457
+
458
+ TEST_F (ProgramTest, LoadFromMutableSegment) {
459
+ // Load the serialized ModuleSimpleTrain data.
460
+ auto path = std::getenv (" ET_MODULE_SIMPLE_TRAIN_PATH" );
461
+ Result<FileDataLoader> training_loader = FileDataLoader::from (path);
462
+ ASSERT_EQ (training_loader.error (), Error::Ok);
463
+
464
+ // This file should always be compatible.
465
+ Result<FreeableBuffer> training_header = training_loader->load (
466
+ /* offset=*/ 0 ,
467
+ Program::kMinHeadBytes ,
468
+ /* segment_info=*/
469
+ DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::Program));
470
+ ASSERT_EQ (training_header.error (), Error::Ok);
471
+ EXPECT_EQ (
472
+ Program::check_header (training_header->data (), training_header->size ()),
473
+ Program::HeaderStatus::CompatibleVersion);
474
+
475
+ Result<Program> program = Program::load (&training_loader.get ());
476
+ ASSERT_EQ (program.error (), Error::Ok);
477
+
478
+ // dummy buffers to load into
479
+ uint8_t buffer[1 ] = {0 };
480
+ uint8_t buffer2[1 ] = {0 };
481
+
482
+ // Load some mutable segment data
483
+ Error err = ProgramTestFriend::load_mutable_subsegment_into (
484
+ &program.get (), 0 , 1 , 1 , buffer);
485
+ EXPECT_EQ (err, Error::Ok);
486
+
487
+ // Check that the data loaded correctly, and then mutate it
488
+ EXPECT_EQ (buffer[0 ], 232 ); // 232 comes from inspecting the file itself. The
489
+ // file is seeded so this value should be stable.
490
+ buffer[0 ] = 0 ;
491
+
492
+ // Load some more mutable segment data
493
+ err = ProgramTestFriend::load_mutable_subsegment_into (
494
+ &program.get (), 0 , 1 , 1 , buffer2);
495
+ EXPECT_EQ (err, Error::Ok);
496
+
497
+ // Check that new data loaded from the file does not reflect the change to
498
+ // buffer.
499
+ EXPECT_EQ (buffer2[0 ], 232 );
500
+
501
+ const executorch_flatbuffer::Program* flatbuffer_program =
502
+ ProgramTestFriend::GetInternalProgram (&program.get ());
503
+
504
+ // Expect 1 segment. 1 mutable segment and no constant segment.
505
+ EXPECT_EQ (flatbuffer_program->segments ()->size (), 1 );
506
+
507
+ // Expect a mutable data segment.
508
+ EXPECT_EQ (flatbuffer_program->mutable_data_segments ()->size (), 1 );
509
+
510
+ // Expect the 0 index to be reserved and the offsets for weight and bias of
511
+ // linear to be indices 1 and 2.
512
+ EXPECT_EQ (
513
+ flatbuffer_program->mutable_data_segments ()->Get (0 )->offsets ()->size (),
514
+ 3 );
515
+ EXPECT_EQ (
516
+ flatbuffer_program->mutable_data_segments ()->Get (0 )->offsets ()->Get (0 ),
517
+ 0 );
518
+ EXPECT_EQ (
519
+ flatbuffer_program->mutable_data_segments ()->Get (0 )->offsets ()->Get (1 ),
520
+ 0 );
521
+ EXPECT_EQ (
522
+ flatbuffer_program->mutable_data_segments ()->Get (0 )->offsets ()->Get (2 ),
523
+ 36 );
524
+
525
+ // Loading beyond file should fail
526
+ err = ProgramTestFriend::load_mutable_subsegment_into (
527
+ &program.get (), 0 , 1 , 500 , buffer);
528
+ EXPECT_NE (err, Error::Ok);
529
+
530
+ // Loading beyond offsets should fail
531
+ err = ProgramTestFriend::load_mutable_subsegment_into (
532
+ &program.get (), 0 , 500 , 1 , buffer);
533
+ EXPECT_NE (err, Error::Ok);
534
+
535
+ // Loading beyond segments should fail
536
+ err = ProgramTestFriend::load_mutable_subsegment_into (
537
+ &program.get (), 500 , 1 , 1 , buffer);
538
+ EXPECT_NE (err, Error::Ok);
539
+ }
0 commit comments