Skip to content

Commit df72b8c

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Use TensorMeta to check if inputs and outputs are memory planned (#5565)
Summary: Pull Request resolved: #5565 Swap to using method meta so we can be finer grained about this check Reviewed By: dbort Differential Revision: D62983475 fbshipit-source-id: c4599c5ecad0409cd8b2670464c4e9e8809b49ad
1 parent f4728f4 commit df72b8c

File tree

4 files changed

+75
-64
lines changed

4 files changed

+75
-64
lines changed

extension/pybindings/test/make_test.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
# pyre-unsafe
88

99
import unittest
10-
from typing import Any, Callable, Tuple
10+
from typing import Any, Callable, Optional, Tuple
1111

1212
import torch
13-
from executorch.exir import ExecutorchProgramManager, to_edge
13+
from executorch.exir import ExecutorchBackendConfig, ExecutorchProgramManager, to_edge
14+
from executorch.exir.passes import MemoryPlanningPass
1415
from torch.export import export
1516

1617

@@ -75,8 +76,25 @@ def get_methods_to_export(self):
7576
def get_inputs(self):
7677
return (torch.ones(2, 2),)
7778

79+
class ModuleAddConstReturn(torch.nn.Module):
80+
"""The module to serialize and execute."""
81+
82+
def __init__(self):
83+
super(ModuleAddConstReturn, self).__init__()
84+
self.state = torch.ones(2, 2)
85+
86+
def forward(self, x):
87+
return x + self.state, self.state
88+
89+
def get_methods_to_export(self):
90+
return ("forward",)
91+
92+
def get_inputs(self):
93+
return (torch.ones(2, 2),)
94+
7895
def create_program(
7996
eager_module: torch.nn.Module,
97+
et_config: Optional[ExecutorchBackendConfig] = None,
8098
) -> Tuple[ExecutorchProgramManager, Tuple[Any, ...]]:
8199
"""Returns an executorch program based on ModuleAdd, along with inputs."""
82100

@@ -103,7 +121,7 @@ def forward(self, *args, **kwargs):
103121
)
104122
exported_methods[method_name] = export(wrapped_mod, method_input)
105123

106-
exec_prog = to_edge(exported_methods).to_executorch()
124+
exec_prog = to_edge(exported_methods).to_executorch(config=et_config)
107125

108126
# Create the ExecuTorch program from the graph.
109127
exec_prog.dump_executorch_program(verbose=True)
@@ -251,12 +269,41 @@ def test_quantized_ops(tester):
251269
expected = example_inputs[0] + example_inputs[1]
252270
tester.assertEqual(str(expected), str(executorch_output))
253271

272+
def test_constant_output_not_memory_planned(tester):
273+
# Create an ExecuTorch program from ModuleAdd.
274+
exported_program, inputs = create_program(
275+
ModuleAddConstReturn(),
276+
et_config=ExecutorchBackendConfig(
277+
memory_planning_pass=MemoryPlanningPass(alloc_graph_output=False)
278+
),
279+
)
280+
281+
exported_program.dump_executorch_program(verbose=True)
282+
283+
# Use pybindings to load and execute the program.
284+
executorch_module = load_fn(exported_program.buffer)
285+
# Invoke the callable on executorch_module instead of calling module.forward.
286+
# Use only one input to test this case.
287+
executorch_output = executorch_module((torch.ones(2, 2),))
288+
print(executorch_output)
289+
290+
# The test module adds the input to torch.ones(2,2), so its output should be the same
291+
# as adding them directly.
292+
expected = torch.ones(2, 2) + torch.ones(2, 2)
293+
tester.assertEqual(str(expected), str(executorch_output[0]))
294+
295+
# The test module returns the state. Check that its value is correct.
296+
tester.assertEqual(str(torch.ones(2, 2)), str(executorch_output[1]))
297+
298+
######### RUN TEST CASES #########
299+
254300
test_e2e(tester)
255301
test_multiple_entry(tester)
256302
test_output_lifespan(tester)
257303
test_module_callable(tester)
258304
test_module_single_input(tester)
259305
test_stderr_redirect(tester)
260306
test_quantized_ops(tester)
307+
test_constant_output_not_memory_planned(tester)
261308

262309
return wrapper

runtime/executor/method.cpp

Lines changed: 14 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -744,40 +744,6 @@ Error Method::init(executorch_flatbuffer::ExecutionPlan* s_plan) {
744744
}
745745
}
746746

747-
// Validate input values and get tensor pre-allocation info.
748-
pre_allocated_input_ = false;
749-
for (int i = 0; i < inputs_size(); i++) {
750-
// get_input() will panic if the index is invalid, so do this manually.
751-
size_t index = get_input_index(i);
752-
ET_CHECK_OR_RETURN_ERROR(
753-
index < n_value_,
754-
InvalidProgram,
755-
"Input index %zu >= %zu",
756-
index,
757-
n_value_);
758-
const EValue& input = values_[index];
759-
if (input.isTensor()) {
760-
pre_allocated_input_ |= input.toTensor().const_data_ptr() != nullptr;
761-
}
762-
}
763-
764-
// Validate output values and get tensor pre-allocation info.
765-
pre_allocated_output_ = false;
766-
for (int i = 0; i < outputs_size(); i++) {
767-
// get_output() will panic if the index is invalid, so do this manually.
768-
size_t index = get_output_index(i);
769-
ET_CHECK_OR_RETURN_ERROR(
770-
index < n_value_,
771-
InvalidProgram,
772-
"output index %zu >= %zu",
773-
index,
774-
n_value_);
775-
const EValue& output = values_[index];
776-
if (output.isTensor()) {
777-
pre_allocated_output_ |= output.toTensor().const_data_ptr() != nullptr;
778-
}
779-
}
780-
781747
step_state_ = StepState{0, 0};
782748

783749
init_state_ = InitializationState::Initialized;
@@ -841,7 +807,8 @@ Method::set_input(const EValue& input_evalue, size_t input_idx) {
841807
input_idx,
842808
static_cast<uint32_t>(err));
843809
Error error;
844-
if (pre_allocated_input_) {
810+
auto tensor_meta = this->method_meta().input_tensor_meta(input_idx);
811+
if (tensor_meta->is_memory_planned()) {
845812
error = internal::copy_tensor_data(t_dst, t_src);
846813
} else {
847814
error = internal::share_tensor_data(t_dst, t_src);
@@ -950,21 +917,11 @@ Method::set_output_data_ptr(void* buffer, size_t size, size_t output_idx) {
950917
InvalidState,
951918
"Outputs can not be retrieved until method has been initialized.");
952919

953-
// ET_CHECK_OR_RETURN_ERROR(
954-
// !pre_allocated_output_,
955-
// InvalidState,
956-
// "Overriding output data pointer allocated by memory plan is not
957-
// allowed.");
958-
// TODO(T188740925): for now, return error without logs.
959-
if (pre_allocated_output_) {
960-
return Error::InvalidState;
961-
}
962-
963920
// Check the args
964921
ET_CHECK_OR_RETURN_ERROR(
965-
output_idx <= outputs_size(),
922+
output_idx < outputs_size(),
966923
InvalidArgument,
967-
"output_idx: %zu num_outputs: %zu",
924+
"output_idx: %zu > num_outputs: %zu",
968925
output_idx,
969926
outputs_size());
970927

@@ -975,6 +932,16 @@ Method::set_output_data_ptr(void* buffer, size_t size, size_t output_idx) {
975932
"output type: %zu is not tensor",
976933
(size_t)output.tag);
977934

935+
auto tensor_meta = this->method_meta().output_tensor_meta(output_idx);
936+
if (tensor_meta->is_memory_planned()) {
937+
ET_LOG(
938+
Error,
939+
"Output %zu is memory planned, or is a constant. Cannot override \
940+
the existing data pointer.",
941+
output_idx);
942+
return Error::InvalidState;
943+
}
944+
978945
auto& t = output.toTensor();
979946
ET_CHECK_OR_RETURN_ERROR(
980947
output.isTensor(),

runtime/executor/method.h

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,7 @@ class Method final {
6262
delegates_(rhs.delegates_),
6363
n_chains_(rhs.n_chains_),
6464
chains_(rhs.chains_),
65-
init_state_(rhs.init_state_),
66-
pre_allocated_input_(rhs.pre_allocated_input_),
67-
pre_allocated_output_(rhs.pre_allocated_output_) {
65+
init_state_(rhs.init_state_) {
6866
// Required: clear out fields that the dtor looks at, so that we don't free
6967
// anything twice.
7068
rhs.n_value_ = 0;
@@ -82,8 +80,6 @@ class Method final {
8280
rhs.event_tracer_ = nullptr;
8381
rhs.n_chains_ = 0;
8482
rhs.chains_ = nullptr;
85-
rhs.pre_allocated_input_ = false;
86-
rhs.pre_allocated_output_ = false;
8783
}
8884

8985
/**
@@ -288,9 +284,7 @@ class Method final {
288284
delegates_(nullptr),
289285
n_chains_(0),
290286
chains_(nullptr),
291-
init_state_(InitializationState::Uninitialized),
292-
pre_allocated_input_(false),
293-
pre_allocated_output_(false) {}
287+
init_state_(InitializationState::Uninitialized) {}
294288

295289
/// Static factory used by Program.
296290
ET_NODISCARD static Result<Method> load(
@@ -336,8 +330,6 @@ class Method final {
336330
Chain* chains_;
337331

338332
InitializationState init_state_;
339-
bool pre_allocated_input_;
340-
bool pre_allocated_output_;
341333

342334
/**
343335
* Parses the elements of the values_ array. On error, n_value_ will be set to

runtime/executor/method_meta.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,9 @@ Result<TensorInfo> MethodMeta::input_tensor_meta(size_t index) const {
139139
Span<const uint8_t>(
140140
tensor_value->dim_order()->data(), tensor_value->dim_order()->size()),
141141
static_cast<exec_aten::ScalarType>(tensor_value->scalar_type()),
142-
tensor_value->allocation_info() != nullptr);
142+
tensor_value->allocation_info() != nullptr ||
143+
tensor_value->data_buffer_idx() !=
144+
0); // Count constant returns as memory planned.
143145
}
144146

145147
size_t MethodMeta::num_outputs() const {
@@ -170,15 +172,18 @@ Result<TensorInfo> MethodMeta::output_tensor_meta(size_t index) const {
170172
"Tag: %zu output: %zu is not Tensor",
171173
(size_t)tag.get(),
172174
index);
173-
auto input_index = s_plan_->outputs()->Get(index);
174-
auto tensor_value = s_plan_->values()->Get(input_index)->val_as_Tensor();
175+
auto output_index = s_plan_->outputs()->Get(index);
176+
auto tensor_value = s_plan_->values()->Get(output_index)->val_as_Tensor();
177+
175178
return TensorInfo(
176179
Span<const int32_t>(
177180
tensor_value->sizes()->data(), tensor_value->sizes()->size()),
178181
Span<const uint8_t>(
179182
tensor_value->dim_order()->data(), tensor_value->dim_order()->size()),
180183
static_cast<exec_aten::ScalarType>(tensor_value->scalar_type()),
181-
tensor_value->allocation_info() != nullptr);
184+
tensor_value->allocation_info() != nullptr ||
185+
tensor_value->data_buffer_idx() !=
186+
0); // Count constant returns as memory planned.
182187
}
183188

184189
size_t MethodMeta::num_memory_planned_buffers() const {

0 commit comments

Comments
 (0)