|
24 | 24 | #include <executorch/extension/data_loader/mmap_data_loader.h>
|
25 | 25 | #include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
|
26 | 26 | #include <executorch/runtime/core/data_loader.h>
|
| 27 | +#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h> |
27 | 28 | #include <executorch/runtime/executor/method.h>
|
28 | 29 | #include <executorch/runtime/executor/program.h>
|
29 | 30 | #include <executorch/runtime/kernel/operator_registry.h>
|
|
55 | 56 | } \
|
56 | 57 | })
|
57 | 58 |
|
| 59 | +#define THROW_INDEX_IF_ERROR(error, message, ...) \ |
| 60 | + ({ \ |
| 61 | + if ((error) != Error::Ok) { \ |
| 62 | + char msg_buf[128]; \ |
| 63 | + snprintf(msg_buf, sizeof(msg_buf), message, ##__VA_ARGS__); \ |
| 64 | + /* pybind will convert this to a python exception. */ \ |
| 65 | + throw std::out_of_range(msg_buf); \ |
| 66 | + } \ |
| 67 | + }) |
| 68 | + |
58 | 69 | // Our logs work by writing to stderr. By default this is done through fprintf
|
59 | 70 | // (as defined in posix.cpp) which then does not show up in python environments.
|
60 | 71 | // Here we override the pal to use std::cerr which can be properly redirected by
|
@@ -448,6 +459,119 @@ struct PyBundledModule final {
|
448 | 459 | size_t program_len_;
|
449 | 460 | };
|
450 | 461 |
|
| 462 | +/// Expose a subset of TensorInfo information to python. |
| 463 | +struct PyTensorInfo final { |
| 464 | + explicit PyTensorInfo( |
| 465 | + std::shared_ptr<Module> module, |
| 466 | + torch::executor::TensorInfo info) |
| 467 | + : module_(std::move(module)), info_(info) {} |
| 468 | + |
| 469 | + py::tuple sizes() const { |
| 470 | + const auto shape = info_.sizes(); |
| 471 | + py::tuple tup(shape.size()); |
| 472 | + for (size_t i = 0; i < shape.size(); ++i) { |
| 473 | + tup[i] = py::cast(shape[i]); |
| 474 | + } |
| 475 | + return tup; |
| 476 | + } |
| 477 | + |
| 478 | + int8_t dtype() const { |
| 479 | + return static_cast<std::underlying_type<exec_aten::ScalarType>::type>( |
| 480 | + info_.scalar_type()); |
| 481 | + } |
| 482 | + |
| 483 | + bool is_memory_planned() const { |
| 484 | + return info_.is_memory_planned(); |
| 485 | + } |
| 486 | + |
| 487 | + size_t nbytes() const { |
| 488 | + return info_.nbytes(); |
| 489 | + } |
| 490 | + |
| 491 | + std::string repr() const { |
| 492 | + std::string size_str = "["; |
| 493 | + for (const auto& d : info_.sizes()) { |
| 494 | + size_str.append(std::to_string(d)); |
| 495 | + size_str.append(", "); |
| 496 | + } |
| 497 | + if (size_str.length() >= 2) { |
| 498 | + // Pop the last two characters (command and space) and add close bracket. |
| 499 | + size_str.pop_back(); |
| 500 | + size_str.pop_back(); |
| 501 | + } |
| 502 | + size_str.append("]"); |
| 503 | + return "TensorInfo(sizes=" + size_str + ", dtype=" + |
| 504 | + std::string(executorch::runtime::toString(info_.scalar_type())) + |
| 505 | + ", is_memory_planned=" + |
| 506 | + (info_.is_memory_planned() ? "True" : "False") + |
| 507 | + ", nbytes=" + std::to_string(info_.nbytes()) + ")"; |
| 508 | + } |
| 509 | + |
| 510 | + private: |
| 511 | + // TensorInfo relies on module to be alive. |
| 512 | + std::shared_ptr<Module> module_; |
| 513 | + torch::executor::TensorInfo info_; |
| 514 | +}; |
| 515 | + |
| 516 | +/// Expose a subset of MethodMeta information to python. |
| 517 | +struct PyMethodMeta final { |
| 518 | + explicit PyMethodMeta( |
| 519 | + std::shared_ptr<Module> module, |
| 520 | + torch::executor::MethodMeta meta) |
| 521 | + : module_(std::move(module)), meta_(meta) {} |
| 522 | + |
| 523 | + const char* name() const { |
| 524 | + return meta_.name(); |
| 525 | + } |
| 526 | + |
| 527 | + size_t num_inputs() const { |
| 528 | + return meta_.num_inputs(); |
| 529 | + } |
| 530 | + |
| 531 | + std::unique_ptr<PyTensorInfo> input_tensor_meta(size_t index) const { |
| 532 | + const auto result = meta_.input_tensor_meta(index); |
| 533 | + THROW_INDEX_IF_ERROR( |
| 534 | + result.error(), "Cannot get input tensor meta at %zu", index); |
| 535 | + return std::make_unique<PyTensorInfo>(module_, result.get()); |
| 536 | + } |
| 537 | + |
| 538 | + size_t num_outputs() const { |
| 539 | + return meta_.num_outputs(); |
| 540 | + } |
| 541 | + |
| 542 | + std::unique_ptr<PyTensorInfo> output_tensor_meta(size_t index) const { |
| 543 | + const auto result = meta_.output_tensor_meta(index); |
| 544 | + THROW_INDEX_IF_ERROR( |
| 545 | + result.error(), "Cannot get output tensor meta at %zu", index); |
| 546 | + return std::make_unique<PyTensorInfo>(module_, result.get()); |
| 547 | + } |
| 548 | + |
| 549 | + py::str repr() const { |
| 550 | + py::list input_meta_strs; |
| 551 | + for (size_t i = 0; i < meta_.num_inputs(); ++i) { |
| 552 | + input_meta_strs.append(py::str(input_tensor_meta(i)->repr())); |
| 553 | + } |
| 554 | + py::list output_meta_strs; |
| 555 | + for (size_t i = 0; i < meta_.num_outputs(); ++i) { |
| 556 | + output_meta_strs.append(py::str(output_tensor_meta(i)->repr())); |
| 557 | + } |
| 558 | + // Add quotes to be more similar to Python's repr for strings. |
| 559 | + py::str format = |
| 560 | + "MethodMeta(name='{}', num_inputs={}, input_tensor_meta={}, num_outputs={}, output_tensor_meta={})"; |
| 561 | + return format.format( |
| 562 | + std::string(meta_.name()), |
| 563 | + std::to_string(meta_.num_inputs()), |
| 564 | + input_meta_strs, |
| 565 | + std::to_string(meta_.num_outputs()), |
| 566 | + output_meta_strs); |
| 567 | + } |
| 568 | + |
| 569 | + private: |
| 570 | + // Must keep the Module object alive or else the meta object is invalidated. |
| 571 | + std::shared_ptr<Module> module_; |
| 572 | + torch::executor::MethodMeta meta_; |
| 573 | +}; |
| 574 | + |
451 | 575 | struct PyModule final {
|
452 | 576 | explicit PyModule(
|
453 | 577 | const py::bytes& buffer,
|
@@ -751,8 +875,13 @@ struct PyModule final {
|
751 | 875 | return list;
|
752 | 876 | }
|
753 | 877 |
|
| 878 | + std::unique_ptr<PyMethodMeta> method_meta(const std::string method_name) { |
| 879 | + auto& method = module_->get_method(method_name); |
| 880 | + return std::make_unique<PyMethodMeta>(module_, method.method_meta()); |
| 881 | + } |
| 882 | + |
754 | 883 | private:
|
755 |
| - std::unique_ptr<Module> module_; |
| 884 | + std::shared_ptr<Module> module_; |
756 | 885 | // Need to keep-alive output storages until they can be compared in case of
|
757 | 886 | // bundled programs.
|
758 | 887 | std::vector<std::vector<uint8_t>> output_storages_;
|
@@ -866,6 +995,11 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
|
866 | 995 | py::arg("method_name"),
|
867 | 996 | py::arg("clone_outputs") = true,
|
868 | 997 | call_guard)
|
| 998 | + .def( |
| 999 | + "method_meta", |
| 1000 | + &PyModule::method_meta, |
| 1001 | + py::arg("method_name"), |
| 1002 | + call_guard) |
869 | 1003 | .def(
|
870 | 1004 | "run_method",
|
871 | 1005 | &PyModule::run_method,
|
@@ -900,6 +1034,27 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
|
900 | 1034 | call_guard);
|
901 | 1035 |
|
902 | 1036 | py::class_<PyBundledModule>(m, "BundledModule");
|
| 1037 | + py::class_<PyTensorInfo>(m, "TensorInfo") |
| 1038 | + .def("sizes", &PyTensorInfo::sizes, call_guard) |
| 1039 | + .def("dtype", &PyTensorInfo::dtype, call_guard) |
| 1040 | + .def("is_memory_planned", &PyTensorInfo::is_memory_planned, call_guard) |
| 1041 | + .def("nbytes", &PyTensorInfo::nbytes, call_guard) |
| 1042 | + .def("__repr__", &PyTensorInfo::repr, call_guard); |
| 1043 | + py::class_<PyMethodMeta>(m, "MethodMeta") |
| 1044 | + .def("name", &PyMethodMeta::name, call_guard) |
| 1045 | + .def("num_inputs", &PyMethodMeta::num_inputs, call_guard) |
| 1046 | + .def("num_outputs", &PyMethodMeta::num_outputs, call_guard) |
| 1047 | + .def( |
| 1048 | + "input_tensor_meta", |
| 1049 | + &PyMethodMeta::input_tensor_meta, |
| 1050 | + py::arg("index"), |
| 1051 | + call_guard) |
| 1052 | + .def( |
| 1053 | + "output_tensor_meta", |
| 1054 | + &PyMethodMeta::output_tensor_meta, |
| 1055 | + py::arg("index"), |
| 1056 | + call_guard) |
| 1057 | + .def("__repr__", &PyMethodMeta::repr, call_guard); |
903 | 1058 | }
|
904 | 1059 |
|
905 | 1060 | } // namespace pybindings
|
|
0 commit comments