Skip to content

Commit 2a448da

Browse files
[mlir][python] Make types in register_(dialect|operation) more narrow. (llvm#115307)
This PR makes the `pyClass`/`dialectClass` arguments of the pybind11 functions `register_dialect` and `register_operation` as well as their return types more narrow, concretely, a `py::type` instead of a `py::object`. As the name of the arguments indicate, they have to be called with a type instance (a "class"). The PR also updates the typing stubs of these functions (in the corresponding `.pyi` file), such that static type checkers are aware of the changed type. With the previous typing information, `pyright` raised errors on code generated by tablegen. Signed-off-by: Ingo Müller <[email protected]>
1 parent d1aa0da commit 2a448da

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

mlir/lib/Bindings/Python/MainModule.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ PYBIND11_MODULE(_mlir, m) {
5858
// Registration decorators.
5959
m.def(
6060
"register_dialect",
61-
[](py::object pyClass) {
61+
[](py::type pyClass) {
6262
std::string dialectNamespace =
6363
pyClass.attr("DIALECT_NAMESPACE").cast<std::string>();
6464
PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
@@ -68,9 +68,9 @@ PYBIND11_MODULE(_mlir, m) {
6868
"Class decorator for registering a custom Dialect wrapper");
6969
m.def(
7070
"register_operation",
71-
[](const py::object &dialectClass, bool replace) -> py::cpp_function {
71+
[](const py::type &dialectClass, bool replace) -> py::cpp_function {
7272
return py::cpp_function(
73-
[dialectClass, replace](py::object opClass) -> py::object {
73+
[dialectClass, replace](py::type opClass) -> py::type {
7474
std::string operationName =
7575
opClass.attr("OPERATION_NAME").cast<std::string>();
7676
PyGlobals::get().registerOperationImpl(operationName, opClass,

mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@ class _Globals:
88
def append_dialect_search_prefix(self, module_name: str) -> None: ...
99
def _check_dialect_module_loaded(self, dialect_namespace: str) -> bool: ...
1010

11-
def register_dialect(dialect_class: type) -> object: ...
12-
def register_operation(dialect_class: type, *, replace: bool = ...) -> object: ...
11+
def register_dialect(dialect_class: type) -> type: ...
12+
def register_operation(dialect_class: type, *, replace: bool = ...) -> type: ...

0 commit comments

Comments
 (0)