Skip to content

Commit 26f8fa8

Browse files
cccclaifacebook-github-bot
authored andcommitted
Allow getting all backend names (#8520)
Summary: Allow getting all backends name in both python and c++ Reviewed By: omerjerk Differential Revision: D69691354
1 parent 80d5e5a commit 26f8fa8

File tree

8 files changed

+89
-1
lines changed

8 files changed

+89
-1
lines changed

extension/pybindings/portable_lib.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
_create_profile_block, # noqa: F401
3939
_dump_profile_results, # noqa: F401
4040
_get_operator_names, # noqa: F401
41+
_get_registered_backend_names, # noqa: F401
4142
_load_bundled_program_from_buffer, # noqa: F401
4243
_load_for_executorch, # noqa: F401
4344
_load_for_executorch_from_buffer, # noqa: F401

extension/pybindings/pybindings.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <executorch/extension/data_loader/buffer_data_loader.h>
2424
#include <executorch/extension/data_loader/mmap_data_loader.h>
2525
#include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
26+
#include <executorch/runtime/backend/interface.h>
2627
#include <executorch/runtime/core/data_loader.h>
2728
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
2829
#include <executorch/runtime/executor/method.h>
@@ -91,6 +92,8 @@ using ::executorch::runtime::DataLoader;
9192
using ::executorch::runtime::Error;
9293
using ::executorch::runtime::EValue;
9394
using ::executorch::runtime::EventTracerDebugLogLevel;
95+
using ::executorch::runtime::get_backend_name;
96+
using ::executorch::runtime::get_num_registered_backends;
9497
using ::executorch::runtime::get_registered_kernels;
9598
using ::executorch::runtime::HierarchicalAllocator;
9699
using ::executorch::runtime::Kernel;
@@ -975,6 +978,18 @@ py::list get_operator_names() {
975978
return res;
976979
}
977980

981+
py::list get_registered_backend_names() {
982+
size_t n_of_registered_backends = get_num_registered_backends();
983+
py::list res;
984+
for (size_t i = 0; i < n_of_registered_backends; i++) {
985+
auto backend_name_res = get_backend_name(i);
986+
THROW_IF_ERROR(backend_name_res.error(), "Failed to get backend name");
987+
auto backend_name = backend_name_res.get();
988+
res.append(backend_name);
989+
}
990+
return res;
991+
}
992+
978993
} // namespace
979994

980995
PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
@@ -1028,6 +1043,10 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
10281043
prof_result.num_bytes);
10291044
},
10301045
call_guard);
1046+
m.def(
1047+
"_get_registered_backend_names",
1048+
&get_registered_backend_names,
1049+
call_guard);
10311050
m.def("_get_operator_names", &get_operator_names);
10321051
m.def("_create_profile_block", &create_profile_block, call_guard);
10331052
m.def(

extension/pybindings/pybindings.pyi

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,15 @@ def _get_operator_names() -> List[str]:
220220
"""
221221
...
222222

223+
@experimental("This API is experimental and subject to change without notice.")
224+
def _get_registered_backend_names() -> List[str]:
225+
"""
226+
.. warning::
227+
228+
This API is experimental and subject to change without notice.
229+
"""
230+
...
231+
223232
@experimental("This API is experimental and subject to change without notice.")
224233
def _create_profile_block(name: str) -> None:
225234
"""

extension/pybindings/test/TARGETS

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,11 @@ runtime.python_test(
4747
"//executorch/kernels/quantized:aot_lib",
4848
],
4949
)
50+
51+
runtime.python_test(
52+
name = "test_backend_pybinding",
53+
srcs = ["test_backend_pybinding.py"],
54+
deps = [
55+
"//executorch/runtime:runtime",
56+
],
57+
)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import unittest
2+
3+
from executorch.runtime import Runtime
4+
5+
6+
class TestBackendsPybinding(unittest.TestCase):
7+
def test_backend_name_list(
8+
self,
9+
) -> None:
10+
11+
runtime = Runtime.get()
12+
registered_backend_names = runtime.backend_registry.registered_backend_names
13+
self.assertGreaterEqual(len(registered_backend_names), 1)
14+
self.assertIn("XnnpackBackend", registered_backend_names)

runtime/__init__.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
import functools
4343
from pathlib import Path
4444
from types import ModuleType
45-
from typing import Any, BinaryIO, Dict, Optional, Sequence, Set, Union
45+
from typing import Any, BinaryIO, Dict, List, Optional, Sequence, Set, Union
4646

4747
try:
4848
from executorch.extension.pybindings.portable_lib import (
@@ -125,6 +125,21 @@ def load_method(self, name: str) -> Optional[Method]:
125125
return self._methods.get(name, None)
126126

127127

128+
class BackendRegistry:
129+
"""The registry of backends that are available to the runtime."""
130+
131+
def __init__(self, legacy_module: ModuleType) -> None:
132+
# TODO: Expose the kernel callables to Python.
133+
self._legacy_module = legacy_module
134+
135+
@property
136+
def registered_backend_names(self) -> List[str]:
137+
"""
138+
Returns the names of all registered backends as a list of strings.
139+
"""
140+
return self._legacy_module._get_registered_backend_names()
141+
142+
128143
class OperatorRegistry:
129144
"""The registry of operators that are available to the runtime."""
130145

@@ -157,6 +172,7 @@ def get() -> "Runtime":
157172

158173
def __init__(self, *, legacy_module: ModuleType) -> None:
159174
# Public attributes.
175+
self.backend_registry = BackendRegistry(legacy_module)
160176
self.operator_registry = OperatorRegistry(legacy_module)
161177
# Private attributes.
162178
self._legacy_module = legacy_module

runtime/backend/interface.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,5 +55,16 @@ Error register_backend(const Backend& backend) {
5555
return Error::Ok;
5656
}
5757

58+
size_t get_num_registered_backends() {
59+
return num_registered_backends;
60+
}
61+
62+
Result<const char*> get_backend_name(size_t index) {
63+
if (index >= num_registered_backends) {
64+
return Error::InvalidArgument;
65+
}
66+
return registered_backends[index].name;
67+
}
68+
5869
} // namespace runtime
5970
} // namespace executorch

runtime/backend/interface.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,16 @@ struct Backend {
139139
*/
140140
ET_NODISCARD Error register_backend(const Backend& backend);
141141

142+
/**
143+
* Returns the number of registered backends.
144+
*/
145+
size_t get_num_registered_backends();
146+
147+
/**
148+
* Returns the backend name at the given index.
149+
*/
150+
Result<const char*> get_backend_name(size_t index);
151+
142152
} // namespace runtime
143153
} // namespace executorch
144154

0 commit comments

Comments
 (0)