Skip to content

Commit 7942d33

Browse files
committed
[mlir python] Port Python core code to nanobind.
Why? https://nanobind.readthedocs.io/en/latest/why.html says it better than I can, but my primary motivation for this change is to improve MLIR IR construction time from JAX. For a complicated Google-internal LLM model in JAX, this change improves the MLIR lowering time by around 5s (out of around 30s), which is a significant speedup for simply switching binding frameworks. To a large extent, this is a mechanical change, for instance changing pybind11:: to nanobind::. Notes: * this PR needs wjakob/nanobind#806 to land in nanobind first. Without that fix, importing the MLIR modules will fail. * this PR does not port the in-tree dialect extension modules. They can be ported in a future PR. * I removed the py::sibling() annotations from def_static and def_class in PybindAdapters.h. These ask pybind11 to try to form an overload with an existing method, but it's not possible to form mixed pybind11/nanobind overloads this ways and the parent class is now defined in nanobind. Better solutions may be possible here. * nanobind does not contain an exact equivalent of pybind11's buffer protocol support. It was not hard to add a nanobind implementation of a similar API. * nanobind is pickier about casting to std::vector<bool>, expecting that the input is a sequence of bool types, not truthy values. In a couple of places I added code to support truthy values during casting. * nanobind distinguishes bytes (nb::bytes) from strings (e.g., std::string). This required nb::bytes overloads in a few places.
1 parent 5d8eabc commit 7942d33

File tree

20 files changed

+2165
-1916
lines changed

20 files changed

+2165
-1916
lines changed

mlir/include/mlir/Bindings/Python/Diagnostics.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,17 @@
1212
#include <cassert>
1313
#include <string>
1414

15+
#include "llvm/ADT/StringRef.h"
1516
#include "mlir-c/Diagnostics.h"
1617
#include "mlir-c/IR.h"
17-
#include "llvm/ADT/StringRef.h"
1818

1919
namespace mlir {
2020
namespace python {
2121

2222
/// RAII scope intercepting all diagnostics into a string. The message must be
2323
/// checked before this goes out of scope.
2424
class CollectDiagnosticsToStringScope {
25-
public:
25+
public:
2626
explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
2727
handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
2828
/*deleteUserData=*/nullptr);
@@ -34,7 +34,7 @@ class CollectDiagnosticsToStringScope {
3434

3535
[[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
3636

37-
private:
37+
private:
3838
static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
3939
auto printer = +[](MlirStringRef message, void *data) {
4040
*static_cast<std::string *>(data) +=
@@ -53,7 +53,7 @@ class CollectDiagnosticsToStringScope {
5353
std::string errorMessage = "";
5454
};
5555

56-
} // namespace python
57-
} // namespace mlir
56+
} // namespace python
57+
} // namespace mlir
5858

59-
#endif // MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H
59+
#endif // MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H

mlir/include/mlir/Bindings/Python/IRTypes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H
1010
#define MLIR_BINDINGS_PYTHON_IRTYPES_H
1111

12-
#include "mlir/Bindings/Python/PybindAdaptors.h"
12+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
1313

1414
namespace mlir {
1515

mlir/include/mlir/Bindings/Python/NanobindAdaptors.h

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424

2525
#include <cstdint>
2626

27+
#include "llvm/ADT/Twine.h"
2728
#include "mlir-c/Bindings/Python/Interop.h"
2829
#include "mlir-c/Diagnostics.h"
2930
#include "mlir-c/IR.h"
30-
#include "llvm/ADT/Twine.h"
3131

3232
// Raw CAPI type casters need to be declared before use, so always include them
3333
// first.
@@ -233,8 +233,7 @@ struct type_caster<MlirOperation> {
233233
}
234234
static handle from_cpp(MlirOperation v, rv_policy,
235235
cleanup_list *cleanup) noexcept {
236-
if (v.ptr == nullptr)
237-
return nanobind::none();
236+
if (v.ptr == nullptr) return nanobind::none();
238237
nanobind::object capsule =
239238
nanobind::steal<nanobind::object>(mlirPythonOperationToCapsule(v));
240239
return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
@@ -255,8 +254,7 @@ struct type_caster<MlirValue> {
255254
}
256255
static handle from_cpp(MlirValue v, rv_policy,
257256
cleanup_list *cleanup) noexcept {
258-
if (v.ptr == nullptr)
259-
return nanobind::none();
257+
if (v.ptr == nullptr) return nanobind::none();
260258
nanobind::object capsule =
261259
nanobind::steal<nanobind::object>(mlirPythonValueToCapsule(v));
262260
return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
@@ -289,8 +287,7 @@ struct type_caster<MlirTypeID> {
289287
}
290288
static handle from_cpp(MlirTypeID v, rv_policy,
291289
cleanup_list *cleanup) noexcept {
292-
if (v.ptr == nullptr)
293-
return nanobind::none();
290+
if (v.ptr == nullptr) return nanobind::none();
294291
nanobind::object capsule =
295292
nanobind::steal<nanobind::object>(mlirPythonTypeIDToCapsule(v));
296293
return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
@@ -321,8 +318,8 @@ struct type_caster<MlirType> {
321318
}
322319
};
323320

324-
} // namespace detail
325-
} // namespace nanobind
321+
} // namespace detail
322+
} // namespace nanobind
326323

327324
namespace mlir {
328325
namespace python {
@@ -340,7 +337,7 @@ namespace nanobind_adaptors {
340337
/// (plus a fair amount of extra curricular poking)
341338
/// TODO: If this proves useful, see about including it in nanobind.
342339
class pure_subclass {
343-
public:
340+
public:
344341
pure_subclass(nanobind::handle scope, const char *derivedClassName,
345342
const nanobind::object &superClass) {
346343
nanobind::object pyType =
@@ -382,7 +379,7 @@ class pure_subclass {
382379
"function pointer");
383380
nanobind::object cf = nanobind::cpp_function(
384381
std::forward<Func>(f),
385-
nanobind::name(name), // nanobind::scope(thisClass),
382+
nanobind::name(name), // nanobind::scope(thisClass),
386383
extra...);
387384
thisClass.attr(name) = cf;
388385
return *this;
@@ -396,7 +393,7 @@ class pure_subclass {
396393
"function pointer");
397394
nanobind::object cf = nanobind::cpp_function(
398395
std::forward<Func>(f),
399-
nanobind::name(name), // nanobind::scope(thisClass),
396+
nanobind::name(name), // nanobind::scope(thisClass),
400397
extra...);
401398
thisClass.attr(name) =
402399
nanobind::borrow<nanobind::object>(PyClassMethod_New(cf.ptr()));
@@ -405,15 +402,15 @@ class pure_subclass {
405402

406403
nanobind::object get_class() const { return thisClass; }
407404

408-
protected:
405+
protected:
409406
nanobind::object superClass;
410407
nanobind::object thisClass;
411408
};
412409

413410
/// Creates a custom subclass of mlir.ir.Attribute, implementing a casting
414411
/// constructor and type checking methods.
415412
class mlir_attribute_subclass : public pure_subclass {
416-
public:
413+
public:
417414
using IsAFunctionTy = bool (*)(MlirAttribute);
418415
using GetTypeIDFunctionTy = MlirTypeID (*)();
419416

@@ -445,7 +442,7 @@ class mlir_attribute_subclass : public pure_subclass {
445442
// have no additional members, we can just return the instance thus created
446443
// without amending it.
447444
std::string captureTypeName(
448-
typeClassName); // As string in case if typeClassName is not static.
445+
typeClassName); // As string in case if typeClassName is not static.
449446
nanobind::object newCf = nanobind::cpp_function(
450447
[superCls, isaFunction, captureTypeName](
451448
nanobind::object cls, nanobind::object otherAttribute) {
@@ -491,7 +488,7 @@ class mlir_attribute_subclass : public pure_subclass {
491488
/// Creates a custom subclass of mlir.ir.Type, implementing a casting
492489
/// constructor and type checking methods.
493490
class mlir_type_subclass : public pure_subclass {
494-
public:
491+
public:
495492
using IsAFunctionTy = bool (*)(MlirType);
496493
using GetTypeIDFunctionTy = MlirTypeID (*)();
497494

@@ -523,7 +520,7 @@ class mlir_type_subclass : public pure_subclass {
523520
// have no additional members, we can just return the instance thus created
524521
// without amending it.
525522
std::string captureTypeName(
526-
typeClassName); // As string in case if typeClassName is not static.
523+
typeClassName); // As string in case if typeClassName is not static.
527524
nanobind::object newCf = nanobind::cpp_function(
528525
[superCls, isaFunction, captureTypeName](nanobind::object cls,
529526
nanobind::object otherType) {
@@ -573,7 +570,7 @@ class mlir_type_subclass : public pure_subclass {
573570
/// Creates a custom subclass of mlir.ir.Value, implementing a casting
574571
/// constructor and type checking methods.
575572
class mlir_value_subclass : public pure_subclass {
576-
public:
573+
public:
577574
using IsAFunctionTy = bool (*)(MlirValue);
578575

579576
/// Subclasses by looking up the super-class dynamically.
@@ -601,7 +598,7 @@ class mlir_value_subclass : public pure_subclass {
601598
// have no additional members, we can just return the instance thus created
602599
// without amending it.
603600
std::string captureValueName(
604-
valueClassName); // As string in case if valueClassName is not static.
601+
valueClassName); // As string in case if valueClassName is not static.
605602
nanobind::object newCf = nanobind::cpp_function(
606603
[superCls, isaFunction, captureValueName](nanobind::object cls,
607604
nanobind::object otherValue) {
@@ -629,12 +626,12 @@ class mlir_value_subclass : public pure_subclass {
629626
}
630627
};
631628

632-
} // namespace nanobind_adaptors
629+
} // namespace nanobind_adaptors
633630

634631
/// RAII scope intercepting all diagnostics into a string. The message must be
635632
/// checked before this goes out of scope.
636633
class CollectDiagnosticsToStringScope {
637-
public:
634+
public:
638635
explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
639636
handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
640637
/*deleteUserData=*/nullptr);
@@ -646,7 +643,7 @@ class CollectDiagnosticsToStringScope {
646643

647644
[[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
648645

649-
private:
646+
private:
650647
static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
651648
auto printer = +[](MlirStringRef message, void *data) {
652649
*static_cast<std::string *>(data) +=
@@ -665,7 +662,7 @@ class CollectDiagnosticsToStringScope {
665662
std::string errorMessage = "";
666663
};
667664

668-
} // namespace python
669-
} // namespace mlir
665+
} // namespace python
666+
} // namespace mlir
670667

671-
#endif // MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H
668+
#endif // MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H

mlir/include/mlir/Bindings/Python/PybindAdaptors.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -374,9 +374,8 @@ class pure_subclass {
374374
static_assert(!std::is_member_function_pointer<Func>::value,
375375
"def_staticmethod(...) called with a non-static member "
376376
"function pointer");
377-
py::cpp_function cf(
378-
std::forward<Func>(f), py::name(name), py::scope(thisClass),
379-
py::sibling(py::getattr(thisClass, name, py::none())), extra...);
377+
py::cpp_function cf(std::forward<Func>(f), py::name(name),
378+
py::scope(thisClass), extra...);
380379
thisClass.attr(cf.name()) = py::staticmethod(cf);
381380
return *this;
382381
}
@@ -387,9 +386,8 @@ class pure_subclass {
387386
static_assert(!std::is_member_function_pointer<Func>::value,
388387
"def_classmethod(...) called with a non-static member "
389388
"function pointer");
390-
py::cpp_function cf(
391-
std::forward<Func>(f), py::name(name), py::scope(thisClass),
392-
py::sibling(py::getattr(thisClass, name, py::none())), extra...);
389+
py::cpp_function cf(std::forward<Func>(f), py::name(name),
390+
py::scope(thisClass), extra...);
393391
thisClass.attr(cf.name()) =
394392
py::reinterpret_borrow<py::object>(PyClassMethod_New(cf.ptr()));
395393
return *this;

mlir/lib/Bindings/Python/Globals.h

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,16 @@
99
#ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
1010
#define MLIR_BINDINGS_PYTHON_GLOBALS_H
1111

12-
#include "PybindUtils.h"
12+
#include <optional>
13+
#include <string>
14+
#include <vector>
1315

14-
#include "mlir-c/IR.h"
15-
#include "mlir/CAPI/Support.h"
16+
#include "NanobindUtils.h"
1617
#include "llvm/ADT/DenseMap.h"
1718
#include "llvm/ADT/StringRef.h"
1819
#include "llvm/ADT/StringSet.h"
19-
20-
#include <optional>
21-
#include <string>
22-
#include <vector>
20+
#include "mlir-c/IR.h"
21+
#include "mlir/CAPI/Support.h"
2322

2423
namespace mlir {
2524
namespace python {
@@ -57,71 +56,71 @@ class PyGlobals {
5756
/// Raises an exception if the mapping already exists and replace == false.
5857
/// This is intended to be called by implementation code.
5958
void registerAttributeBuilder(const std::string &attributeKind,
60-
pybind11::function pyFunc,
59+
nanobind::callable pyFunc,
6160
bool replace = false);
6261

6362
/// Adds a user-friendly type caster. Raises an exception if the mapping
6463
/// already exists and replace == false. This is intended to be called by
6564
/// implementation code.
66-
void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster,
65+
void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster,
6766
bool replace = false);
6867

6968
/// Adds a user-friendly value caster. Raises an exception if the mapping
7069
/// already exists and replace == false. This is intended to be called by
7170
/// implementation code.
7271
void registerValueCaster(MlirTypeID mlirTypeID,
73-
pybind11::function valueCaster,
72+
nanobind::callable valueCaster,
7473
bool replace = false);
7574

7675
/// Adds a concrete implementation dialect class.
7776
/// Raises an exception if the mapping already exists.
7877
/// This is intended to be called by implementation code.
7978
void registerDialectImpl(const std::string &dialectNamespace,
80-
pybind11::object pyClass);
79+
nanobind::object pyClass);
8180

8281
/// Adds a concrete implementation operation class.
8382
/// Raises an exception if the mapping already exists and replace == false.
8483
/// This is intended to be called by implementation code.
8584
void registerOperationImpl(const std::string &operationName,
86-
pybind11::object pyClass, bool replace = false);
85+
nanobind::object pyClass, bool replace = false);
8786

8887
/// Returns the custom Attribute builder for Attribute kind.
89-
std::optional<pybind11::function>
90-
lookupAttributeBuilder(const std::string &attributeKind);
88+
std::optional<nanobind::callable> lookupAttributeBuilder(
89+
const std::string &attributeKind);
9190

9291
/// Returns the custom type caster for MlirTypeID mlirTypeID.
93-
std::optional<pybind11::function> lookupTypeCaster(MlirTypeID mlirTypeID,
92+
std::optional<nanobind::callable> lookupTypeCaster(MlirTypeID mlirTypeID,
9493
MlirDialect dialect);
9594

9695
/// Returns the custom value caster for MlirTypeID mlirTypeID.
97-
std::optional<pybind11::function> lookupValueCaster(MlirTypeID mlirTypeID,
96+
std::optional<nanobind::callable> lookupValueCaster(MlirTypeID mlirTypeID,
9897
MlirDialect dialect);
9998

10099
/// Looks up a registered dialect class by namespace. Note that this may
101100
/// trigger loading of the defining module and can arbitrarily re-enter.
102-
std::optional<pybind11::object>
103-
lookupDialectClass(const std::string &dialectNamespace);
101+
std::optional<nanobind::object> lookupDialectClass(
102+
const std::string &dialectNamespace);
104103

105104
/// Looks up a registered operation class (deriving from OpView) by operation
106105
/// name. Note that this may trigger a load of the dialect, which can
107106
/// arbitrarily re-enter.
108-
std::optional<pybind11::object>
109-
lookupOperationClass(llvm::StringRef operationName);
107+
std::optional<nanobind::object> lookupOperationClass(
108+
llvm::StringRef operationName);
110109

111-
private:
110+
private:
112111
static PyGlobals *instance;
113112
/// Module name prefixes to search under for dialect implementation modules.
114113
std::vector<std::string> dialectSearchPrefixes;
115114
/// Map of dialect namespace to external dialect class object.
116-
llvm::StringMap<pybind11::object> dialectClassMap;
115+
llvm::StringMap<nanobind::object> dialectClassMap;
117116
/// Map of full operation name to external operation class object.
118-
llvm::StringMap<pybind11::object> operationClassMap;
117+
llvm::StringMap<nanobind::object> operationClassMap;
119118
/// Map of attribute ODS name to custom builder.
120-
llvm::StringMap<pybind11::object> attributeBuilderMap;
119+
llvm::StringMap<nanobind::callable> attributeBuilderMap;
121120
/// Map of MlirTypeID to custom type caster.
122-
llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
121+
llvm::DenseMap<MlirTypeID, nanobind::callable> typeCasterMap;
123122
/// Map of MlirTypeID to custom value caster.
124-
llvm::DenseMap<MlirTypeID, pybind11::object> valueCasterMap;
123+
llvm::DenseMap<MlirTypeID, nanobind::callable> valueCasterMap;
125124
/// Set of dialect namespaces that we have attempted to import implementation
126125
/// modules for.
127126
llvm::StringSet<> loadedDialectModules;

0 commit comments

Comments
 (0)