55
55
})
56
56
57
57
namespace py = pybind11;
58
- using ATTensor = at::Tensor;
59
58
namespace torch {
60
59
namespace executor {
61
60
@@ -134,15 +133,7 @@ class Module final {
134
133
135
134
// / Executes the specified method on the provided inputs and returns its
136
135
// / outputs.
137
- template <typename ... Types>
138
136
std::vector<EValue> run_method (
139
- const std::string& method_name,
140
- Types&&... args) {
141
- return run_method_internal (method_name, std::vector<EValue>{args...});
142
- }
143
-
144
- private:
145
- std::vector<EValue> run_method_internal (
146
137
const std::string& method_name,
147
138
const std::vector<EValue>& args) {
148
139
auto & method = methods_[method_name];
@@ -187,6 +178,7 @@ class Module final {
187
178
return result;
188
179
}
189
180
181
+ private:
190
182
// / A wrapper/util class for executorch memory allocations/manager.
191
183
class Memory {
192
184
public:
@@ -266,66 +258,6 @@ inline std::unique_ptr<Module> load_from_file(const std::string& path) {
266
258
return std::make_unique<Module>(std::move (loader));
267
259
}
268
260
269
- // Struct used to manage the memory of tensors allocated in lean (not aten) mode
270
- #ifdef USE_ATEN_LIB
271
- struct KeepAlive {};
272
- #else
273
- struct KeepAlive {
274
- std::vector<std::unique_ptr<exec_aten::TensorImpl>> tensors;
275
- torch::util::KeepAliveSizes sizes;
276
- };
277
- #endif
278
-
279
- EValue pyToEValue (py::handle h, KeepAlive& keep_alive) {
280
- const std::string& type_str = py::str (h.get_type ());
281
- EXECUTORCH_SCOPE_PROF (" pyToEValue" );
282
- if (type_str == " <class 'torch.Tensor'>" ) {
283
- auto atTensor = h.cast <ATTensor>();
284
- #ifdef USE_ATEN_LIB
285
- EValue evalue (atTensor);
286
- #else
287
- auto etensorImpl =
288
- torch::util::eTensorFromAtTensor (atTensor, keep_alive.sizes );
289
- EValue evalue (torch::executor::Tensor (etensorImpl.get ()));
290
- keep_alive.tensors .push_back (std::move (etensorImpl));
291
- #endif
292
- return evalue;
293
- } else if (py::isinstance<py::none>(h)) {
294
- return EValue ();
295
- } else if (py::isinstance<py::bool_>(h)) {
296
- return EValue (py::cast<bool >(h));
297
- } else if (py::isinstance<py::int_>(h)) {
298
- return EValue (py::cast<int64_t >(h));
299
- } else {
300
- // Unsupported pytype
301
- ET_ASSERT_UNREACHABLE_MSG (type_str.c_str ());
302
- }
303
- }
304
-
305
- py::object pyFromEValue (const EValue& v, KeepAlive& keep_alive) {
306
- EXECUTORCH_SCOPE_PROF (" pyFromEValue" );
307
- if (Tag::None == v.tag ) {
308
- return py::none ();
309
- } else if (Tag::Int == v.tag ) {
310
- return py::cast (v.toInt ());
311
- } else if (Tag::Double == v.tag ) {
312
- return py::cast (v.toDouble ());
313
- } else if (Tag::Bool == v.tag ) {
314
- return py::cast (v.toBool ());
315
- } else if (Tag::Tensor == v.tag ) {
316
- #ifdef USE_ATEN_LIB
317
- return py::cast (v.toTensor ().clone ());
318
- #else
319
- // Clone so the outputs in python do not share a lifetime with the module
320
- // object
321
- return py::cast (torch::util::atTensorFromETensor (
322
- v.toTensor ().unsafeGetTensorImpl (), keep_alive.sizes )
323
- .clone ());
324
- #endif
325
- }
326
- ET_ASSERT_UNREACHABLE ();
327
- }
328
-
329
261
static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U ;
330
262
331
263
struct PyBundledModule final {
@@ -406,19 +338,113 @@ struct PyModule final {
406
338
py::list run_method (
407
339
const std::string& method_name,
408
340
const py::sequence& inputs) {
409
- std::vector<EValue> cpp_inputs;
410
341
const auto inputs_size = py::len (inputs);
342
+ std::vector<EValue> cpp_inputs;
411
343
cpp_inputs.reserve (inputs_size);
344
+
345
+ #ifndef USE_ATEN_LIB // Portable mode
346
+ // So the ETensors and their metadata stay in scope for Module->run_method.
347
+ std::vector<torch::executor::TensorImpl> input_tensors;
348
+ std::vector<std::vector<torch::executor::Tensor::SizesType>> input_sizes;
349
+ std::vector<std::vector<torch::executor::Tensor::StridesType>>
350
+ input_strides;
351
+ std::vector<std::vector<torch::executor::Tensor::DimOrderType>>
352
+ input_dim_order;
353
+ // We store pointers to these vector elements so important to reserve so
354
+ // that we don't lose those on a vector resize. Don't need to do this for
355
+ // the others since they are vectors of vectors, and we don't store a
356
+ // pointer to the root level vector data.
357
+ input_tensors.reserve (inputs_size);
358
+ #endif
359
+
360
+ // Convert python objects into EValues.
412
361
for (size_t i = 0 ; i < inputs_size; ++i) {
413
- cpp_inputs.emplace_back (pyToEValue (inputs[i], keep_alive_));
362
+ auto python_input = inputs[i];
363
+ const std::string& type_str = py::str (python_input.get_type ());
364
+ if (type_str == " <class 'torch.Tensor'>" ) {
365
+ auto at_tensor = python_input.cast <at::Tensor>();
366
+ // alias_etensor_to_attensor will assert on this later, so to better
367
+ // propogate up to python we check early and throw an exception.
368
+ if (!at_tensor.is_contiguous ()) {
369
+ auto error_msg = " Input " + std::to_string (i) + " for method " +
370
+ method_name + " is not contiguous." ;
371
+ throw std::runtime_error (error_msg);
372
+ }
373
+
374
+ #ifdef USE_ATEN_LIB
375
+ EValue evalue (at_tensor);
376
+ #else
377
+ // convert at::Tensor to torch::executor::Tensor
378
+ auto type = torch::util::torchToExecuTorchScalarType (
379
+ at_tensor.options ().dtype ());
380
+ size_t dim = at_tensor.dim ();
381
+ // cant directly alias at::Tensor sizes and strides due to int64 vs
382
+ // int32 typing conflict
383
+ input_sizes.emplace_back (
384
+ at_tensor.sizes ().begin (), at_tensor.sizes ().end ());
385
+ input_strides.emplace_back (
386
+ at_tensor.strides ().begin (), at_tensor.strides ().end ());
387
+
388
+ // Only works for MemoryFormat::Contiguous inputs
389
+ std::vector<torch::executor::Tensor::DimOrderType> dim_order;
390
+ for (size_t cur_dim = 0 ; cur_dim < dim; cur_dim++) {
391
+ dim_order.push_back (cur_dim);
392
+ }
393
+ input_dim_order.push_back (std::move (dim_order));
394
+ input_tensors.emplace_back (
395
+ type,
396
+ dim,
397
+ input_sizes[i].data (),
398
+ nullptr ,
399
+ input_dim_order[i].data (),
400
+ input_strides[i].data ());
401
+
402
+ torch::executor::Tensor temp =
403
+ torch::executor::Tensor (&input_tensors[i]);
404
+ torch::util::alias_etensor_to_attensor (at_tensor, temp);
405
+ EValue evalue (temp);
406
+ #endif
407
+
408
+ cpp_inputs.push_back (evalue);
409
+ } else if (py::isinstance<py::none>(python_input)) {
410
+ cpp_inputs.push_back (EValue ());
411
+ } else if (py::isinstance<py::bool_>(python_input)) {
412
+ cpp_inputs.push_back (EValue (py::cast<bool >(python_input)));
413
+ } else if (py::isinstance<py::int_>(python_input)) {
414
+ cpp_inputs.push_back (EValue (py::cast<int64_t >(python_input)));
415
+ } else {
416
+ // Unsupported pytype
417
+ ET_ASSERT_UNREACHABLE_MSG (type_str.c_str ());
418
+ }
414
419
}
415
420
416
421
auto outputs = module_->run_method (method_name, cpp_inputs);
417
422
423
+ // Retrieve outputs
418
424
const auto outputs_size = outputs.size ();
419
425
py::list list (outputs_size);
420
426
for (size_t i = 0 ; i < outputs_size; ++i) {
421
- list[i] = pyFromEValue (outputs[i], keep_alive_);
427
+ auto & v = outputs[i];
428
+ if (Tag::None == v.tag ) {
429
+ list[i] = py::none ();
430
+ } else if (Tag::Int == v.tag ) {
431
+ list[i] = py::cast (v.toInt ());
432
+ } else if (Tag::Double == v.tag ) {
433
+ list[i] = py::cast (v.toDouble ());
434
+ } else if (Tag::Bool == v.tag ) {
435
+ list[i] = py::cast (v.toBool ());
436
+ } else if (Tag::Tensor == v.tag ) {
437
+ #ifdef USE_ATEN_LIB
438
+ // Clone so the outputs in python do not share a lifetime with the
439
+ // module object
440
+ list[i] = py::cast (v.toTensor ().clone ());
441
+ #else
442
+ list[i] = py::cast (
443
+ torch::util::alias_attensor_to_etensor (v.toTensor ()).clone ());
444
+ #endif
445
+ } else {
446
+ ET_ASSERT_UNREACHABLE_MSG (" Invalid model output type" );
447
+ }
422
448
}
423
449
return list;
424
450
}
@@ -428,7 +454,6 @@ struct PyModule final {
428
454
}
429
455
430
456
private:
431
- KeepAlive keep_alive_;
432
457
std::unique_ptr<Module> module_;
433
458
};
434
459
0 commit comments