Skip to content

Commit adbc118

Browse files
committed
[mlir python] Add nanobind support for standalone dialects.
This PR allows out-of-tree dialects to write Python dialect modules using nanobind instead of pybind11. It may make sense to migrate in-tree dialects and some of the ODS Python infrastructure to nanobind, but that is a topic for a future change. This PR makes the following changes: * adds nanobind to the CMake and Bazel build systems. We also add robin_map to the Bazel build, which is a dependency of nanobind. * adds a PYTHON_BINDING_LIBRARY option to various CMake functions, such as declare_mlir_python_extension, allowing users to select a Python binding library. * creates a fork of mlir/include/mlir/Bindings/Python/PybindAdaptors.h named NanobindAdaptors.h. This plays the same role, using nanobind instead of pybind11. * splits CollectDiagnosticsToStringScope out of PybindAdaptors.h and into a new header mlir/include/mlir/Bindings/Python/Diagnostics.h, since it is code that is no way related to pybind11 or for that matter, Python. * changed the standalone Python extension example to have both pybind11 and nanobind variants. * changed mlir/python/mlir/dialects/python_test.py to have both pybind11 and nanobind variants. Notes: * A slightly unfortunate thing that I needed to do in the CMake integration was to use FindPython in addition to FindPython3, since nanobind's CMake integration expects the Python_ names for variables. Perhaps there's a better way to do this.
1 parent 866755f commit adbc118

File tree

26 files changed

+1155
-82
lines changed

26 files changed

+1155
-82
lines changed

mlir/cmake/modules/AddMLIRPython.cmake

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,11 @@ endfunction()
114114
# EMBED_CAPI_LINK_LIBS: Dependent CAPI libraries that this extension depends
115115
# on. These will be collected for all extensions and put into an
116116
# aggregate dylib that is linked against.
117+
# PYTHON_BINDINGS_LIBRARY: Either pybind11 or nanobind.
117118
function(declare_mlir_python_extension name)
118119
cmake_parse_arguments(ARG
119120
""
120-
"ROOT_DIR;MODULE_NAME;ADD_TO_PARENT"
121+
"ROOT_DIR;MODULE_NAME;ADD_TO_PARENT;PYTHON_BINDINGS_LIBRARY"
121122
"SOURCES;PRIVATE_LINK_LIBS;EMBED_CAPI_LINK_LIBS"
122123
${ARGN})
123124

@@ -126,15 +127,20 @@ function(declare_mlir_python_extension name)
126127
endif()
127128
set(_install_destination "src/python/${name}")
128129

130+
if(NOT ARG_PYTHON_BINDINGS_LIBRARY)
131+
set(ARG_PYTHON_BINDINGS_LIBRARY "pybind11")
132+
endif()
133+
129134
add_library(${name} INTERFACE)
130135
set_target_properties(${name} PROPERTIES
131136
# Yes: Leading-lowercase property names are load bearing and the recommended
132137
# way to do this: https://gitlab.kitware.com/cmake/cmake/-/issues/19261
133-
EXPORT_PROPERTIES "mlir_python_SOURCES_TYPE;mlir_python_EXTENSION_MODULE_NAME;mlir_python_EMBED_CAPI_LINK_LIBS;mlir_python_DEPENDS"
138+
EXPORT_PROPERTIES "mlir_python_SOURCES_TYPE;mlir_python_EXTENSION_MODULE_NAME;mlir_python_EMBED_CAPI_LINK_LIBS;mlir_python_DEPENDS;mlir_python_BINDINGS_LIBRARY"
134139
mlir_python_SOURCES_TYPE extension
135140
mlir_python_EXTENSION_MODULE_NAME "${ARG_MODULE_NAME}"
136141
mlir_python_EMBED_CAPI_LINK_LIBS "${ARG_EMBED_CAPI_LINK_LIBS}"
137142
mlir_python_DEPENDS ""
143+
mlir_python_BINDINGS_LIBRARY "${ARG_PYTHON_BINDINGS_LIBRARY}"
138144
)
139145

140146
# Set the interface source and link_libs properties of the target
@@ -223,12 +229,14 @@ function(add_mlir_python_modules name)
223229
elseif(_source_type STREQUAL "extension")
224230
# Native CPP extension.
225231
get_target_property(_module_name ${sources_target} mlir_python_EXTENSION_MODULE_NAME)
232+
get_target_property(_bindings_library ${sources_target} mlir_python_BINDINGS_LIBRARY)
226233
# Transform relative source to based on root dir.
227234
set(_extension_target "${modules_target}.extension.${_module_name}.dso")
228235
add_mlir_python_extension(${_extension_target} "${_module_name}"
229236
INSTALL_COMPONENT ${modules_target}
230237
INSTALL_DIR "${ARG_INSTALL_PREFIX}/_mlir_libs"
231238
OUTPUT_DIRECTORY "${ARG_ROOT_PREFIX}/_mlir_libs"
239+
PYTHON_BINDINGS_LIBRARY ${_bindings_library}
232240
LINK_LIBS PRIVATE
233241
${sources_target}
234242
${ARG_COMMON_CAPI_LINK_LIBS}
@@ -634,7 +642,7 @@ endfunction()
634642
function(add_mlir_python_extension libname extname)
635643
cmake_parse_arguments(ARG
636644
""
637-
"INSTALL_COMPONENT;INSTALL_DIR;OUTPUT_DIRECTORY"
645+
"INSTALL_COMPONENT;INSTALL_DIR;OUTPUT_DIRECTORY;PYTHON_BINDINGS_LIBRARY"
638646
"SOURCES;LINK_LIBS"
639647
${ARGN})
640648
if(ARG_UNPARSED_ARGUMENTS)
@@ -644,9 +652,15 @@ function(add_mlir_python_extension libname extname)
644652
# The actual extension library produces a shared-object or DLL and has
645653
# sources that must be compiled in accordance with pybind11 needs (RTTI and
646654
# exceptions).
647-
pybind11_add_module(${libname}
648-
${ARG_SOURCES}
649-
)
655+
if(NOT DEFINED ARG_PYTHON_BINDINGS_LIBRARY OR ARG_PYTHON_BINDINGS_LIBRARY STREQUAL "pybind11")
656+
pybind11_add_module(${libname}
657+
${ARG_SOURCES}
658+
)
659+
elseif(ARG_PYTHON_BINDINGS_LIBRARY STREQUAL "nanobind")
660+
nanobind_add_module(${libname}
661+
${ARG_SOURCES}
662+
)
663+
endif()
650664

651665
# The extension itself must be compiled with RTTI and exceptions enabled.
652666
# Also, some warning classes triggered by pybind11 are disabled.

mlir/cmake/modules/MLIRDetectPythonEnv.cmake

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,21 @@ macro(mlir_configure_python_dev_packages)
2020

2121
find_package(Python3 ${LLVM_MINIMUM_PYTHON_VERSION}
2222
COMPONENTS Interpreter ${_python_development_component} REQUIRED)
23+
24+
# It's a little silly to detect Python a second time, but nanobind's cmake
25+
# code looks for Python_ not Python3_.
26+
find_package(Python ${LLVM_MINIMUM_PYTHON_VERSION}
27+
COMPONENTS Interpreter ${_python_development_component} REQUIRED)
2328
unset(_python_development_component)
2429
message(STATUS "Found python include dirs: ${Python3_INCLUDE_DIRS}")
2530
message(STATUS "Found python libraries: ${Python3_LIBRARIES}")
2631
message(STATUS "Found numpy v${Python3_NumPy_VERSION}: ${Python3_NumPy_INCLUDE_DIRS}")
2732
mlir_detect_pybind11_install()
2833
find_package(pybind11 2.10 CONFIG REQUIRED)
2934
message(STATUS "Found pybind11 v${pybind11_VERSION}: ${pybind11_INCLUDE_DIR}")
35+
mlir_detect_nanobind_install()
36+
find_package(nanobind 2.2 CONFIG REQUIRED)
37+
message(STATUS "Found nanobind v${nanobind_VERSION}: ${nanobind_INCLUDE_DIR}")
3038
message(STATUS "Python prefix = '${PYTHON_MODULE_PREFIX}', "
3139
"suffix = '${PYTHON_MODULE_SUFFIX}', "
3240
"extension = '${PYTHON_MODULE_EXTENSION}")
@@ -56,3 +64,29 @@ function(mlir_detect_pybind11_install)
5664
set(pybind11_DIR "${PACKAGE_DIR}" PARENT_SCOPE)
5765
endif()
5866
endfunction()
67+
68+
69+
# Detects a nanobind package installed in the current python environment
70+
# and sets variables to allow it to be found. This allows nanobind to be
71+
# installed via pip, which typically yields a much more recent version than
72+
# the OS install, which will be available otherwise.
73+
function(mlir_detect_nanobind_install)
74+
if(nanobind_DIR)
75+
message(STATUS "Using explicit nanobind cmake directory: ${nanobind_DIR} (-Dnanobind_DIR to change)")
76+
else()
77+
message(STATUS "Checking for nanobind in python path...")
78+
execute_process(
79+
COMMAND "${Python3_EXECUTABLE}"
80+
-c "import nanobind;print(nanobind.cmake_dir(), end='')"
81+
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
82+
RESULT_VARIABLE STATUS
83+
OUTPUT_VARIABLE PACKAGE_DIR
84+
ERROR_QUIET)
85+
if(NOT STATUS EQUAL "0")
86+
message(STATUS "not found (install via 'pip install nanobind' or set nanobind_DIR)")
87+
return()
88+
endif()
89+
message(STATUS "found (${PACKAGE_DIR})")
90+
set(nanobind_DIR "${PACKAGE_DIR}" PARENT_SCOPE)
91+
endif()
92+
endfunction()

mlir/docs/Bindings/Python.md

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,12 +1138,14 @@ attributes and types must connect to the relevant C APIs for building and
11381138
inspection, which must be provided first. Bindings for `Attribute` and `Type`
11391139
subclasses can be defined using
11401140
[`include/mlir/Bindings/Python/PybindAdaptors.h`](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/PybindAdaptors.h)
1141-
utilities that mimic pybind11 API for defining functions and properties. These
1142-
bindings are to be included in a separate pybind11 module. The utilities also
1143-
provide automatic casting between C API handles `MlirAttribute` and `MlirType`
1144-
and their Python counterparts so that the C API handles can be used directly in
1145-
binding implementations. The methods and properties provided by the bindings
1146-
should follow the principles discussed above.
1141+
or
1142+
[`include/mlir/Bindings/Python/NanobindAdaptors.h`](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h)
1143+
utilities that mimic pybind11/nanobind API for defining functions and
1144+
properties. These bindings are to be included in a separate module. The
1145+
utilities also provide automatic casting between C API handles `MlirAttribute`
1146+
and `MlirType` and their Python counterparts so that the C API handles can be
1147+
used directly in binding implementations. The methods and properties provided by
1148+
the bindings should follow the principles discussed above.
11471149

11481150
The attribute and type bindings for a dialect can be located in
11491151
`lib/Bindings/Python/Dialect<Name>.cpp` and should be compiled into a separate
@@ -1179,7 +1181,9 @@ make the passes available along with the dialect.
11791181
Dialect functionality other than IR objects or passes, such as helper functions,
11801182
can be exposed to Python similarly to attributes and types. C API is expected to
11811183
exist for this functionality, which can then be wrapped using pybind11 and
1182-
`[include/mlir/Bindings/Python/PybindAdaptors.h](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/PybindAdaptors.h)`
1184+
`[include/mlir/Bindings/Python/PybindAdaptors.h](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/PybindAdaptors.h)`,
1185+
or nanobind and
1186+
`[include/mlir/Bindings/Python/NanobindAdaptors.h](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h)`
11831187
utilities to connect to the rest of Python API. The bindings can be located in a
1184-
separate pybind11 module or in the same module as attributes and types, and
1188+
separate module or in the same module as attributes and types, and
11851189
loaded along with the dialect.

mlir/examples/standalone/python/CMakeLists.txt

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,32 @@ declare_mlir_dialect_python_bindings(
1717
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir_standalone"
1818
TD_FILE dialects/StandaloneOps.td
1919
SOURCES
20-
dialects/standalone.py
20+
dialects/standalone_pybind11.py
21+
dialects/standalone_nanobind.py
2122
DIALECT_NAME standalone)
2223

23-
declare_mlir_python_extension(StandalonePythonSources.Extension
24-
MODULE_NAME _standaloneDialects
24+
25+
declare_mlir_python_extension(StandalonePythonSources.Pybind11Extension
26+
MODULE_NAME _standaloneDialectsPybind11
27+
ADD_TO_PARENT StandalonePythonSources
28+
SOURCES
29+
StandaloneExtensionPybind11.cpp
30+
EMBED_CAPI_LINK_LIBS
31+
StandaloneCAPI
32+
PYTHON_BINDINGS_LIBRARY pybind11
33+
)
34+
35+
declare_mlir_python_extension(StandalonePythonSources.NanobindExtension
36+
MODULE_NAME _standaloneDialectsNanobind
2537
ADD_TO_PARENT StandalonePythonSources
2638
SOURCES
27-
StandaloneExtension.cpp
39+
StandaloneExtensionNanobind.cpp
2840
EMBED_CAPI_LINK_LIBS
2941
StandaloneCAPI
42+
PYTHON_BINDINGS_LIBRARY nanobind
3043
)
3144

45+
3246
################################################################################
3347
# Common CAPI
3448
################################################################################
@@ -62,3 +76,4 @@ add_mlir_python_modules(StandalonePythonModules
6276
COMMON_CAPI_LINK_LIBS
6377
StandalonePythonCAPI
6478
)
79+
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//===- StandaloneExtension.cpp - Extension module -------------------------===//
2+
//
3+
// This is the nanobind version of the example module. There is also a pybind11
4+
// example in StandaloneExtensionPybind11.cpp.
5+
//
6+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
7+
// See https://llvm.org/LICENSE.txt for license information.
8+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
9+
//
10+
//===----------------------------------------------------------------------===//
11+
12+
#include <nanobind/nanobind.h>
13+
14+
#include "Standalone-c/Dialects.h"
15+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
16+
17+
namespace nb = nanobind;
18+
19+
NB_MODULE(_standaloneDialectsNanobind, m) {
20+
//===--------------------------------------------------------------------===//
21+
// standalone dialect
22+
//===--------------------------------------------------------------------===//
23+
auto standaloneM = m.def_submodule("standalone");
24+
25+
standaloneM.def(
26+
"register_dialect",
27+
[](MlirContext context, bool load) {
28+
MlirDialectHandle handle = mlirGetDialectHandle__standalone__();
29+
mlirDialectHandleRegisterDialect(handle, context);
30+
if (load) {
31+
mlirDialectHandleLoadDialect(handle, context);
32+
}
33+
},
34+
nb::arg("context").none() = nb::none(), nb::arg("load") = true);
35+
}

mlir/examples/standalone/python/StandaloneExtension.cpp renamed to mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
//===- StandaloneExtension.cpp - Extension module -------------------------===//
1+
//===- StandaloneExtensionPybind11.cpp - Extension module -----------------===//
2+
//
3+
// This is the pybind11 version of the example module. There is also a nanobind
4+
// example in StandaloneExtensionNanobind.cpp.
25
//
36
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47
// See https://llvm.org/LICENSE.txt for license information.
@@ -11,7 +14,7 @@
1114

1215
using namespace mlir::python::adaptors;
1316

14-
PYBIND11_MODULE(_standaloneDialects, m) {
17+
PYBIND11_MODULE(_standaloneDialectsPybind11, m) {
1518
//===--------------------------------------------------------------------===//
1619
// standalone dialect
1720
//===--------------------------------------------------------------------===//

mlir/examples/standalone/python/mlir_standalone/dialects/standalone.py renamed to mlir/examples/standalone/python/mlir_standalone/dialects/standalone_nanobind.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

55
from ._standalone_ops_gen import *
6-
from .._mlir_libs._standaloneDialects.standalone import *
6+
from .._mlir_libs._standaloneDialectsNanobind.standalone import *
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from ._standalone_ops_gen import *
6+
from .._mlir_libs._standaloneDialectsPybind11.standalone import *

mlir/examples/standalone/test/python/smoketest.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
1-
# RUN: %python %s | FileCheck %s
1+
# RUN: %python %s pybind11 | FileCheck %s
2+
# RUN: %python %s nanobind | FileCheck %s
23

4+
import sys
35
from mlir_standalone.ir import *
4-
from mlir_standalone.dialects import builtin as builtin_d, standalone as standalone_d
6+
from mlir_standalone.dialects import builtin as builtin_d
7+
8+
if sys.argv[1] == "pybind11":
9+
from mlir_standalone.dialects import standalone_pybind11 as standalone_d
10+
elif sys.argv[1] == "nanobind":
11+
from mlir_standalone.dialects import standalone_nanobind as standalone_d
12+
else:
13+
raise ValueError("Expected either pybind11 or nanobind as arguments")
14+
515

616
with Context():
717
standalone_d.register_dialect()
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
//===- Diagnostics.h - Helpers for diagnostics in Python bindings ---------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H
10+
#define MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H
11+
12+
#include <cassert>
13+
#include <string>
14+
15+
#include "mlir-c/Diagnostics.h"
16+
#include "mlir-c/IR.h"
17+
#include "llvm/ADT/StringRef.h"
18+
19+
namespace mlir {
20+
namespace python {
21+
22+
/// RAII scope intercepting all diagnostics into a string. The message must be
23+
/// checked before this goes out of scope.
24+
class CollectDiagnosticsToStringScope {
25+
public:
26+
explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
27+
handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
28+
/*deleteUserData=*/nullptr);
29+
}
30+
~CollectDiagnosticsToStringScope() {
31+
assert(errorMessage.empty() && "unchecked error message");
32+
mlirContextDetachDiagnosticHandler(context, handlerID);
33+
}
34+
35+
[[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
36+
37+
private:
38+
static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
39+
auto printer = +[](MlirStringRef message, void *data) {
40+
*static_cast<std::string *>(data) +=
41+
llvm::StringRef(message.data, message.length);
42+
};
43+
MlirLocation loc = mlirDiagnosticGetLocation(diag);
44+
*static_cast<std::string *>(data) += "at ";
45+
mlirLocationPrint(loc, printer, data);
46+
*static_cast<std::string *>(data) += ": ";
47+
mlirDiagnosticPrint(diag, printer, data);
48+
return mlirLogicalResultSuccess();
49+
}
50+
51+
MlirContext context;
52+
MlirDiagnosticHandlerID handlerID;
53+
std::string errorMessage = "";
54+
};
55+
56+
} // namespace python
57+
} // namespace mlir
58+
59+
#endif // MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H

0 commit comments

Comments
 (0)