Skip to content

Commit 005429d

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Untie the lifetime of returns and the lifetime of module in pybindings
Summary: In cpp these are tied to not force a clone on everyone due to performance overhead and there are ways around it with customizing your own output buffer. In python which is just for testing really its fine to always clone and avoid unintuitive behavior Reviewed By: dulinriley Differential Revision: D47933661 fbshipit-source-id: 84178bfd5e9e055078b2d7cfd2c6739df2b3e1ea
1 parent e822e97 commit 005429d

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

extension/pybindings/module.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -270,12 +270,13 @@ py::object pyFromEValue(const EValue& v, KeepAlive& keep_alive) {
270270
return py::cast(v.toBool());
271271
} else if (Tag::Tensor == v.tag) {
272272
#ifdef USE_ATEN_LIB
273-
return py::cast(v.toTensor());
273+
return py::cast(v.toTensor().clone());
274274
#else
275-
return py::cast(
276-
torch::util::atTensorFromETensor(
277-
v.toTensor().unsafeGetTensorImpl(), keep_alive.sizes),
278-
py::return_value_policy::reference);
275+
// Clone so the outputs in python do not share a lifetime with the module
276+
// object
277+
return py::cast(torch::util::atTensorFromETensor(
278+
v.toTensor().unsafeGetTensorImpl(), keep_alive.sizes)
279+
.clone());
279280
#endif
280281
}
281282
ET_ASSERT_UNREACHABLE();

extension/pybindings/test/test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,14 @@ def test_multiple_entry(self):
157157

158158
executorch_output2 = executorch_module.run_method("forward2", inputs)[0]
159159
self.assertTrue(torch.allclose(executorch_output2, torch.ones(2, 2) * 3))
160+
161+
def test_output_lifespan(self):
162+
def lower_function_call():
163+
program, inputs = create_program(ModuleMulti())
164+
executorch_module = _load_for_executorch_from_buffer(program.buffer)
165+
166+
return executorch_module.forward(inputs)
167+
# executorch_module is destructed here and all of its memory is freed
168+
169+
outputs = lower_function_call()
170+
self.assertTrue(torch.allclose(outputs[0], torch.ones(2, 2) * 2))

0 commit comments

Comments
 (0)