Skip to content

Commit b56d1ec

Browse files
authored
[mlir python] Port Python core code to nanobind. (#120473)
Relands #118583, with a fix for Python 3.8 compatibility. It was not possible to set the buffer protocol accessers via slots in Python 3.8. Why? https://nanobind.readthedocs.io/en/latest/why.html says it better than I can, but my primary motivation for this change is to improve MLIR IR construction time from JAX. For a complicated Google-internal LLM model in JAX, this change improves the MLIR lowering time by around 5s (out of around 30s), which is a significant speedup for simply switching binding frameworks. To a large extent, this is a mechanical change, for instance changing `pybind11::` to `nanobind::`. Notes: * this PR needs Nanobind 2.4.0, because it needs a bug fix (wjakob/nanobind#806) that landed in that release. * this PR does not port the in-tree dialect extension modules. They can be ported in a future PR. * I removed the py::sibling() annotations from def_static and def_class in `PybindAdapters.h`. These ask pybind11 to try to form an overload with an existing method, but it's not possible to form mixed pybind11/nanobind overloads this ways and the parent class is now defined in nanobind. Better solutions may be possible here. * nanobind does not contain an exact equivalent of pybind11's buffer protocol support. It was not hard to add a nanobind implementation of a similar API. * nanobind is pickier about casting to std::vector<bool>, expecting that the input is a sequence of bool types, not truthy values. In a couple of places I added code to support truthy values during casting. * nanobind distinguishes bytes (`nb::bytes`) from strings (e.g., `std::string`). This required nb::bytes overloads in a few places.
1 parent 89b34ec commit b56d1ec

File tree

23 files changed

+1898
-1583
lines changed

23 files changed

+1898
-1583
lines changed

mlir/cmake/modules/MLIRDetectPythonEnv.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ macro(mlir_configure_python_dev_packages)
3939
"extension = '${PYTHON_MODULE_EXTENSION}")
4040

4141
mlir_detect_nanobind_install()
42-
find_package(nanobind 2.2 CONFIG REQUIRED)
42+
find_package(nanobind 2.4 CONFIG REQUIRED)
4343
message(STATUS "Found nanobind v${nanobind_VERSION}: ${nanobind_INCLUDE_DIR}")
4444
message(STATUS "Python prefix = '${PYTHON_MODULE_PREFIX}', "
4545
"suffix = '${PYTHON_MODULE_SUFFIX}', "

mlir/include/mlir/Bindings/Python/IRTypes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H
1010
#define MLIR_BINDINGS_PYTHON_IRTYPES_H
1111

12-
#include "mlir/Bindings/Python/PybindAdaptors.h"
12+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
1313

1414
namespace mlir {
1515

mlir/include/mlir/Bindings/Python/NanobindAdaptors.h

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ static nanobind::object mlirApiObjectToCapsule(nanobind::handle apiObject) {
6464
/// Casts object <-> MlirAffineMap.
6565
template <>
6666
struct type_caster<MlirAffineMap> {
67-
NB_TYPE_CASTER(MlirAffineMap, const_name("MlirAffineMap"));
67+
NB_TYPE_CASTER(MlirAffineMap, const_name("MlirAffineMap"))
6868
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
6969
nanobind::object capsule = mlirApiObjectToCapsule(src);
7070
value = mlirPythonCapsuleToAffineMap(capsule.ptr());
@@ -87,7 +87,7 @@ struct type_caster<MlirAffineMap> {
8787
/// Casts object <-> MlirAttribute.
8888
template <>
8989
struct type_caster<MlirAttribute> {
90-
NB_TYPE_CASTER(MlirAttribute, const_name("MlirAttribute"));
90+
NB_TYPE_CASTER(MlirAttribute, const_name("MlirAttribute"))
9191
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
9292
nanobind::object capsule = mlirApiObjectToCapsule(src);
9393
value = mlirPythonCapsuleToAttribute(capsule.ptr());
@@ -108,7 +108,7 @@ struct type_caster<MlirAttribute> {
108108
/// Casts object -> MlirBlock.
109109
template <>
110110
struct type_caster<MlirBlock> {
111-
NB_TYPE_CASTER(MlirBlock, const_name("MlirBlock"));
111+
NB_TYPE_CASTER(MlirBlock, const_name("MlirBlock"))
112112
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
113113
nanobind::object capsule = mlirApiObjectToCapsule(src);
114114
value = mlirPythonCapsuleToBlock(capsule.ptr());
@@ -119,7 +119,7 @@ struct type_caster<MlirBlock> {
119119
/// Casts object -> MlirContext.
120120
template <>
121121
struct type_caster<MlirContext> {
122-
NB_TYPE_CASTER(MlirContext, const_name("MlirContext"));
122+
NB_TYPE_CASTER(MlirContext, const_name("MlirContext"))
123123
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
124124
if (src.is_none()) {
125125
// Gets the current thread-bound context.
@@ -139,7 +139,7 @@ struct type_caster<MlirContext> {
139139
/// Casts object <-> MlirDialectRegistry.
140140
template <>
141141
struct type_caster<MlirDialectRegistry> {
142-
NB_TYPE_CASTER(MlirDialectRegistry, const_name("MlirDialectRegistry"));
142+
NB_TYPE_CASTER(MlirDialectRegistry, const_name("MlirDialectRegistry"))
143143
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
144144
nanobind::object capsule = mlirApiObjectToCapsule(src);
145145
value = mlirPythonCapsuleToDialectRegistry(capsule.ptr());
@@ -159,7 +159,7 @@ struct type_caster<MlirDialectRegistry> {
159159
/// Casts object <-> MlirLocation.
160160
template <>
161161
struct type_caster<MlirLocation> {
162-
NB_TYPE_CASTER(MlirLocation, const_name("MlirLocation"));
162+
NB_TYPE_CASTER(MlirLocation, const_name("MlirLocation"))
163163
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
164164
if (src.is_none()) {
165165
// Gets the current thread-bound context.
@@ -185,7 +185,7 @@ struct type_caster<MlirLocation> {
185185
/// Casts object <-> MlirModule.
186186
template <>
187187
struct type_caster<MlirModule> {
188-
NB_TYPE_CASTER(MlirModule, const_name("MlirModule"));
188+
NB_TYPE_CASTER(MlirModule, const_name("MlirModule"))
189189
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
190190
nanobind::object capsule = mlirApiObjectToCapsule(src);
191191
value = mlirPythonCapsuleToModule(capsule.ptr());
@@ -206,7 +206,7 @@ struct type_caster<MlirModule> {
206206
template <>
207207
struct type_caster<MlirFrozenRewritePatternSet> {
208208
NB_TYPE_CASTER(MlirFrozenRewritePatternSet,
209-
const_name("MlirFrozenRewritePatternSet"));
209+
const_name("MlirFrozenRewritePatternSet"))
210210
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
211211
nanobind::object capsule = mlirApiObjectToCapsule(src);
212212
value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
@@ -225,7 +225,7 @@ struct type_caster<MlirFrozenRewritePatternSet> {
225225
/// Casts object <-> MlirOperation.
226226
template <>
227227
struct type_caster<MlirOperation> {
228-
NB_TYPE_CASTER(MlirOperation, const_name("MlirOperation"));
228+
NB_TYPE_CASTER(MlirOperation, const_name("MlirOperation"))
229229
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
230230
nanobind::object capsule = mlirApiObjectToCapsule(src);
231231
value = mlirPythonCapsuleToOperation(capsule.ptr());
@@ -247,7 +247,7 @@ struct type_caster<MlirOperation> {
247247
/// Casts object <-> MlirValue.
248248
template <>
249249
struct type_caster<MlirValue> {
250-
NB_TYPE_CASTER(MlirValue, const_name("MlirValue"));
250+
NB_TYPE_CASTER(MlirValue, const_name("MlirValue"))
251251
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
252252
nanobind::object capsule = mlirApiObjectToCapsule(src);
253253
value = mlirPythonCapsuleToValue(capsule.ptr());
@@ -270,7 +270,7 @@ struct type_caster<MlirValue> {
270270
/// Casts object -> MlirPassManager.
271271
template <>
272272
struct type_caster<MlirPassManager> {
273-
NB_TYPE_CASTER(MlirPassManager, const_name("MlirPassManager"));
273+
NB_TYPE_CASTER(MlirPassManager, const_name("MlirPassManager"))
274274
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
275275
nanobind::object capsule = mlirApiObjectToCapsule(src);
276276
value = mlirPythonCapsuleToPassManager(capsule.ptr());
@@ -281,7 +281,7 @@ struct type_caster<MlirPassManager> {
281281
/// Casts object <-> MlirTypeID.
282282
template <>
283283
struct type_caster<MlirTypeID> {
284-
NB_TYPE_CASTER(MlirTypeID, const_name("MlirTypeID"));
284+
NB_TYPE_CASTER(MlirTypeID, const_name("MlirTypeID"))
285285
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
286286
nanobind::object capsule = mlirApiObjectToCapsule(src);
287287
value = mlirPythonCapsuleToTypeID(capsule.ptr());
@@ -303,7 +303,7 @@ struct type_caster<MlirTypeID> {
303303
/// Casts object <-> MlirType.
304304
template <>
305305
struct type_caster<MlirType> {
306-
NB_TYPE_CASTER(MlirType, const_name("MlirType"));
306+
NB_TYPE_CASTER(MlirType, const_name("MlirType"))
307307
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
308308
nanobind::object capsule = mlirApiObjectToCapsule(src);
309309
value = mlirPythonCapsuleToType(capsule.ptr());

mlir/include/mlir/Bindings/Python/PybindAdaptors.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -374,9 +374,8 @@ class pure_subclass {
374374
static_assert(!std::is_member_function_pointer<Func>::value,
375375
"def_staticmethod(...) called with a non-static member "
376376
"function pointer");
377-
py::cpp_function cf(
378-
std::forward<Func>(f), py::name(name), py::scope(thisClass),
379-
py::sibling(py::getattr(thisClass, name, py::none())), extra...);
377+
py::cpp_function cf(std::forward<Func>(f), py::name(name),
378+
py::scope(thisClass), extra...);
380379
thisClass.attr(cf.name()) = py::staticmethod(cf);
381380
return *this;
382381
}
@@ -387,9 +386,8 @@ class pure_subclass {
387386
static_assert(!std::is_member_function_pointer<Func>::value,
388387
"def_classmethod(...) called with a non-static member "
389388
"function pointer");
390-
py::cpp_function cf(
391-
std::forward<Func>(f), py::name(name), py::scope(thisClass),
392-
py::sibling(py::getattr(thisClass, name, py::none())), extra...);
389+
py::cpp_function cf(std::forward<Func>(f), py::name(name),
390+
py::scope(thisClass), extra...);
393391
thisClass.attr(cf.name()) =
394392
py::reinterpret_borrow<py::object>(PyClassMethod_New(cf.ptr()));
395393
return *this;

mlir/lib/Bindings/Python/Globals.h

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,17 @@
99
#ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
1010
#define MLIR_BINDINGS_PYTHON_GLOBALS_H
1111

12-
#include "PybindUtils.h"
12+
#include <optional>
13+
#include <string>
14+
#include <vector>
1315

16+
#include "NanobindUtils.h"
1417
#include "mlir-c/IR.h"
1518
#include "mlir/CAPI/Support.h"
1619
#include "llvm/ADT/DenseMap.h"
1720
#include "llvm/ADT/StringRef.h"
1821
#include "llvm/ADT/StringSet.h"
1922

20-
#include <optional>
21-
#include <string>
22-
#include <vector>
23-
2423
namespace mlir {
2524
namespace python {
2625

@@ -57,71 +56,71 @@ class PyGlobals {
5756
/// Raises an exception if the mapping already exists and replace == false.
5857
/// This is intended to be called by implementation code.
5958
void registerAttributeBuilder(const std::string &attributeKind,
60-
pybind11::function pyFunc,
59+
nanobind::callable pyFunc,
6160
bool replace = false);
6261

6362
/// Adds a user-friendly type caster. Raises an exception if the mapping
6463
/// already exists and replace == false. This is intended to be called by
6564
/// implementation code.
66-
void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster,
65+
void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster,
6766
bool replace = false);
6867

6968
/// Adds a user-friendly value caster. Raises an exception if the mapping
7069
/// already exists and replace == false. This is intended to be called by
7170
/// implementation code.
7271
void registerValueCaster(MlirTypeID mlirTypeID,
73-
pybind11::function valueCaster,
72+
nanobind::callable valueCaster,
7473
bool replace = false);
7574

7675
/// Adds a concrete implementation dialect class.
7776
/// Raises an exception if the mapping already exists.
7877
/// This is intended to be called by implementation code.
7978
void registerDialectImpl(const std::string &dialectNamespace,
80-
pybind11::object pyClass);
79+
nanobind::object pyClass);
8180

8281
/// Adds a concrete implementation operation class.
8382
/// Raises an exception if the mapping already exists and replace == false.
8483
/// This is intended to be called by implementation code.
8584
void registerOperationImpl(const std::string &operationName,
86-
pybind11::object pyClass, bool replace = false);
85+
nanobind::object pyClass, bool replace = false);
8786

8887
/// Returns the custom Attribute builder for Attribute kind.
89-
std::optional<pybind11::function>
88+
std::optional<nanobind::callable>
9089
lookupAttributeBuilder(const std::string &attributeKind);
9190

9291
/// Returns the custom type caster for MlirTypeID mlirTypeID.
93-
std::optional<pybind11::function> lookupTypeCaster(MlirTypeID mlirTypeID,
92+
std::optional<nanobind::callable> lookupTypeCaster(MlirTypeID mlirTypeID,
9493
MlirDialect dialect);
9594

9695
/// Returns the custom value caster for MlirTypeID mlirTypeID.
97-
std::optional<pybind11::function> lookupValueCaster(MlirTypeID mlirTypeID,
96+
std::optional<nanobind::callable> lookupValueCaster(MlirTypeID mlirTypeID,
9897
MlirDialect dialect);
9998

10099
/// Looks up a registered dialect class by namespace. Note that this may
101100
/// trigger loading of the defining module and can arbitrarily re-enter.
102-
std::optional<pybind11::object>
101+
std::optional<nanobind::object>
103102
lookupDialectClass(const std::string &dialectNamespace);
104103

105104
/// Looks up a registered operation class (deriving from OpView) by operation
106105
/// name. Note that this may trigger a load of the dialect, which can
107106
/// arbitrarily re-enter.
108-
std::optional<pybind11::object>
107+
std::optional<nanobind::object>
109108
lookupOperationClass(llvm::StringRef operationName);
110109

111110
private:
112111
static PyGlobals *instance;
113112
/// Module name prefixes to search under for dialect implementation modules.
114113
std::vector<std::string> dialectSearchPrefixes;
115114
/// Map of dialect namespace to external dialect class object.
116-
llvm::StringMap<pybind11::object> dialectClassMap;
115+
llvm::StringMap<nanobind::object> dialectClassMap;
117116
/// Map of full operation name to external operation class object.
118-
llvm::StringMap<pybind11::object> operationClassMap;
117+
llvm::StringMap<nanobind::object> operationClassMap;
119118
/// Map of attribute ODS name to custom builder.
120-
llvm::StringMap<pybind11::object> attributeBuilderMap;
119+
llvm::StringMap<nanobind::callable> attributeBuilderMap;
121120
/// Map of MlirTypeID to custom type caster.
122-
llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
121+
llvm::DenseMap<MlirTypeID, nanobind::callable> typeCasterMap;
123122
/// Map of MlirTypeID to custom value caster.
124-
llvm::DenseMap<MlirTypeID, pybind11::object> valueCasterMap;
123+
llvm::DenseMap<MlirTypeID, nanobind::callable> valueCasterMap;
125124
/// Set of dialect namespaces that we have attempted to import implementation
126125
/// modules for.
127126
llvm::StringSet<> loadedDialectModules;

0 commit comments

Comments
 (0)