Skip to content

Commit 69d1d41

Browse files
committed
[mlir python] Port in-tree dialects to nanobind.
This is a companion to #118583, although it can be landed independently because since #117922 dialects do not have to use the same Python binding framework as the Python core code. This PR ports all of the in-tree dialect and pass extensions to nanobind, with the exception of those that remain for testing pybind11 support. It would make sense to merge this PR after merging #118583, if we have agreed that we are migrating the core to nanobind. This PR also: * removes CollectDiagnosticsToStringScope from NanobindAdaptors.h. This was overlooked in a previous PR and it is duplicated in Diagnostics.h. * removes some extraneous semicolons in NanobindAdaptors.h
1 parent 3cbc73f commit 69d1d41

File tree

19 files changed

+298
-303
lines changed

19 files changed

+298
-303
lines changed

mlir/include/mlir/Bindings/Python/NanobindAdaptors.h

Lines changed: 13 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ static nanobind::object mlirApiObjectToCapsule(nanobind::handle apiObject) {
6464
/// Casts object <-> MlirAffineMap.
6565
template <>
6666
struct type_caster<MlirAffineMap> {
67-
NB_TYPE_CASTER(MlirAffineMap, const_name("MlirAffineMap"));
67+
NB_TYPE_CASTER(MlirAffineMap, const_name("MlirAffineMap"))
6868
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
6969
nanobind::object capsule = mlirApiObjectToCapsule(src);
7070
value = mlirPythonCapsuleToAffineMap(capsule.ptr());
@@ -87,7 +87,7 @@ struct type_caster<MlirAffineMap> {
8787
/// Casts object <-> MlirAttribute.
8888
template <>
8989
struct type_caster<MlirAttribute> {
90-
NB_TYPE_CASTER(MlirAttribute, const_name("MlirAttribute"));
90+
NB_TYPE_CASTER(MlirAttribute, const_name("MlirAttribute"))
9191
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
9292
nanobind::object capsule = mlirApiObjectToCapsule(src);
9393
value = mlirPythonCapsuleToAttribute(capsule.ptr());
@@ -108,7 +108,7 @@ struct type_caster<MlirAttribute> {
108108
/// Casts object -> MlirBlock.
109109
template <>
110110
struct type_caster<MlirBlock> {
111-
NB_TYPE_CASTER(MlirBlock, const_name("MlirBlock"));
111+
NB_TYPE_CASTER(MlirBlock, const_name("MlirBlock"))
112112
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
113113
nanobind::object capsule = mlirApiObjectToCapsule(src);
114114
value = mlirPythonCapsuleToBlock(capsule.ptr());
@@ -119,7 +119,7 @@ struct type_caster<MlirBlock> {
119119
/// Casts object -> MlirContext.
120120
template <>
121121
struct type_caster<MlirContext> {
122-
NB_TYPE_CASTER(MlirContext, const_name("MlirContext"));
122+
NB_TYPE_CASTER(MlirContext, const_name("MlirContext"))
123123
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
124124
if (src.is_none()) {
125125
// Gets the current thread-bound context.
@@ -139,7 +139,7 @@ struct type_caster<MlirContext> {
139139
/// Casts object <-> MlirDialectRegistry.
140140
template <>
141141
struct type_caster<MlirDialectRegistry> {
142-
NB_TYPE_CASTER(MlirDialectRegistry, const_name("MlirDialectRegistry"));
142+
NB_TYPE_CASTER(MlirDialectRegistry, const_name("MlirDialectRegistry"))
143143
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
144144
nanobind::object capsule = mlirApiObjectToCapsule(src);
145145
value = mlirPythonCapsuleToDialectRegistry(capsule.ptr());
@@ -159,7 +159,7 @@ struct type_caster<MlirDialectRegistry> {
159159
/// Casts object <-> MlirLocation.
160160
template <>
161161
struct type_caster<MlirLocation> {
162-
NB_TYPE_CASTER(MlirLocation, const_name("MlirLocation"));
162+
NB_TYPE_CASTER(MlirLocation, const_name("MlirLocation"))
163163
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
164164
if (src.is_none()) {
165165
// Gets the current thread-bound context.
@@ -185,7 +185,7 @@ struct type_caster<MlirLocation> {
185185
/// Casts object <-> MlirModule.
186186
template <>
187187
struct type_caster<MlirModule> {
188-
NB_TYPE_CASTER(MlirModule, const_name("MlirModule"));
188+
NB_TYPE_CASTER(MlirModule, const_name("MlirModule"))
189189
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
190190
nanobind::object capsule = mlirApiObjectToCapsule(src);
191191
value = mlirPythonCapsuleToModule(capsule.ptr());
@@ -206,7 +206,7 @@ struct type_caster<MlirModule> {
206206
template <>
207207
struct type_caster<MlirFrozenRewritePatternSet> {
208208
NB_TYPE_CASTER(MlirFrozenRewritePatternSet,
209-
const_name("MlirFrozenRewritePatternSet"));
209+
const_name("MlirFrozenRewritePatternSet"))
210210
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
211211
nanobind::object capsule = mlirApiObjectToCapsule(src);
212212
value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
@@ -225,7 +225,7 @@ struct type_caster<MlirFrozenRewritePatternSet> {
225225
/// Casts object <-> MlirOperation.
226226
template <>
227227
struct type_caster<MlirOperation> {
228-
NB_TYPE_CASTER(MlirOperation, const_name("MlirOperation"));
228+
NB_TYPE_CASTER(MlirOperation, const_name("MlirOperation"))
229229
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
230230
nanobind::object capsule = mlirApiObjectToCapsule(src);
231231
value = mlirPythonCapsuleToOperation(capsule.ptr());
@@ -247,7 +247,7 @@ struct type_caster<MlirOperation> {
247247
/// Casts object <-> MlirValue.
248248
template <>
249249
struct type_caster<MlirValue> {
250-
NB_TYPE_CASTER(MlirValue, const_name("MlirValue"));
250+
NB_TYPE_CASTER(MlirValue, const_name("MlirValue"))
251251
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
252252
nanobind::object capsule = mlirApiObjectToCapsule(src);
253253
value = mlirPythonCapsuleToValue(capsule.ptr());
@@ -270,7 +270,7 @@ struct type_caster<MlirValue> {
270270
/// Casts object -> MlirPassManager.
271271
template <>
272272
struct type_caster<MlirPassManager> {
273-
NB_TYPE_CASTER(MlirPassManager, const_name("MlirPassManager"));
273+
NB_TYPE_CASTER(MlirPassManager, const_name("MlirPassManager"))
274274
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
275275
nanobind::object capsule = mlirApiObjectToCapsule(src);
276276
value = mlirPythonCapsuleToPassManager(capsule.ptr());
@@ -281,7 +281,7 @@ struct type_caster<MlirPassManager> {
281281
/// Casts object <-> MlirTypeID.
282282
template <>
283283
struct type_caster<MlirTypeID> {
284-
NB_TYPE_CASTER(MlirTypeID, const_name("MlirTypeID"));
284+
NB_TYPE_CASTER(MlirTypeID, const_name("MlirTypeID"))
285285
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
286286
nanobind::object capsule = mlirApiObjectToCapsule(src);
287287
value = mlirPythonCapsuleToTypeID(capsule.ptr());
@@ -303,7 +303,7 @@ struct type_caster<MlirTypeID> {
303303
/// Casts object <-> MlirType.
304304
template <>
305305
struct type_caster<MlirType> {
306-
NB_TYPE_CASTER(MlirType, const_name("MlirType"));
306+
NB_TYPE_CASTER(MlirType, const_name("MlirType"))
307307
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
308308
nanobind::object capsule = mlirApiObjectToCapsule(src);
309309
value = mlirPythonCapsuleToType(capsule.ptr());
@@ -631,40 +631,6 @@ class mlir_value_subclass : public pure_subclass {
631631

632632
} // namespace nanobind_adaptors
633633

634-
/// RAII scope intercepting all diagnostics into a string. The message must be
635-
/// checked before this goes out of scope.
636-
class CollectDiagnosticsToStringScope {
637-
public:
638-
explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
639-
handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
640-
/*deleteUserData=*/nullptr);
641-
}
642-
~CollectDiagnosticsToStringScope() {
643-
assert(errorMessage.empty() && "unchecked error message");
644-
mlirContextDetachDiagnosticHandler(context, handlerID);
645-
}
646-
647-
[[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
648-
649-
private:
650-
static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
651-
auto printer = +[](MlirStringRef message, void *data) {
652-
*static_cast<std::string *>(data) +=
653-
llvm::StringRef(message.data, message.length);
654-
};
655-
MlirLocation loc = mlirDiagnosticGetLocation(diag);
656-
*static_cast<std::string *>(data) += "at ";
657-
mlirLocationPrint(loc, printer, data);
658-
*static_cast<std::string *>(data) += ": ";
659-
mlirDiagnosticPrint(diag, printer, data);
660-
return mlirLogicalResultSuccess();
661-
}
662-
663-
MlirContext context;
664-
MlirDiagnosticHandlerID handlerID;
665-
std::string errorMessage = "";
666-
};
667-
668634
} // namespace python
669635
} // namespace mlir
670636

mlir/lib/Bindings/Python/AsyncPasses.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,13 @@
88

99
#include "mlir-c/Dialect/Async.h"
1010

11-
#include <pybind11/detail/common.h>
12-
#include <pybind11/pybind11.h>
11+
#include <nanobind/nanobind.h>
1312

1413
// -----------------------------------------------------------------------------
1514
// Module initialization.
1615
// -----------------------------------------------------------------------------
1716

18-
PYBIND11_MODULE(_mlirAsyncPasses, m) {
17+
NB_MODULE(_mlirAsyncPasses, m) {
1918
m.doc() = "MLIR Async Dialect Passes";
2019

2120
// Register all Async passes on load.

mlir/lib/Bindings/Python/DialectGPU.cpp

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,23 @@
99
#include "mlir-c/Dialect/GPU.h"
1010
#include "mlir-c/IR.h"
1111
#include "mlir-c/Support.h"
12-
#include "mlir/Bindings/Python/PybindAdaptors.h"
12+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
1313

14-
#include <pybind11/detail/common.h>
15-
#include <pybind11/pybind11.h>
14+
#include <nanobind/nanobind.h>
15+
#include <nanobind/stl/optional.h>
16+
17+
namespace nb = nanobind;
18+
using namespace nanobind::literals;
1619

17-
namespace py = pybind11;
1820
using namespace mlir;
1921
using namespace mlir::python;
20-
using namespace mlir::python::adaptors;
22+
using namespace mlir::python::nanobind_adaptors;
2123

2224
// -----------------------------------------------------------------------------
2325
// Module initialization.
2426
// -----------------------------------------------------------------------------
2527

26-
PYBIND11_MODULE(_mlirDialectsGPU, m) {
28+
NB_MODULE(_mlirDialectsGPU, m) {
2729
m.doc() = "MLIR GPU Dialect";
2830
//===-------------------------------------------------------------------===//
2931
// AsyncTokenType
@@ -34,11 +36,11 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
3436

3537
mlirGPUAsyncTokenType.def_classmethod(
3638
"get",
37-
[](py::object cls, MlirContext ctx) {
39+
[](nb::object cls, MlirContext ctx) {
3840
return cls(mlirGPUAsyncTokenTypeGet(ctx));
3941
},
40-
"Gets an instance of AsyncTokenType in the same context", py::arg("cls"),
41-
py::arg("ctx") = py::none());
42+
"Gets an instance of AsyncTokenType in the same context", nb::arg("cls"),
43+
nb::arg("ctx").none() = nb::none());
4244

4345
//===-------------------------------------------------------------------===//
4446
// ObjectAttr
@@ -47,12 +49,12 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
4749
mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr)
4850
.def_classmethod(
4951
"get",
50-
[](py::object cls, MlirAttribute target, uint32_t format,
51-
py::bytes object, std::optional<MlirAttribute> mlirObjectProps,
52+
[](nb::object cls, MlirAttribute target, uint32_t format,
53+
nb::bytes object, std::optional<MlirAttribute> mlirObjectProps,
5254
std::optional<MlirAttribute> mlirKernelsAttr) {
53-
py::buffer_info info(py::buffer(object).request());
54-
MlirStringRef objectStrRef =
55-
mlirStringRefCreate(static_cast<char *>(info.ptr), info.size);
55+
MlirStringRef objectStrRef = mlirStringRefCreate(
56+
static_cast<char *>(const_cast<void *>(object.data())),
57+
object.size());
5658
return cls(mlirGPUObjectAttrGetWithKernels(
5759
mlirAttributeGetContext(target), target, format, objectStrRef,
5860
mlirObjectProps.has_value() ? *mlirObjectProps
@@ -61,7 +63,7 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
6163
: MlirAttribute{nullptr}));
6264
},
6365
"cls"_a, "target"_a, "format"_a, "object"_a,
64-
"properties"_a = py::none(), "kernels"_a = py::none(),
66+
"properties"_a.none() = nb::none(), "kernels"_a.none() = nb::none(),
6567
"Gets a gpu.object from parameters.")
6668
.def_property_readonly(
6769
"target",
@@ -73,18 +75,18 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
7375
"object",
7476
[](MlirAttribute self) {
7577
MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
76-
return py::bytes(stringRef.data, stringRef.length);
78+
return nb::bytes(stringRef.data, stringRef.length);
7779
})
7880
.def_property_readonly("properties",
79-
[](MlirAttribute self) {
81+
[](MlirAttribute self) -> nb::object {
8082
if (mlirGPUObjectAttrHasProperties(self))
81-
return py::cast(
83+
return nb::cast(
8284
mlirGPUObjectAttrGetProperties(self));
83-
return py::none().cast<py::object>();
85+
return nb::none();
8486
})
85-
.def_property_readonly("kernels", [](MlirAttribute self) {
87+
.def_property_readonly("kernels", [](MlirAttribute self) -> nb::object {
8688
if (mlirGPUObjectAttrHasKernels(self))
87-
return py::cast(mlirGPUObjectAttrGetKernels(self));
88-
return py::none().cast<py::object>();
89+
return nb::cast(mlirGPUObjectAttrGetKernels(self));
90+
return nb::none();
8991
});
9092
}

0 commit comments

Comments
 (0)