Skip to content

Allow getting all backend names #8520

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions extension/pybindings/portable_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
_create_profile_block, # noqa: F401
_dump_profile_results, # noqa: F401
_get_operator_names, # noqa: F401
_get_registered_backend_names, # noqa: F401
_load_bundled_program_from_buffer, # noqa: F401
_load_for_executorch, # noqa: F401
_load_for_executorch_from_buffer, # noqa: F401
Expand Down
19 changes: 19 additions & 0 deletions extension/pybindings/pybindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <executorch/extension/data_loader/buffer_data_loader.h>
#include <executorch/extension/data_loader/mmap_data_loader.h>
#include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/core/data_loader.h>
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
#include <executorch/runtime/executor/method.h>
Expand Down Expand Up @@ -91,6 +92,8 @@ using ::executorch::runtime::DataLoader;
using ::executorch::runtime::Error;
using ::executorch::runtime::EValue;
using ::executorch::runtime::EventTracerDebugLogLevel;
using ::executorch::runtime::get_backend_name;
using ::executorch::runtime::get_num_registered_backends;
using ::executorch::runtime::get_registered_kernels;
using ::executorch::runtime::HierarchicalAllocator;
using ::executorch::runtime::Kernel;
Expand Down Expand Up @@ -975,6 +978,18 @@ py::list get_operator_names() {
return res;
}

py::list get_registered_backend_names() {
size_t n_of_registered_backends = get_num_registered_backends();
py::list res;
for (size_t i = 0; i < n_of_registered_backends; i++) {
auto backend_name_res = get_backend_name(i);
THROW_IF_ERROR(backend_name_res.error(), "Failed to get backend name");
auto backend_name = backend_name_res.get();
res.append(backend_name);
}
return res;
}

} // namespace

PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
Expand Down Expand Up @@ -1028,6 +1043,10 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
prof_result.num_bytes);
},
call_guard);
m.def(
"_get_registered_backend_names",
&get_registered_backend_names,
call_guard);
m.def("_get_operator_names", &get_operator_names);
m.def("_create_profile_block", &create_profile_block, call_guard);
m.def(
Expand Down
9 changes: 9 additions & 0 deletions extension/pybindings/pybindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,15 @@ def _get_operator_names() -> List[str]:
"""
...

@experimental("This API is experimental and subject to change without notice.")
def _get_registered_backend_names() -> List[str]:
"""
.. warning::

This API is experimental and subject to change without notice.
"""
...

@experimental("This API is experimental and subject to change without notice.")
def _create_profile_block(name: str) -> None:
"""
Expand Down
8 changes: 8 additions & 0 deletions extension/pybindings/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,11 @@ runtime.python_test(
"//executorch/kernels/quantized:aot_lib",
],
)

runtime.python_test(
name = "test_backend_pybinding",
srcs = ["test_backend_pybinding.py"],
deps = [
"//executorch/runtime:runtime",
],
)
14 changes: 14 additions & 0 deletions extension/pybindings/test/test_backend_pybinding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import unittest

from executorch.runtime import Runtime


class TestBackendsPybinding(unittest.TestCase):
def test_backend_name_list(
self,
) -> None:

runtime = Runtime.get()
registered_backend_names = runtime.backend_registry.registered_backend_names
self.assertGreaterEqual(len(registered_backend_names), 1)
self.assertIn("XnnpackBackend", registered_backend_names)
18 changes: 17 additions & 1 deletion runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
import functools
from pathlib import Path
from types import ModuleType
from typing import Any, BinaryIO, Dict, Optional, Sequence, Set, Union
from typing import Any, BinaryIO, Dict, List, Optional, Sequence, Set, Union

try:
from executorch.extension.pybindings.portable_lib import (
Expand Down Expand Up @@ -125,6 +125,21 @@ def load_method(self, name: str) -> Optional[Method]:
return self._methods.get(name, None)


class BackendRegistry:
"""The registry of backends that are available to the runtime."""

def __init__(self, legacy_module: ModuleType) -> None:
# TODO: Expose the kernel callables to Python.
self._legacy_module = legacy_module

@property
def registered_backend_names(self) -> List[str]:
"""
Returns the names of all registered backends as a list of strings.
"""
return self._legacy_module._get_registered_backend_names()


class OperatorRegistry:
"""The registry of operators that are available to the runtime."""

Expand Down Expand Up @@ -157,6 +172,7 @@ def get() -> "Runtime":

def __init__(self, *, legacy_module: ModuleType) -> None:
# Public attributes.
self.backend_registry = BackendRegistry(legacy_module)
self.operator_registry = OperatorRegistry(legacy_module)
# Private attributes.
self._legacy_module = legacy_module
Expand Down
11 changes: 11 additions & 0 deletions runtime/backend/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,16 @@ Error register_backend(const Backend& backend) {
return Error::Ok;
}

size_t get_num_registered_backends() {
return num_registered_backends;
}

Result<const char*> get_backend_name(size_t index) {
if (index >= num_registered_backends) {
return Error::InvalidArgument;
}
return registered_backends[index].name;
}

} // namespace runtime
} // namespace executorch
10 changes: 10 additions & 0 deletions runtime/backend/interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,16 @@ struct Backend {
*/
ET_NODISCARD Error register_backend(const Backend& backend);

/**
* Returns the number of registered backends.
*/
size_t get_num_registered_backends();

/**
* Returns the backend name at the given index.
*/
Result<const char*> get_backend_name(size_t index);

} // namespace runtime
} // namespace executorch

Expand Down
Loading