Skip to content

Commit 522a30d

Browse files
committed
[mlir python] Port Python core code to nanobind.
Why? https://nanobind.readthedocs.io/en/latest/why.html says it better than I can, but my primary motivation for this change is to improve MLIR IR construction time from JAX. For a complicated Google-internal LLM model in JAX, this change improves the MLIR lowering time by around 5s (out of around 30s), which is a significant speedup for simply switching binding frameworks. To a large extent, this is a mechanical change, for instance changing pybind11:: to nanobind::. Notes: * this PR needs wjakob/nanobind#806 to land in nanobind first. Without that fix, importing the MLIR modules will fail. * this PR does not port the in-tree dialect extension modules. They can be ported in a future PR. * I removed the py::sibling() annotations from def_static and def_class in PybindAdapters.h. These ask pybind11 to try to form an overload with an existing method, but it's not possible to form mixed pybind11/nanobind overloads this ways and the parent class is now defined in nanobind. Better solutions may be possible here. * nanobind does not contain an exact equivalent of pybind11's buffer protocol support. It was not hard to add a nanobind implementation of a similar API. * nanobind is pickier about casting to std::vector<bool>, expecting that the input is a sequence of bool types, not truthy values. In a couple of places I added code to support truthy values during casting. * nanobind distinguishes bytes (nb::bytes) from strings (e.g., std::string). This required nb::bytes overloads in a few places.
1 parent 5d8eabc commit 522a30d

File tree

18 files changed

+1838
-1576
lines changed

18 files changed

+1838
-1576
lines changed

mlir/include/mlir/Bindings/Python/IRTypes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H
1010
#define MLIR_BINDINGS_PYTHON_IRTYPES_H
1111

12-
#include "mlir/Bindings/Python/PybindAdaptors.h"
12+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
1313

1414
namespace mlir {
1515

mlir/include/mlir/Bindings/Python/PybindAdaptors.h

Lines changed: 17 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ static py::object mlirApiObjectToCapsule(py::handle apiObject) {
6464
// ownership is unclear.
6565

6666
/// Casts object <-> MlirAffineMap.
67-
template <>
68-
struct type_caster<MlirAffineMap> {
67+
template <> struct type_caster<MlirAffineMap> {
6968
PYBIND11_TYPE_CASTER(MlirAffineMap, _("MlirAffineMap"));
7069
bool load(handle src, bool) {
7170
py::object capsule = mlirApiObjectToCapsule(src);
@@ -86,8 +85,7 @@ struct type_caster<MlirAffineMap> {
8685
};
8786

8887
/// Casts object <-> MlirAttribute.
89-
template <>
90-
struct type_caster<MlirAttribute> {
88+
template <> struct type_caster<MlirAttribute> {
9189
PYBIND11_TYPE_CASTER(MlirAttribute, _("MlirAttribute"));
9290
bool load(handle src, bool) {
9391
py::object capsule = mlirApiObjectToCapsule(src);
@@ -106,8 +104,7 @@ struct type_caster<MlirAttribute> {
106104
};
107105

108106
/// Casts object -> MlirBlock.
109-
template <>
110-
struct type_caster<MlirBlock> {
107+
template <> struct type_caster<MlirBlock> {
111108
PYBIND11_TYPE_CASTER(MlirBlock, _("MlirBlock"));
112109
bool load(handle src, bool) {
113110
py::object capsule = mlirApiObjectToCapsule(src);
@@ -117,8 +114,7 @@ struct type_caster<MlirBlock> {
117114
};
118115

119116
/// Casts object -> MlirContext.
120-
template <>
121-
struct type_caster<MlirContext> {
117+
template <> struct type_caster<MlirContext> {
122118
PYBIND11_TYPE_CASTER(MlirContext, _("MlirContext"));
123119
bool load(handle src, bool) {
124120
if (src.is_none()) {
@@ -137,8 +133,7 @@ struct type_caster<MlirContext> {
137133
};
138134

139135
/// Casts object <-> MlirDialectRegistry.
140-
template <>
141-
struct type_caster<MlirDialectRegistry> {
136+
template <> struct type_caster<MlirDialectRegistry> {
142137
PYBIND11_TYPE_CASTER(MlirDialectRegistry, _("MlirDialectRegistry"));
143138
bool load(handle src, bool) {
144139
py::object capsule = mlirApiObjectToCapsule(src);
@@ -156,8 +151,7 @@ struct type_caster<MlirDialectRegistry> {
156151
};
157152

158153
/// Casts object <-> MlirLocation.
159-
template <>
160-
struct type_caster<MlirLocation> {
154+
template <> struct type_caster<MlirLocation> {
161155
PYBIND11_TYPE_CASTER(MlirLocation, _("MlirLocation"));
162156
bool load(handle src, bool) {
163157
if (src.is_none()) {
@@ -181,8 +175,7 @@ struct type_caster<MlirLocation> {
181175
};
182176

183177
/// Casts object <-> MlirModule.
184-
template <>
185-
struct type_caster<MlirModule> {
178+
template <> struct type_caster<MlirModule> {
186179
PYBIND11_TYPE_CASTER(MlirModule, _("MlirModule"));
187180
bool load(handle src, bool) {
188181
py::object capsule = mlirApiObjectToCapsule(src);
@@ -200,8 +193,7 @@ struct type_caster<MlirModule> {
200193
};
201194

202195
/// Casts object <-> MlirFrozenRewritePatternSet.
203-
template <>
204-
struct type_caster<MlirFrozenRewritePatternSet> {
196+
template <> struct type_caster<MlirFrozenRewritePatternSet> {
205197
PYBIND11_TYPE_CASTER(MlirFrozenRewritePatternSet,
206198
_("MlirFrozenRewritePatternSet"));
207199
bool load(handle src, bool) {
@@ -221,8 +213,7 @@ struct type_caster<MlirFrozenRewritePatternSet> {
221213
};
222214

223215
/// Casts object <-> MlirOperation.
224-
template <>
225-
struct type_caster<MlirOperation> {
216+
template <> struct type_caster<MlirOperation> {
226217
PYBIND11_TYPE_CASTER(MlirOperation, _("MlirOperation"));
227218
bool load(handle src, bool) {
228219
py::object capsule = mlirApiObjectToCapsule(src);
@@ -242,8 +233,7 @@ struct type_caster<MlirOperation> {
242233
};
243234

244235
/// Casts object <-> MlirValue.
245-
template <>
246-
struct type_caster<MlirValue> {
236+
template <> struct type_caster<MlirValue> {
247237
PYBIND11_TYPE_CASTER(MlirValue, _("MlirValue"));
248238
bool load(handle src, bool) {
249239
py::object capsule = mlirApiObjectToCapsule(src);
@@ -264,8 +254,7 @@ struct type_caster<MlirValue> {
264254
};
265255

266256
/// Casts object -> MlirPassManager.
267-
template <>
268-
struct type_caster<MlirPassManager> {
257+
template <> struct type_caster<MlirPassManager> {
269258
PYBIND11_TYPE_CASTER(MlirPassManager, _("MlirPassManager"));
270259
bool load(handle src, bool) {
271260
py::object capsule = mlirApiObjectToCapsule(src);
@@ -275,8 +264,7 @@ struct type_caster<MlirPassManager> {
275264
};
276265

277266
/// Casts object <-> MlirTypeID.
278-
template <>
279-
struct type_caster<MlirTypeID> {
267+
template <> struct type_caster<MlirTypeID> {
280268
PYBIND11_TYPE_CASTER(MlirTypeID, _("MlirTypeID"));
281269
bool load(handle src, bool) {
282270
py::object capsule = mlirApiObjectToCapsule(src);
@@ -296,8 +284,7 @@ struct type_caster<MlirTypeID> {
296284
};
297285

298286
/// Casts object <-> MlirType.
299-
template <>
300-
struct type_caster<MlirType> {
287+
template <> struct type_caster<MlirType> {
301288
PYBIND11_TYPE_CASTER(MlirType, _("MlirType"));
302289
bool load(handle src, bool) {
303290
py::object capsule = mlirApiObjectToCapsule(src);
@@ -374,9 +361,8 @@ class pure_subclass {
374361
static_assert(!std::is_member_function_pointer<Func>::value,
375362
"def_staticmethod(...) called with a non-static member "
376363
"function pointer");
377-
py::cpp_function cf(
378-
std::forward<Func>(f), py::name(name), py::scope(thisClass),
379-
py::sibling(py::getattr(thisClass, name, py::none())), extra...);
364+
py::cpp_function cf(std::forward<Func>(f), py::name(name),
365+
py::scope(thisClass), extra...);
380366
thisClass.attr(cf.name()) = py::staticmethod(cf);
381367
return *this;
382368
}
@@ -387,9 +373,8 @@ class pure_subclass {
387373
static_assert(!std::is_member_function_pointer<Func>::value,
388374
"def_classmethod(...) called with a non-static member "
389375
"function pointer");
390-
py::cpp_function cf(
391-
std::forward<Func>(f), py::name(name), py::scope(thisClass),
392-
py::sibling(py::getattr(thisClass, name, py::none())), extra...);
376+
py::cpp_function cf(std::forward<Func>(f), py::name(name),
377+
py::scope(thisClass), extra...);
393378
thisClass.attr(cf.name()) =
394379
py::reinterpret_borrow<py::object>(PyClassMethod_New(cf.ptr()));
395380
return *this;

mlir/lib/Bindings/Python/Globals.h

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
1010
#define MLIR_BINDINGS_PYTHON_GLOBALS_H
1111

12-
#include "PybindUtils.h"
12+
#include "NanobindUtils.h"
1313

1414
#include "mlir-c/IR.h"
1515
#include "mlir/CAPI/Support.h"
@@ -57,71 +57,71 @@ class PyGlobals {
5757
/// Raises an exception if the mapping already exists and replace == false.
5858
/// This is intended to be called by implementation code.
5959
void registerAttributeBuilder(const std::string &attributeKind,
60-
pybind11::function pyFunc,
60+
nanobind::callable pyFunc,
6161
bool replace = false);
6262

6363
/// Adds a user-friendly type caster. Raises an exception if the mapping
6464
/// already exists and replace == false. This is intended to be called by
6565
/// implementation code.
66-
void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster,
66+
void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster,
6767
bool replace = false);
6868

6969
/// Adds a user-friendly value caster. Raises an exception if the mapping
7070
/// already exists and replace == false. This is intended to be called by
7171
/// implementation code.
7272
void registerValueCaster(MlirTypeID mlirTypeID,
73-
pybind11::function valueCaster,
73+
nanobind::callable valueCaster,
7474
bool replace = false);
7575

7676
/// Adds a concrete implementation dialect class.
7777
/// Raises an exception if the mapping already exists.
7878
/// This is intended to be called by implementation code.
7979
void registerDialectImpl(const std::string &dialectNamespace,
80-
pybind11::object pyClass);
80+
nanobind::object pyClass);
8181

8282
/// Adds a concrete implementation operation class.
8383
/// Raises an exception if the mapping already exists and replace == false.
8484
/// This is intended to be called by implementation code.
8585
void registerOperationImpl(const std::string &operationName,
86-
pybind11::object pyClass, bool replace = false);
86+
nanobind::object pyClass, bool replace = false);
8787

8888
/// Returns the custom Attribute builder for Attribute kind.
89-
std::optional<pybind11::function>
89+
std::optional<nanobind::callable>
9090
lookupAttributeBuilder(const std::string &attributeKind);
9191

9292
/// Returns the custom type caster for MlirTypeID mlirTypeID.
93-
std::optional<pybind11::function> lookupTypeCaster(MlirTypeID mlirTypeID,
93+
std::optional<nanobind::callable> lookupTypeCaster(MlirTypeID mlirTypeID,
9494
MlirDialect dialect);
9595

9696
/// Returns the custom value caster for MlirTypeID mlirTypeID.
97-
std::optional<pybind11::function> lookupValueCaster(MlirTypeID mlirTypeID,
97+
std::optional<nanobind::callable> lookupValueCaster(MlirTypeID mlirTypeID,
9898
MlirDialect dialect);
9999

100100
/// Looks up a registered dialect class by namespace. Note that this may
101101
/// trigger loading of the defining module and can arbitrarily re-enter.
102-
std::optional<pybind11::object>
102+
std::optional<nanobind::object>
103103
lookupDialectClass(const std::string &dialectNamespace);
104104

105105
/// Looks up a registered operation class (deriving from OpView) by operation
106106
/// name. Note that this may trigger a load of the dialect, which can
107107
/// arbitrarily re-enter.
108-
std::optional<pybind11::object>
108+
std::optional<nanobind::object>
109109
lookupOperationClass(llvm::StringRef operationName);
110110

111111
private:
112112
static PyGlobals *instance;
113113
/// Module name prefixes to search under for dialect implementation modules.
114114
std::vector<std::string> dialectSearchPrefixes;
115115
/// Map of dialect namespace to external dialect class object.
116-
llvm::StringMap<pybind11::object> dialectClassMap;
116+
llvm::StringMap<nanobind::object> dialectClassMap;
117117
/// Map of full operation name to external operation class object.
118-
llvm::StringMap<pybind11::object> operationClassMap;
118+
llvm::StringMap<nanobind::object> operationClassMap;
119119
/// Map of attribute ODS name to custom builder.
120-
llvm::StringMap<pybind11::object> attributeBuilderMap;
120+
llvm::StringMap<nanobind::callable> attributeBuilderMap;
121121
/// Map of MlirTypeID to custom type caster.
122-
llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
122+
llvm::DenseMap<MlirTypeID, nanobind::callable> typeCasterMap;
123123
/// Map of MlirTypeID to custom value caster.
124-
llvm::DenseMap<MlirTypeID, pybind11::object> valueCasterMap;
124+
llvm::DenseMap<MlirTypeID, nanobind::callable> valueCasterMap;
125125
/// Set of dialect namespaces that we have attempted to import implementation
126126
/// modules for.
127127
llvm::StringSet<> loadedDialectModules;

0 commit comments

Comments
 (0)