Skip to content

Commit 504366f

Browse files
lucylqfacebook-github-bot
authored andcommitted
Constant segment runtime tests (#1505)
Summary: Pull Request resolved: #1505 - load segment when constants are in segment - no segments when constants are in flatbuffer Reviewed By: dbort Differential Revision: D52434413 fbshipit-source-id: fb54ee074fc5b11815073e0dff295e463ca33044
1 parent 60df682 commit 504366f

File tree

3 files changed

+106
-3
lines changed

3 files changed

+106
-3
lines changed

runtime/executor/test/method_test.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@ class MethodTest : public ::testing::Test {
5555
load_program(std::getenv("ET_MODULE_INDEX_PATH"), "index");
5656
load_program(
5757
std::getenv("ET_MODULE_DYNAMIC_CAT_UNALLOCATED_IO_PATH"), "cat");
58+
load_program(
59+
std::getenv("ET_MODULE_LINEAR_CONSTANT_SEGMENT_PATH"),
60+
"linear_constant_segment");
61+
load_program(
62+
std::getenv("ET_MODULE_LINEAR_CONSTANT_BUFFER_PATH"),
63+
"linear_constant_buffer");
5864
}
5965

6066
private:
@@ -196,6 +202,30 @@ TEST_F(MethodTest, AliasedIOTest) {
196202
}
197203
}
198204

205+
TEST_F(MethodTest, ConstantSegmentTest) {
206+
// Execute model with constants stored in segment.
207+
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
208+
Result<Method> method =
209+
programs_["linear_constant_segment"]->load_method("forward", &mmm.get());
210+
ASSERT_EQ(method.error(), Error::Ok);
211+
212+
// Can execute the method.
213+
Error err = method->execute();
214+
ASSERT_EQ(err, Error::Ok);
215+
}
216+
217+
TEST_F(MethodTest, ConstantBufferTest) {
218+
// Execute model with constants stored in the program flatbuffer.
219+
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
220+
Result<Method> method =
221+
programs_["linear_constant_buffer"]->load_method("forward", &mmm.get());
222+
ASSERT_EQ(method.error(), Error::Ok);
223+
224+
// Can execute the method.
225+
Error err = method->execute();
226+
ASSERT_EQ(err, Error::Ok);
227+
}
228+
199229
// TODO(T161163608): Test is disabled due to a resize bug in tensor_index_out of
200230
// the portable op lib
201231

runtime/executor/test/program_test.cpp

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <executorch/runtime/core/result.h>
1919
#include <executorch/runtime/executor/program.h>
2020
#include <executorch/runtime/platform/runtime.h>
21+
#include <executorch/schema/program_generated.h>
2122
#include <executorch/test/utils/DeathTest.h>
2223

2324
#include <gtest/gtest.h>
@@ -89,6 +90,11 @@ class ProgramTestFriend final {
8990
size_t index) {
9091
return program->LoadSegment(index);
9192
}
93+
94+
const static executorch_flatbuffer::Program* GetInternalProgram(
95+
const Program* program) {
96+
return program->internal_program_;
97+
}
9298
};
9399
} // namespace testing
94100
} // namespace executor
@@ -299,9 +305,6 @@ TEST_F(ProgramTest, HeaderNotPresent) {
299305
Program::HeaderStatus::NotPresent);
300306
}
301307

302-
// TODO(T144120904): Add tests for programs with segments once we can generate
303-
// them.
304-
305308
TEST_F(ProgramTest, getMethods) {
306309
// Parse the Program from the data.
307310
Result<Program> program_res =
@@ -326,3 +329,70 @@ TEST_F(ProgramTest, DEPRECATEDLoad) {
326329
Result<Program> program_res = Program::Load(multi_loader_.get());
327330
EXPECT_EQ(program_res.error(), Error::Ok);
328331
}
332+
333+
TEST_F(ProgramTest, LoadConstantSegment) {
334+
// Load the serialized ModuleLinear data, with constants in the segment and no
335+
// constants in the flatbuffer.
336+
const char* linear_path =
337+
std::getenv("ET_MODULE_LINEAR_CONSTANT_SEGMENT_PATH");
338+
Result<FileDataLoader> linear_loader = FileDataLoader::from(linear_path);
339+
ASSERT_EQ(linear_loader.error(), Error::Ok);
340+
341+
// This file should always be compatible.
342+
Result<FreeableBuffer> linear_header =
343+
linear_loader->Load(/*offset=*/0, Program::kMinHeadBytes);
344+
ASSERT_EQ(linear_header.error(), Error::Ok);
345+
EXPECT_EQ(
346+
Program::check_header(linear_header->data(), linear_header->size()),
347+
Program::HeaderStatus::CompatibleVersion);
348+
349+
Result<Program> program = Program::load(&linear_loader.get());
350+
ASSERT_EQ(program.error(), Error::Ok);
351+
352+
// Load constant segment data.
353+
Result<FreeableBuffer> segment =
354+
ProgramTestFriend::LoadSegment(&program.get(), 0);
355+
EXPECT_EQ(segment.error(), Error::Ok);
356+
357+
const executorch_flatbuffer::Program* flatbuffer_program =
358+
ProgramTestFriend::GetInternalProgram(&program.get());
359+
360+
// Expect one segment containing the constants.
361+
EXPECT_EQ(flatbuffer_program->segments()->size(), 1);
362+
363+
// The constant buffer should be empty.
364+
EXPECT_EQ(flatbuffer_program->constant_buffer()->size(), 0);
365+
366+
// Check constant segment offsets.
367+
EXPECT_EQ(flatbuffer_program->constant_segment()->segment_index(), 0);
368+
EXPECT_GE(flatbuffer_program->constant_segment()->offsets()->size(), 1);
369+
}
370+
371+
TEST_F(ProgramTest, LoadConstantSegmentWithNoConstantSegment) {
372+
// Load the serialized ModuleLinear data, with constants in the flatbuffer and
373+
// no constants in the segment.
374+
const char* linear_path =
375+
std::getenv("ET_MODULE_LINEAR_CONSTANT_BUFFER_PATH");
376+
Result<FileDataLoader> linear_loader = FileDataLoader::from(linear_path);
377+
ASSERT_EQ(linear_loader.error(), Error::Ok);
378+
379+
// This file should always be compatible.
380+
Result<FreeableBuffer> linear_header =
381+
linear_loader->Load(/*offset=*/0, Program::kMinHeadBytes);
382+
ASSERT_EQ(linear_header.error(), Error::Ok);
383+
EXPECT_EQ(
384+
Program::check_header(linear_header->data(), linear_header->size()),
385+
Program::HeaderStatus::CompatibleVersion);
386+
387+
Result<Program> program = Program::load(&linear_loader.get());
388+
ASSERT_EQ(program.error(), Error::Ok);
389+
390+
const executorch_flatbuffer::Program* flatbuffer_program =
391+
ProgramTestFriend::GetInternalProgram(&program.get());
392+
393+
// Expect no segments.
394+
EXPECT_EQ(flatbuffer_program->segments()->size(), 0);
395+
396+
// The constant buffer should exist.
397+
EXPECT_GE(flatbuffer_program->constant_buffer()->size(), 1);
398+
}

runtime/executor/test/targets.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ def define_common_targets(is_fbcode = False):
8585
"ET_MODULE_ADD_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleAdd.pte])",
8686
"ET_MODULE_DYNAMIC_CAT_UNALLOCATED_IO_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleDynamicCatUnallocatedIO.pte])",
8787
"ET_MODULE_INDEX_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleIndex.pte])",
88+
"ET_MODULE_LINEAR_CONSTANT_BUFFER_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleLinear-no-constant-segment.pte])",
89+
"ET_MODULE_LINEAR_CONSTANT_SEGMENT_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleLinear.pte])",
8890
"ET_MODULE_MULTI_ENTRY_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleMultipleEntry.pte])",
8991
}
9092

@@ -140,6 +142,7 @@ def define_common_targets(is_fbcode = False):
140142
"//executorch/runtime/executor:program",
141143
"//executorch/extension/data_loader:buffer_data_loader",
142144
"//executorch/extension/data_loader:file_data_loader",
145+
"//executorch/schema:program",
143146
],
144147
env = modules_env,
145148
)

0 commit comments

Comments
 (0)