Skip to content

Commit 7153a27

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Support mutable tensors in TensorParser (#4713)
Summary: Pull Request resolved: #4713 Call the new method on program in tensor parser. Add a friend class so it can access it. Reviewed By: lucylq, dvorjackz Differential Revision: D61222257
1 parent caadd81 commit 7153a27

File tree

4 files changed

+102
-8
lines changed

4 files changed

+102
-8
lines changed

runtime/executor/program.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ namespace testing {
3535
class ProgramTestFriend;
3636
} // namespace testing
3737

38+
namespace deserialization {
39+
// Provides Tensor deserializaiton access to private Program methods.
40+
class TensorParser;
41+
} // namespace deserialization
42+
3843
/**
3944
* A deserialized ExecuTorch program binary.
4045
*/
@@ -194,6 +199,7 @@ class Program final {
194199
friend class BackendDelegate;
195200
friend class Executor;
196201
friend class Method;
202+
friend class deserialization::TensorParser;
197203
friend class testing::ProgramTestFriend;
198204

199205
const executorch_flatbuffer::Program* get_internal_program() const {

runtime/executor/tensor_parser_exec_aten.cpp

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,20 @@ namespace torch {
1919
namespace executor {
2020
namespace deserialization {
2121

22+
// Provides access to private Program methods.
23+
class TensorParser final {
24+
public:
25+
__ET_NODISCARD static Error load_mutable_subsegment_into(
26+
const Program* program,
27+
size_t mutable_data_segments_index,
28+
size_t offset_index,
29+
size_t size,
30+
void* buffer) {
31+
return program->load_mutable_subsegment_into(
32+
mutable_data_segments_index, offset_index, size, buffer);
33+
}
34+
};
35+
2236
namespace {
2337

2438
// Retrieve the buffer specified by the allocation_info
@@ -94,14 +108,17 @@ __ET_NODISCARD Result<void*> getTensorDataPtr(
94108

95109
// Memory Planned, with initial state
96110
if (data_buffer_idx > 0 && allocation_info != nullptr) {
97-
// Stub case for now.
98-
99-
// Get memory planned data pointer
100-
101-
// Call something like program.load_into_buffer(s_tensor->segment_idx,
102-
// s_tensor->data_buffer_idx, mem_planned_buffer, nbytes)
111+
auto planned_ptr = getMemPlannedPtr(allocation_info, nbytes, allocator);
112+
if (!planned_ptr.ok()) {
113+
return planned_ptr.error();
114+
}
115+
auto err = TensorParser::load_mutable_subsegment_into(
116+
program, 0, s_tensor->data_buffer_idx(), nbytes, planned_ptr.get());
103117

104-
return Error::NotImplemented;
118+
if (err != Error::Ok) {
119+
return err;
120+
}
121+
return planned_ptr;
105122

106123
// Constant
107124
} else if (data_buffer_idx > 0 && allocation_info == nullptr) {

runtime/executor/test/tensor_parser_test.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,72 @@ TEST_F(TensorParserTest, TestModuleAddHalf) {
120120
torch::executor::ScalarType::Half,
121121
sizeof(torch::executor::Half));
122122
}
123+
124+
TEST_F(TensorParserTest, TestMutableState) {
125+
// Load the serialized ModuleSimpleTrain data.
126+
const char* path = std::getenv("ET_MODULE_SIMPLE_TRAIN_PATH");
127+
Result<FileDataLoader> train_loader = FileDataLoader::from(path);
128+
ASSERT_EQ(train_loader.error(), Error::Ok);
129+
130+
Result<Program> program =
131+
Program::load(&train_loader.get(), Program::Verification::Minimal);
132+
EXPECT_EQ(program.error(), Error::Ok);
133+
134+
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
135+
ManagedMemoryManager mmm_copy(
136+
kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
137+
138+
const executorch_flatbuffer::Program* internal_program =
139+
ProgramTestFriend::GetInternalProgram(&program.get());
140+
executorch_flatbuffer::ExecutionPlan* execution_plan =
141+
internal_program->execution_plan()->GetMutableObject(0);
142+
auto flatbuffer_values = execution_plan->values();
143+
144+
size_t num_mutable_tensors = 0;
145+
for (size_t i = 0; i < flatbuffer_values->size(); ++i) {
146+
auto serialization_value = flatbuffer_values->Get(i);
147+
if (serialization_value->val_type() ==
148+
executorch_flatbuffer::KernelTypes::Tensor &&
149+
serialization_value->val_as_Tensor()->allocation_info() != nullptr &&
150+
serialization_value->val_as_Tensor()->data_buffer_idx() > 0) {
151+
num_mutable_tensors++;
152+
Result<torch::executor::Tensor> tensor = parseTensor(
153+
&program.get(), &mmm.get(), serialization_value->val_as_Tensor());
154+
torch::executor::Tensor t = tensor.get();
155+
float loaded_value = t.const_data_ptr<float>()[0];
156+
ASSERT_NE(nullptr, t.const_data_ptr());
157+
ASSERT_NE(t.mutable_data_ptr<float>()[0], 0.5);
158+
t.mutable_data_ptr<float>()[0] = 0.5;
159+
ASSERT_EQ(
160+
t.mutable_data_ptr<float>()[0],
161+
0.5); // 0.5 can be represented perfectly by float so EQ and NE work
162+
// fine here. Any power of 2 rational can be perfectly
163+
// represented. See dyadic rationals for more info.
164+
165+
// Load the same tensor using the same mem manager and show the value is
166+
// updated again.
167+
Result<torch::executor::Tensor> tensor1_alias = parseTensor(
168+
&program.get(), &mmm.get(), serialization_value->val_as_Tensor());
169+
torch::executor::Tensor t2 = tensor.get();
170+
ASSERT_NE(t2.mutable_data_ptr<float>()[0], 0.5);
171+
172+
// Show the tensors are equivalent
173+
ASSERT_EQ(t.const_data_ptr(), t2.const_data_ptr());
174+
// Set mutable tensor value back to 0.5 since it got overwritten by second
175+
// parse.
176+
t.mutable_data_ptr<float>()[0] = 0.5;
177+
178+
// Load the same tensor using a different mem manager and show the value
179+
// is not the same as t.
180+
Result<torch::executor::Tensor> tensor_new = parseTensor(
181+
&program.get(),
182+
&mmm_copy.get(),
183+
serialization_value->val_as_Tensor());
184+
torch::executor::Tensor t3 = tensor_new.get();
185+
ASSERT_NE(t3.mutable_data_ptr<float>()[0], 0.5);
186+
ASSERT_NE(t3.const_data_ptr(), t.const_data_ptr());
187+
ASSERT_EQ(loaded_value, t3.const_data_ptr<float>()[0]);
188+
}
189+
}
190+
ASSERT_EQ(num_mutable_tensors, 2);
191+
}

test/run_oss_cpp_tests.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ build_gtest() {
5252
}
5353

5454
export_test_model() {
55-
python3 -m test.models.export_program --modules "ModuleAdd,ModuleAddHalf,ModuleDynamicCatUnallocatedIO,ModuleIndex,ModuleLinear,ModuleMultipleEntry" --outdir "cmake-out" 2> /dev/null
55+
python3 -m test.models.export_program --modules "ModuleAdd,ModuleAddHalf,ModuleDynamicCatUnallocatedIO,ModuleIndex,ModuleLinear,ModuleMultipleEntry,ModuleSimpleTrain" --outdir "cmake-out" 2> /dev/null
5656
python3 -m test.models.export_delegated_program --modules "ModuleAddMul" --backend_id "StubBackend" --outdir "cmake-out" || true
5757

5858
ET_MODULE_ADD_HALF_PATH="$(realpath cmake-out/ModuleAddHalf.pte)"
@@ -65,6 +65,7 @@ export_test_model() {
6565
ET_MODULE_ADD_MUL_NOSEGMENTS_DA1024_PATH="$(realpath cmake-out/ModuleAddMul-nosegments-da1024.pte)"
6666
ET_MODULE_ADD_MUL_NOSEGMENTS_PATH="$(realpath cmake-out/ModuleAddMul-nosegments.pte)"
6767
ET_MODULE_ADD_MUL_PATH="$(realpath cmake-out/ModuleAddMul.pte)"
68+
ET_MODULE_SIMPLE_TRAIN_PATH="$(realpath cmake-out/ModuleSimpleTrain.pte)"
6869
export ET_MODULE_ADD_HALF_PATH
6970
export ET_MODULE_ADD_PATH
7071
export ET_MODULE_DYNAMIC_CAT_UNALLOCATED_IO_PATH
@@ -75,6 +76,7 @@ export_test_model() {
7576
export ET_MODULE_ADD_MUL_NOSEGMENTS_DA1024_PATH
7677
export ET_MODULE_ADD_MUL_NOSEGMENTS_PATH
7778
export ET_MODULE_ADD_MUL_PATH
79+
export ET_MODULE_SIMPLE_TRAIN_PATH
7880
}
7981

8082
build_and_run_test() {

0 commit comments

Comments
 (0)