Skip to content

Commit 610f333

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Support mutable tensors in TensorParser (#4713)
Summary: Pull Request resolved: #4713 Call the new method on program in tensor parser. Add a friend class so it can access it. Reviewed By: dvorjackz Differential Revision: D61222257
1 parent 6efc222 commit 610f333

File tree

5 files changed

+112
-9
lines changed

5 files changed

+112
-9
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ if(EXECUTORCH_BUILD_FLATC)
428428
# exir lets users set the alignment of tensor data embedded in the flatbuffer,
429429
# and some users need an alignment larger than the default, which is typically
430430
# 32.
431-
target_compile_definitions(flatc PRIVATE FLATBUFFERS_MAX_ALIGNMENT=1024)
431+
target_compile_definitions(flatc PRIVATE FLATBUFFERS_MAX_ALIGNMENT=2048)
432432
endif()
433433
if(NOT FLATC_EXECUTABLE)
434434
message(

runtime/executor/program.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ namespace testing {
3535
class ProgramTestFriend;
3636
} // namespace testing
3737

38+
namespace deserialization {
39+
// Provides Tensor deserializaiton access to private Program methods.
40+
class TensorParser;
41+
} // namespace deserialization
42+
3843
/**
3944
* A deserialized ExecuTorch program binary.
4045
*/
@@ -194,6 +199,7 @@ class Program final {
194199
friend class BackendDelegate;
195200
friend class Executor;
196201
friend class Method;
202+
friend class deserialization::TensorParser;
197203
friend class testing::ProgramTestFriend;
198204

199205
const executorch_flatbuffer::Program* get_internal_program() const {

runtime/executor/tensor_parser_exec_aten.cpp

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,20 @@ namespace torch {
1919
namespace executor {
2020
namespace deserialization {
2121

22+
// Provides access to private Program methods.
23+
class TensorParser final {
24+
public:
25+
__ET_NODISCARD static Error load_mutable_subsegment_into(
26+
const Program* program,
27+
size_t mutable_data_segments_index,
28+
size_t offset_index,
29+
size_t size,
30+
void* buffer) {
31+
return program->load_mutable_subsegment_into(
32+
mutable_data_segments_index, offset_index, size, buffer);
33+
}
34+
};
35+
2236
namespace {
2337

2438
// Retrieve the buffer specified by the allocation_info
@@ -94,14 +108,17 @@ __ET_NODISCARD Result<void*> getTensorDataPtr(
94108

95109
// Memory Planned, with initial state
96110
if (data_buffer_idx > 0 && allocation_info != nullptr) {
97-
// Stub case for now.
98-
99-
// Get memory planned data pointer
100-
101-
// Call something like program.load_into_buffer(s_tensor->segment_idx,
102-
// s_tensor->data_buffer_idx, mem_planned_buffer, nbytes)
111+
auto planned_ptr = getMemPlannedPtr(allocation_info, nbytes, allocator);
112+
if (!planned_ptr.ok()) {
113+
return planned_ptr.error();
114+
}
115+
auto err = TensorParser::load_mutable_subsegment_into(
116+
program, 0, s_tensor->data_buffer_idx(), nbytes, planned_ptr.get());
103117

104-
return Error::NotImplemented;
118+
if (err != Error::Ok) {
119+
return err;
120+
}
121+
return planned_ptr;
105122

106123
// Constant
107124
} else if (data_buffer_idx > 0 && allocation_info == nullptr) {

runtime/executor/test/tensor_parser_test.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,83 @@ TEST_F(TensorParserTest, TestModuleAddHalf) {
120120
torch::executor::ScalarType::Half,
121121
sizeof(torch::executor::Half));
122122
}
123+
124+
TEST_F(TensorParserTest, TestMutableState) {
125+
// Load the serialized ModuleSimpleTrain data.
126+
const char* path = std::getenv("ET_MODULE_SIMPLE_TRAIN_PATH");
127+
Result<FileDataLoader> train_loader = FileDataLoader::from(path);
128+
ASSERT_EQ(train_loader.error(), Error::Ok);
129+
130+
Result<Program> program =
131+
Program::load(&train_loader.get(), Program::Verification::Minimal);
132+
EXPECT_EQ(program.error(), Error::Ok);
133+
134+
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
135+
ManagedMemoryManager mmm_copy(
136+
kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
137+
138+
const executorch_flatbuffer::Program* internal_program =
139+
ProgramTestFriend::GetInternalProgram(&program.get());
140+
executorch_flatbuffer::ExecutionPlan* execution_plan =
141+
internal_program->execution_plan()->GetMutableObject(0);
142+
auto flatbuffer_values = execution_plan->values();
143+
144+
size_t num_mutable_tensors = 0;
145+
for (size_t i = 0; i < flatbuffer_values->size(); ++i) {
146+
auto serialization_value = flatbuffer_values->Get(i);
147+
if (serialization_value->val_type() ==
148+
executorch_flatbuffer::KernelTypes::Tensor &&
149+
serialization_value->val_as_Tensor()->allocation_info() != nullptr &&
150+
serialization_value->val_as_Tensor()->data_buffer_idx() > 0) {
151+
num_mutable_tensors++;
152+
Result<torch::executor::Tensor> tensor = parseTensor(
153+
&program.get(), &mmm.get(), serialization_value->val_as_Tensor());
154+
torch::executor::Tensor t = tensor.get();
155+
float loaded_value = t.const_data_ptr<float>()[0];
156+
ASSERT_NE(nullptr, t.const_data_ptr());
157+
ASSERT_NE(t.mutable_data_ptr<float>()[0], 0.5);
158+
t.mutable_data_ptr<float>()[0] = 0.5;
159+
ASSERT_EQ(
160+
t.mutable_data_ptr<float>()[0],
161+
0.5); // 0.5 can be represented perfectly by float so EQ and NE work
162+
// fine here. Any power of 2 rational can be perfectly
163+
// represented. See dyadic rationals for more info.
164+
165+
// Load the same tensor using the same mem manager and show the value is
166+
// updated again.
167+
Result<torch::executor::Tensor> tensor1_alias = parseTensor(
168+
&program.get(), &mmm.get(), serialization_value->val_as_Tensor());
169+
torch::executor::Tensor t2 = tensor.get();
170+
ASSERT_NE(t2.mutable_data_ptr<float>()[0], 0.5);
171+
172+
// Show the tensors are equivalent
173+
ASSERT_EQ(t.const_data_ptr(), t2.const_data_ptr());
174+
t.mutable_data_ptr<float>()[0] = 0.5;
175+
176+
// Load the same tensor using a different mem manager and show the value
177+
// is not the same as t1.
178+
Result<torch::executor::Tensor> tensor_new = parseTensor(
179+
&program.get(),
180+
&mmm_copy.get(),
181+
serialization_value->val_as_Tensor());
182+
torch::executor::Tensor t3 = tensor_new.get();
183+
ASSERT_NE(t3.mutable_data_ptr<float>()[0], 0.5);
184+
ASSERT_NE(t3.const_data_ptr(), t.const_data_ptr());
185+
ASSERT_EQ(loaded_value, t3.const_data_ptr<float>()[0]);
186+
187+
// Hard check the first byte of the serialized data.
188+
// 232 and 210 comes from inspecting the file itself. The
189+
// file is seeded so this value should be stable.
190+
if (num_mutable_tensors == 1) {
191+
const uint8_t* byte_data =
192+
reinterpret_cast<const uint8_t*>(t3.const_data_ptr());
193+
ASSERT_EQ(byte_data[0], 232);
194+
} else if (num_mutable_tensors == 2) {
195+
const uint8_t* byte_data =
196+
reinterpret_cast<const uint8_t*>(t3.const_data_ptr());
197+
ASSERT_EQ(byte_data[0], 210);
198+
}
199+
}
200+
}
201+
ASSERT_EQ(num_mutable_tensors, 2);
202+
}

schema/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ function(generate_program_schema _schema_srcs _schema_name)
4949
# and some users need an alignment larger than the default, which is typically
5050
# 32.
5151
target_compile_definitions(
52-
${_schema_name} INTERFACE FLATBUFFERS_MAX_ALIGNMENT=1024)
52+
${_schema_name} INTERFACE FLATBUFFERS_MAX_ALIGNMENT=2048)
5353

5454
target_include_directories(
5555
${_schema_name}

0 commit comments

Comments
 (0)