Skip to content

Commit 1c92891

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Remove strange helpers for python -> tensor conversions, and at -> et tensor conversions (#121)
Summary: Pull Request resolved: #121 Keepalive was difficult to manage and Im not even sure how sound it was for things like strides and dim order. Now the run_method loop is a little grosser, but I think the intent is much clearer. We cast from python types to at types, and then possible have to do a hard conversion from at::tensor to ETensor which can be pretty gross since ETensor doesnt manage any memory. This highlights the need for like an ETensor wrapper that manages a bunch of state like sizes and TensorImpl Reviewed By: dbort Differential Revision: D48618092 fbshipit-source-id: 9cca1f51209f83649f6791af8eba0290ddde0922
1 parent 56f685f commit 1c92891

File tree

3 files changed

+98
-123
lines changed

3 files changed

+98
-123
lines changed

extension/aten_util/aten_bridge.cpp

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -147,39 +147,5 @@ at::Tensor alias_attensor_to_etensor(const torch::executor::Tensor& etensor) {
147147
check_tensor_meta(t, etensor);
148148
return t;
149149
}
150-
151-
std::unique_ptr<torch::executor::TensorImpl> eTensorFromAtTensor(
152-
const at::Tensor& tensor,
153-
KeepAliveSizes& keep_alive) {
154-
auto sizes = tensor.sizes();
155-
auto options = tensor.options();
156-
keep_alive.sizes32.emplace_back(sizes.size());
157-
auto& sizes32 = keep_alive.sizes32.back();
158-
for (size_t i = 0; i < sizes.size(); ++i) {
159-
// NOLINTNEXTLINE
160-
sizes32[i] = sizes[i];
161-
}
162-
163-
const torch::executor::ScalarType edtype =
164-
torchToExecuTorchScalarType(options.dtype());
165-
166-
return std::make_unique<torch::executor::TensorImpl>(
167-
edtype, sizes32.size(), sizes32.data(), tensor.mutable_data_ptr());
168-
}
169-
170-
at::Tensor atTensorFromETensor(
171-
torch::executor::TensorImpl* etensor,
172-
KeepAliveSizes& keep_alive) {
173-
c10::ScalarType dtype = execuTorchtoTorchScalarType(etensor->scalar_type());
174-
keep_alive.sizes64.emplace_back(etensor->sizes().size());
175-
auto& sizes64 = keep_alive.sizes64.back();
176-
for (size_t i = 0; i < etensor->sizes().size(); ++i) {
177-
// NOLINTNEXTLINE
178-
sizes64[i] = etensor->sizes()[i];
179-
}
180-
return at::from_blob(
181-
etensor->mutable_data(), sizes64, at::TensorOptions(dtype));
182-
}
183-
184150
} // namespace util
185151
} // namespace torch

extension/aten_util/aten_bridge.h

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,6 @@
2121
namespace torch {
2222
namespace util {
2323

24-
using sizes32_t = std::vector<int32_t>;
25-
using sizes64_t = std::vector<int64_t>;
26-
27-
struct KeepAliveSizes {
28-
std::vector<sizes32_t> sizes32;
29-
std::vector<sizes64_t> sizes64;
30-
};
31-
32-
// TODO: we should really remove this as
33-
__ET_DEPRECATED std::unique_ptr<torch::executor::TensorImpl>
34-
eTensorFromAtTensor(const at::Tensor& tensor, KeepAliveSizes& keep_alive);
35-
36-
__ET_DEPRECATED at::Tensor atTensorFromETensor(
37-
torch::executor::TensorImpl* etensor,
38-
KeepAliveSizes& keep_alive);
39-
4024
torch::executor::ScalarType torchToExecuTorchScalarType(caffe2::TypeMeta type);
4125

4226
c10::ScalarType execuTorchtoTorchScalarType(torch::executor::ScalarType type);

extension/pybindings/module.cpp

Lines changed: 98 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
})
5656

5757
namespace py = pybind11;
58-
using ATTensor = at::Tensor;
5958
namespace torch {
6059
namespace executor {
6160

@@ -134,15 +133,7 @@ class Module final {
134133

135134
/// Executes the specified method on the provided inputs and returns its
136135
/// outputs.
137-
template <typename... Types>
138136
std::vector<EValue> run_method(
139-
const std::string& method_name,
140-
Types&&... args) {
141-
return run_method_internal(method_name, std::vector<EValue>{args...});
142-
}
143-
144-
private:
145-
std::vector<EValue> run_method_internal(
146137
const std::string& method_name,
147138
const std::vector<EValue>& args) {
148139
auto& method = methods_[method_name];
@@ -187,6 +178,7 @@ class Module final {
187178
return result;
188179
}
189180

181+
private:
190182
/// A wrapper/util class for executorch memory allocations/manager.
191183
class Memory {
192184
public:
@@ -266,66 +258,6 @@ inline std::unique_ptr<Module> load_from_file(const std::string& path) {
266258
return std::make_unique<Module>(std::move(loader));
267259
}
268260

269-
// Struct used to manage the memory of tensors allocated in lean (not aten) mode
270-
#ifdef USE_ATEN_LIB
271-
struct KeepAlive {};
272-
#else
273-
struct KeepAlive {
274-
std::vector<std::unique_ptr<exec_aten::TensorImpl>> tensors;
275-
torch::util::KeepAliveSizes sizes;
276-
};
277-
#endif
278-
279-
EValue pyToEValue(py::handle h, KeepAlive& keep_alive) {
280-
const std::string& type_str = py::str(h.get_type());
281-
EXECUTORCH_SCOPE_PROF("pyToEValue");
282-
if (type_str == "<class 'torch.Tensor'>") {
283-
auto atTensor = h.cast<ATTensor>();
284-
#ifdef USE_ATEN_LIB
285-
EValue evalue(atTensor);
286-
#else
287-
auto etensorImpl =
288-
torch::util::eTensorFromAtTensor(atTensor, keep_alive.sizes);
289-
EValue evalue(torch::executor::Tensor(etensorImpl.get()));
290-
keep_alive.tensors.push_back(std::move(etensorImpl));
291-
#endif
292-
return evalue;
293-
} else if (py::isinstance<py::none>(h)) {
294-
return EValue();
295-
} else if (py::isinstance<py::bool_>(h)) {
296-
return EValue(py::cast<bool>(h));
297-
} else if (py::isinstance<py::int_>(h)) {
298-
return EValue(py::cast<int64_t>(h));
299-
} else {
300-
// Unsupported pytype
301-
ET_ASSERT_UNREACHABLE_MSG(type_str.c_str());
302-
}
303-
}
304-
305-
py::object pyFromEValue(const EValue& v, KeepAlive& keep_alive) {
306-
EXECUTORCH_SCOPE_PROF("pyFromEValue");
307-
if (Tag::None == v.tag) {
308-
return py::none();
309-
} else if (Tag::Int == v.tag) {
310-
return py::cast(v.toInt());
311-
} else if (Tag::Double == v.tag) {
312-
return py::cast(v.toDouble());
313-
} else if (Tag::Bool == v.tag) {
314-
return py::cast(v.toBool());
315-
} else if (Tag::Tensor == v.tag) {
316-
#ifdef USE_ATEN_LIB
317-
return py::cast(v.toTensor().clone());
318-
#else
319-
// Clone so the outputs in python do not share a lifetime with the module
320-
// object
321-
return py::cast(torch::util::atTensorFromETensor(
322-
v.toTensor().unsafeGetTensorImpl(), keep_alive.sizes)
323-
.clone());
324-
#endif
325-
}
326-
ET_ASSERT_UNREACHABLE();
327-
}
328-
329261
static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U;
330262

331263
struct PyBundledModule final {
@@ -406,19 +338,113 @@ struct PyModule final {
406338
py::list run_method(
407339
const std::string& method_name,
408340
const py::sequence& inputs) {
409-
std::vector<EValue> cpp_inputs;
410341
const auto inputs_size = py::len(inputs);
342+
std::vector<EValue> cpp_inputs;
411343
cpp_inputs.reserve(inputs_size);
344+
345+
#ifndef USE_ATEN_LIB // Portable mode
346+
// So the ETensors and their metadata stay in scope for Module->run_method.
347+
std::vector<torch::executor::TensorImpl> input_tensors;
348+
std::vector<std::vector<torch::executor::Tensor::SizesType>> input_sizes;
349+
std::vector<std::vector<torch::executor::Tensor::StridesType>>
350+
input_strides;
351+
std::vector<std::vector<torch::executor::Tensor::DimOrderType>>
352+
input_dim_order;
353+
// We store pointers to these vector elements so important to reserve so
354+
// that we don't lose those on a vector resize. Don't need to do this for
355+
// the others since they are vectors of vectors, and we don't store a
356+
// pointer to the root level vector data.
357+
input_tensors.reserve(inputs_size);
358+
#endif
359+
360+
// Convert python objects into EValues.
412361
for (size_t i = 0; i < inputs_size; ++i) {
413-
cpp_inputs.emplace_back(pyToEValue(inputs[i], keep_alive_));
362+
auto python_input = inputs[i];
363+
const std::string& type_str = py::str(python_input.get_type());
364+
if (type_str == "<class 'torch.Tensor'>") {
365+
auto at_tensor = python_input.cast<at::Tensor>();
366+
// alias_etensor_to_attensor will assert on this later, so to better
367+
// propogate up to python we check early and throw an exception.
368+
if (!at_tensor.is_contiguous()) {
369+
auto error_msg = "Input " + std::to_string(i) + "for method " +
370+
method_name + " is not contiguous.";
371+
throw std::runtime_error(error_msg);
372+
}
373+
374+
#ifdef USE_ATEN_LIB
375+
EValue evalue(at_tensor);
376+
#else
377+
// convert at::Tensor to torch::executor::Tensor
378+
auto type = torch::util::torchToExecuTorchScalarType(
379+
at_tensor.options().dtype());
380+
size_t dim = at_tensor.dim();
381+
// cant directly alias at::Tensor sizes and strides due to int64 vs
382+
// int32 typing conflict
383+
input_sizes.emplace_back(
384+
at_tensor.sizes().begin(), at_tensor.sizes().end());
385+
input_strides.emplace_back(
386+
at_tensor.strides().begin(), at_tensor.strides().end());
387+
388+
// Only works for MemoryFormat::Contiguous inputs
389+
std::vector<torch::executor::Tensor::DimOrderType> dim_order;
390+
for (size_t cur_dim = 0; cur_dim < dim; cur_dim++) {
391+
dim_order.push_back(cur_dim);
392+
}
393+
input_dim_order.push_back(std::move(dim_order));
394+
input_tensors.emplace_back(
395+
type,
396+
dim,
397+
input_sizes[i].data(),
398+
nullptr,
399+
input_dim_order[i].data(),
400+
input_strides[i].data());
401+
402+
torch::executor::Tensor temp =
403+
torch::executor::Tensor(&input_tensors[i]);
404+
torch::util::alias_etensor_to_attensor(at_tensor, temp);
405+
EValue evalue(temp);
406+
#endif
407+
408+
cpp_inputs.push_back(evalue);
409+
} else if (py::isinstance<py::none>(python_input)) {
410+
cpp_inputs.push_back(EValue());
411+
} else if (py::isinstance<py::bool_>(python_input)) {
412+
cpp_inputs.push_back(EValue(py::cast<bool>(python_input)));
413+
} else if (py::isinstance<py::int_>(python_input)) {
414+
cpp_inputs.push_back(EValue(py::cast<int64_t>(python_input)));
415+
} else {
416+
// Unsupported pytype
417+
ET_ASSERT_UNREACHABLE_MSG(type_str.c_str());
418+
}
414419
}
415420

416421
auto outputs = module_->run_method(method_name, cpp_inputs);
417422

423+
// Retrieve outputs
418424
const auto outputs_size = outputs.size();
419425
py::list list(outputs_size);
420426
for (size_t i = 0; i < outputs_size; ++i) {
421-
list[i] = pyFromEValue(outputs[i], keep_alive_);
427+
auto& v = outputs[i];
428+
if (Tag::None == v.tag) {
429+
list[i] = py::none();
430+
} else if (Tag::Int == v.tag) {
431+
list[i] = py::cast(v.toInt());
432+
} else if (Tag::Double == v.tag) {
433+
list[i] = py::cast(v.toDouble());
434+
} else if (Tag::Bool == v.tag) {
435+
list[i] = py::cast(v.toBool());
436+
} else if (Tag::Tensor == v.tag) {
437+
#ifdef USE_ATEN_LIB
438+
// Clone so the outputs in python do not share a lifetime with the
439+
// module object
440+
list[i] = py::cast(v.toTensor().clone());
441+
#else
442+
list[i] = py::cast(
443+
torch::util::alias_attensor_to_etensor(v.toTensor()).clone());
444+
#endif
445+
} else {
446+
ET_ASSERT_UNREACHABLE_MSG("Invalid model output type");
447+
}
422448
}
423449
return list;
424450
}
@@ -428,7 +454,6 @@ struct PyModule final {
428454
}
429455

430456
private:
431-
KeepAlive keep_alive_;
432457
std::unique_ptr<Module> module_;
433458
};
434459

0 commit comments

Comments
 (0)