Skip to content

Commit 02c1b3d

Browse files
dulinrileylarryliu0820
authored andcommitted
Add MethodMeta object for python visibility (#5571)
Summary: Pull Request resolved: #5571 Some clients and consumers of the Executorch program files (.pte) were requesting ways to access metadata like the sizes of tensors and the number of bytes they needed. When I told them how to access them in C++, they requested Python wrappers since they had processing scripts written in Python. Add some implementations of MethodMeta and TensorInfo methods. Note that these become more expensive than in C++ because they need to allocate python objects, but I doubt these are used in performance-sensitive applications anyway. And dealing with lifetimes of mixed C++/Python objects is complex, so I favored simple lifetimes. Reviewed By: dbort Differential Revision: D63288433 fbshipit-source-id: af775120a8ebd9bf455671a8ce1f158259aa50e6
1 parent eca44f0 commit 02c1b3d

File tree

3 files changed

+300
-1
lines changed

3 files changed

+300
-1
lines changed

extension/pybindings/pybindings.cpp

Lines changed: 156 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <executorch/extension/data_loader/mmap_data_loader.h>
2525
#include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
2626
#include <executorch/runtime/core/data_loader.h>
27+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
2728
#include <executorch/runtime/executor/method.h>
2829
#include <executorch/runtime/executor/program.h>
2930
#include <executorch/runtime/kernel/operator_registry.h>
@@ -55,6 +56,16 @@
5556
} \
5657
})
5758

59+
#define THROW_INDEX_IF_ERROR(error, message, ...) \
60+
({ \
61+
if ((error) != Error::Ok) { \
62+
char msg_buf[128]; \
63+
snprintf(msg_buf, sizeof(msg_buf), message, ##__VA_ARGS__); \
64+
/* pybind will convert this to a python exception. */ \
65+
throw std::out_of_range(msg_buf); \
66+
} \
67+
})
68+
5869
// Our logs work by writing to stderr. By default this is done through fprintf
5970
// (as defined in posix.cpp) which then does not show up in python environments.
6071
// Here we override the pal to use std::cerr which can be properly redirected by
@@ -448,6 +459,119 @@ struct PyBundledModule final {
448459
size_t program_len_;
449460
};
450461

462+
/// Expose a subset of TensorInfo information to python.
463+
struct PyTensorInfo final {
464+
explicit PyTensorInfo(
465+
std::shared_ptr<Module> module,
466+
torch::executor::TensorInfo info)
467+
: module_(std::move(module)), info_(info) {}
468+
469+
py::tuple sizes() const {
470+
const auto shape = info_.sizes();
471+
py::tuple tup(shape.size());
472+
for (size_t i = 0; i < shape.size(); ++i) {
473+
tup[i] = py::cast(shape[i]);
474+
}
475+
return tup;
476+
}
477+
478+
int8_t dtype() const {
479+
return static_cast<std::underlying_type<exec_aten::ScalarType>::type>(
480+
info_.scalar_type());
481+
}
482+
483+
bool is_memory_planned() const {
484+
return info_.is_memory_planned();
485+
}
486+
487+
size_t nbytes() const {
488+
return info_.nbytes();
489+
}
490+
491+
std::string repr() const {
492+
std::string size_str = "[";
493+
for (const auto& d : info_.sizes()) {
494+
size_str.append(std::to_string(d));
495+
size_str.append(", ");
496+
}
497+
if (size_str.length() >= 2) {
498+
// Pop the last two characters (command and space) and add close bracket.
499+
size_str.pop_back();
500+
size_str.pop_back();
501+
}
502+
size_str.append("]");
503+
return "TensorInfo(sizes=" + size_str + ", dtype=" +
504+
std::string(executorch::runtime::toString(info_.scalar_type())) +
505+
", is_memory_planned=" +
506+
(info_.is_memory_planned() ? "True" : "False") +
507+
", nbytes=" + std::to_string(info_.nbytes()) + ")";
508+
}
509+
510+
private:
511+
// TensorInfo relies on module to be alive.
512+
std::shared_ptr<Module> module_;
513+
torch::executor::TensorInfo info_;
514+
};
515+
516+
/// Expose a subset of MethodMeta information to python.
517+
struct PyMethodMeta final {
518+
explicit PyMethodMeta(
519+
std::shared_ptr<Module> module,
520+
torch::executor::MethodMeta meta)
521+
: module_(std::move(module)), meta_(meta) {}
522+
523+
const char* name() const {
524+
return meta_.name();
525+
}
526+
527+
size_t num_inputs() const {
528+
return meta_.num_inputs();
529+
}
530+
531+
std::unique_ptr<PyTensorInfo> input_tensor_meta(size_t index) const {
532+
const auto result = meta_.input_tensor_meta(index);
533+
THROW_INDEX_IF_ERROR(
534+
result.error(), "Cannot get input tensor meta at %zu", index);
535+
return std::make_unique<PyTensorInfo>(module_, result.get());
536+
}
537+
538+
size_t num_outputs() const {
539+
return meta_.num_outputs();
540+
}
541+
542+
std::unique_ptr<PyTensorInfo> output_tensor_meta(size_t index) const {
543+
const auto result = meta_.output_tensor_meta(index);
544+
THROW_INDEX_IF_ERROR(
545+
result.error(), "Cannot get output tensor meta at %zu", index);
546+
return std::make_unique<PyTensorInfo>(module_, result.get());
547+
}
548+
549+
py::str repr() const {
550+
py::list input_meta_strs;
551+
for (size_t i = 0; i < meta_.num_inputs(); ++i) {
552+
input_meta_strs.append(py::str(input_tensor_meta(i)->repr()));
553+
}
554+
py::list output_meta_strs;
555+
for (size_t i = 0; i < meta_.num_outputs(); ++i) {
556+
output_meta_strs.append(py::str(output_tensor_meta(i)->repr()));
557+
}
558+
// Add quotes to be more similar to Python's repr for strings.
559+
py::str format =
560+
"MethodMeta(name='{}', num_inputs={}, input_tensor_meta={}, num_outputs={}, output_tensor_meta={})";
561+
return format.format(
562+
std::string(meta_.name()),
563+
std::to_string(meta_.num_inputs()),
564+
input_meta_strs,
565+
std::to_string(meta_.num_outputs()),
566+
output_meta_strs);
567+
}
568+
569+
private:
570+
// Must keep the Module object alive or else the meta object is invalidated.
571+
std::shared_ptr<Module> module_;
572+
torch::executor::MethodMeta meta_;
573+
};
574+
451575
struct PyModule final {
452576
explicit PyModule(
453577
const py::bytes& buffer,
@@ -751,8 +875,13 @@ struct PyModule final {
751875
return list;
752876
}
753877

878+
std::unique_ptr<PyMethodMeta> method_meta(const std::string method_name) {
879+
auto& method = module_->get_method(method_name);
880+
return std::make_unique<PyMethodMeta>(module_, method.method_meta());
881+
}
882+
754883
private:
755-
std::unique_ptr<Module> module_;
884+
std::shared_ptr<Module> module_;
756885
// Need to keep-alive output storages until they can be compared in case of
757886
// bundled programs.
758887
std::vector<std::vector<uint8_t>> output_storages_;
@@ -866,6 +995,11 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
866995
py::arg("method_name"),
867996
py::arg("clone_outputs") = true,
868997
call_guard)
998+
.def(
999+
"method_meta",
1000+
&PyModule::method_meta,
1001+
py::arg("method_name"),
1002+
call_guard)
8691003
.def(
8701004
"run_method",
8711005
&PyModule::run_method,
@@ -900,6 +1034,27 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
9001034
call_guard);
9011035

9021036
py::class_<PyBundledModule>(m, "BundledModule");
1037+
py::class_<PyTensorInfo>(m, "TensorInfo")
1038+
.def("sizes", &PyTensorInfo::sizes, call_guard)
1039+
.def("dtype", &PyTensorInfo::dtype, call_guard)
1040+
.def("is_memory_planned", &PyTensorInfo::is_memory_planned, call_guard)
1041+
.def("nbytes", &PyTensorInfo::nbytes, call_guard)
1042+
.def("__repr__", &PyTensorInfo::repr, call_guard);
1043+
py::class_<PyMethodMeta>(m, "MethodMeta")
1044+
.def("name", &PyMethodMeta::name, call_guard)
1045+
.def("num_inputs", &PyMethodMeta::num_inputs, call_guard)
1046+
.def("num_outputs", &PyMethodMeta::num_outputs, call_guard)
1047+
.def(
1048+
"input_tensor_meta",
1049+
&PyMethodMeta::input_tensor_meta,
1050+
py::arg("index"),
1051+
call_guard)
1052+
.def(
1053+
"output_tensor_meta",
1054+
&PyMethodMeta::output_tensor_meta,
1055+
py::arg("index"),
1056+
call_guard)
1057+
.def("__repr__", &PyMethodMeta::repr, call_guard);
9031058
}
9041059

9051060
} // namespace pybindings

extension/pybindings/pybindings.pyi

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-strict
8+
from __future__ import annotations
9+
810
from typing import Any, Dict, List, Optional, Sequence, Tuple
911

1012
from executorch.exir._warnings import experimental
@@ -43,6 +45,7 @@ class ExecuTorchModule:
4345
def write_etdump_result_to_file(
4446
self, path: str, debug_buffer_path: Optional[str] = None
4547
) -> None: ...
48+
def method_meta(self, method_name: str) -> MethodMeta: ...
4649

4750
@experimental("This API is experimental and subject to change without notice.")
4851
class BundledModule:
@@ -54,6 +57,72 @@ class BundledModule:
5457

5558
...
5659

60+
@experimental("This API is experimental and subject to change without notice.")
61+
class TensorInfo:
62+
"""Metadata about a tensor such as the shape and dtype.
63+
64+
.. warning::
65+
66+
This API is experimental and subject to change without notice.
67+
"""
68+
69+
def sizes(self) -> Tuple[int, ...]:
70+
"""Shape of the tensor as a tuple"""
71+
...
72+
73+
def dtype(self) -> int:
74+
"""The data type of the elements inside the tensor.
75+
See documentation for ScalarType in executorch/runtime/core/portable_type/scalar_type.h
76+
for the values these integers can take."""
77+
...
78+
79+
def is_memory_planned(self) -> bool:
80+
"""True if the tensor is already memory planned, meaning no allocation
81+
needs to be provided. False otherwise"""
82+
...
83+
84+
def nbytes(self) -> int:
85+
"""Number of bytes in the tensor. Not the same as numel if the dtype is
86+
larger than 1 byte wide"""
87+
...
88+
89+
def __repr__(self) -> str: ...
90+
91+
@experimental("This API is experimental and subject to change without notice.")
92+
class MethodMeta:
93+
"""Metadata about a method such as the number of inputs and outputs.
94+
95+
.. warning::
96+
97+
This API is experimental and subject to change without notice.
98+
"""
99+
100+
def name(self) -> str:
101+
"""The name of the method, such as 'forward'"""
102+
...
103+
104+
def num_inputs(self) -> int:
105+
"""The number of user inputs to the method. This does not include any
106+
internal buffers or weights, which don't need to be provided by the user"""
107+
...
108+
109+
def num_outputs(self) -> int:
110+
"""The number of outputs from the method. This does not include any mutated
111+
internal buffers"""
112+
...
113+
114+
def input_tensor_meta(self, index: int) -> TensorInfo:
115+
"""The tensor info for the 'index'th input. Index must be in the interval
116+
[0, num_inputs()). Raises an IndexError if the index is out of bounds"""
117+
...
118+
119+
def output_tensor_meta(self, index: int) -> TensorInfo:
120+
"""The tensor info for the 'index'th output. Index must be in the interval
121+
[0, num_outputs()). Raises an IndexError if the index is out of bounds"""
122+
...
123+
124+
def __repr__(self) -> str: ...
125+
57126
@experimental("This API is experimental and subject to change without notice.")
58127
def _load_for_executorch(
59128
path: str, enable_etdump: bool = False, debug_buffer_size: int = 0

extension/pybindings/test/make_test.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,12 +251,87 @@ def test_quantized_ops(tester):
251251
expected = example_inputs[0] + example_inputs[1]
252252
tester.assertEqual(str(expected), str(executorch_output))
253253

254+
def test_constant_output_not_memory_planned(tester):
255+
# Create an ExecuTorch program from ModuleAdd.
256+
exported_program, inputs = create_program(
257+
ModuleAddConstReturn(),
258+
et_config=ExecutorchBackendConfig(
259+
memory_planning_pass=MemoryPlanningPass(alloc_graph_output=False)
260+
),
261+
)
262+
263+
exported_program.dump_executorch_program(verbose=True)
264+
265+
# Use pybindings to load and execute the program.
266+
executorch_module = load_fn(exported_program.buffer)
267+
# Invoke the callable on executorch_module instead of calling module.forward.
268+
# Use only one input to test this case.
269+
executorch_output = executorch_module((torch.ones(2, 2),))
270+
print(executorch_output)
271+
272+
# The test module adds the input to torch.ones(2,2), so its output should be the same
273+
# as adding them directly.
274+
expected = torch.ones(2, 2) + torch.ones(2, 2)
275+
tester.assertEqual(str(expected), str(executorch_output[0]))
276+
277+
# The test module returns the state. Check that its value is correct.
278+
tester.assertEqual(str(torch.ones(2, 2)), str(executorch_output[1]))
279+
280+
def test_method_meta(tester) -> None:
281+
# pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
282+
exported_program, inputs = create_program(ModuleAdd())
283+
284+
# Use pybindings to load the program and query its metadata.
285+
executorch_module = load_fn(exported_program.buffer)
286+
meta = executorch_module.method_meta("forward")
287+
288+
# Ensure that all these APIs work even if the module object is destroyed.
289+
del executorch_module
290+
tester.assertEqual(meta.name(), "forward")
291+
tester.assertEqual(meta.num_inputs(), 2)
292+
tester.assertEqual(meta.num_outputs(), 1)
293+
# Common string for all these tensors.
294+
tensor_info = "TensorInfo(sizes=[2, 2], dtype=Float, is_memory_planned=True, nbytes=16)"
295+
float_dtype = 6
296+
tester.assertEqual(
297+
str(meta),
298+
"MethodMeta(name='forward', num_inputs=2, "
299+
f"input_tensor_meta=['{tensor_info}', '{tensor_info}'], "
300+
f"num_outputs=1, output_tensor_meta=['{tensor_info}'])",
301+
)
302+
303+
input_tensors = [meta.input_tensor_meta(i) for i in range(2)]
304+
output_tensor = meta.output_tensor_meta(0)
305+
# Check that accessing out of bounds raises IndexError.
306+
with tester.assertRaises(IndexError):
307+
meta.input_tensor_meta(2)
308+
# Test that tensor metadata can outlive method metadata.
309+
del meta
310+
tester.assertEqual([t.sizes() for t in input_tensors], [(2, 2), (2, 2)])
311+
tester.assertEqual(
312+
[t.dtype() for t in input_tensors], [float_dtype, float_dtype]
313+
)
314+
tester.assertEqual(
315+
[t.is_memory_planned() for t in input_tensors], [True, True]
316+
)
317+
tester.assertEqual([t.nbytes() for t in input_tensors], [16, 16])
318+
tester.assertEqual(str(input_tensors), f"[{tensor_info}, {tensor_info}]")
319+
320+
tester.assertEqual(output_tensor.sizes(), (2, 2))
321+
tester.assertEqual(output_tensor.dtype(), float_dtype)
322+
tester.assertEqual(output_tensor.is_memory_planned(), True)
323+
tester.assertEqual(output_tensor.nbytes(), 16)
324+
tester.assertEqual(str(output_tensor), tensor_info)
325+
326+
######### RUN TEST CASES #########
254327
test_e2e(tester)
255328
test_multiple_entry(tester)
256329
test_output_lifespan(tester)
257330
test_module_callable(tester)
258331
test_module_single_input(tester)
259332
test_stderr_redirect(tester)
260333
test_quantized_ops(tester)
334+
test_constant_output_not_memory_planned(tester)
335+
test_method_meta(tester)
261336

262337
return wrapper

0 commit comments

Comments
 (0)