Skip to content

Commit 79d4d16

Browse files
authored
[mlir][python] extend LLVM bindings (#89797)
Add bindings for LLVM pointer type.
1 parent 6e9ea6e commit 79d4d16

File tree

6 files changed

+102
-8
lines changed

6 files changed

+102
-8
lines changed

mlir/include/mlir-c/Dialect/LLVM.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(LLVM, llvm);
2323
MLIR_CAPI_EXPORTED MlirType mlirLLVMPointerTypeGet(MlirContext ctx,
2424
unsigned addressSpace);
2525

26+
/// Returns `true` if the type is an LLVM dialect pointer type.
27+
MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMPointerType(MlirType type);
28+
29+
/// Returns address space of llvm.ptr
30+
MLIR_CAPI_EXPORTED unsigned
31+
mlirLLVMPointerTypeGetAddressSpace(MlirType pointerType);
32+
2633
/// Creates an llmv.void type.
2734
MLIR_CAPI_EXPORTED MlirType mlirLLVMVoidTypeGet(MlirContext ctx);
2835

mlir/lib/Bindings/Python/DialectLLVM.cpp

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ using namespace mlir::python;
1919
using namespace mlir::python::adaptors;
2020

2121
void populateDialectLLVMSubmodule(const pybind11::module &m) {
22+
23+
//===--------------------------------------------------------------------===//
24+
// StructType
25+
//===--------------------------------------------------------------------===//
26+
2227
auto llvmStructType =
2328
mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);
2429

@@ -35,25 +40,24 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
3540
}
3641
return cls(type);
3742
},
38-
py::arg("cls"), py::arg("elements"), py::kw_only(),
39-
py::arg("packed") = false, py::arg("loc") = py::none());
43+
"cls"_a, "elements"_a, py::kw_only(), "packed"_a = false,
44+
"loc"_a = py::none());
4045

4146
llvmStructType.def_classmethod(
4247
"get_identified",
4348
[](py::object cls, const std::string &name, MlirContext context) {
4449
return cls(mlirLLVMStructTypeIdentifiedGet(
4550
context, mlirStringRefCreate(name.data(), name.size())));
4651
},
47-
py::arg("cls"), py::arg("name"), py::kw_only(),
48-
py::arg("context") = py::none());
52+
"cls"_a, "name"_a, py::kw_only(), "context"_a = py::none());
4953

5054
llvmStructType.def_classmethod(
5155
"get_opaque",
5256
[](py::object cls, const std::string &name, MlirContext context) {
5357
return cls(mlirLLVMStructTypeOpaqueGet(
5458
context, mlirStringRefCreate(name.data(), name.size())));
5559
},
56-
py::arg("cls"), py::arg("name"), py::arg("context") = py::none());
60+
"cls"_a, "name"_a, "context"_a = py::none());
5761

5862
llvmStructType.def(
5963
"set_body",
@@ -65,7 +69,7 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
6569
"Struct body already set to different content.");
6670
}
6771
},
68-
py::arg("elements"), py::kw_only(), py::arg("packed") = false);
72+
"elements"_a, py::kw_only(), "packed"_a = false);
6973

7074
llvmStructType.def_classmethod(
7175
"new_identified",
@@ -75,8 +79,8 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
7579
ctx, mlirStringRefCreate(name.data(), name.length()),
7680
elements.size(), elements.data(), packed));
7781
},
78-
py::arg("cls"), py::arg("name"), py::arg("elements"), py::kw_only(),
79-
py::arg("packed") = false, py::arg("context") = py::none());
82+
"cls"_a, "name"_a, "elements"_a, py::kw_only(), "packed"_a = false,
83+
"context"_a = py::none());
8084

8185
llvmStructType.def_property_readonly(
8286
"name", [](MlirType type) -> std::optional<std::string> {
@@ -105,6 +109,29 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
105109

106110
llvmStructType.def_property_readonly(
107111
"opaque", [](MlirType type) { return mlirLLVMStructTypeIsOpaque(type); });
112+
113+
//===--------------------------------------------------------------------===//
114+
// PointerType
115+
//===--------------------------------------------------------------------===//
116+
117+
mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType)
118+
.def_classmethod(
119+
"get",
120+
[](py::object cls, std::optional<unsigned> addressSpace,
121+
MlirContext context) {
122+
CollectDiagnosticsToStringScope scope(context);
123+
MlirType type = mlirLLVMPointerTypeGet(
124+
context, addressSpace.has_value() ? *addressSpace : 0);
125+
if (mlirTypeIsNull(type)) {
126+
throw py::value_error(scope.takeMessage());
127+
}
128+
return cls(type);
129+
},
130+
"cls"_a, "address_space"_a = py::none(), py::kw_only(),
131+
"context"_a = py::none())
132+
.def_property_readonly("address_space", [](MlirType type) {
133+
return mlirLLVMPointerTypeGetAddressSpace(type);
134+
});
108135
}
109136

110137
PYBIND11_MODULE(_mlirDialectsLLVM, m) {

mlir/lib/CAPI/Dialect/LLVM.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@ MlirType mlirLLVMPointerTypeGet(MlirContext ctx, unsigned addressSpace) {
2727
return wrap(LLVMPointerType::get(unwrap(ctx), addressSpace));
2828
}
2929

30+
bool mlirTypeIsALLVMPointerType(MlirType type) {
31+
return isa<LLVM::LLVMPointerType>(unwrap(type));
32+
}
33+
34+
unsigned mlirLLVMPointerTypeGetAddressSpace(MlirType pointerType) {
35+
return cast<LLVM::LLVMPointerType>(unwrap(pointerType)).getAddressSpace();
36+
}
37+
3038
MlirType mlirLLVMVoidTypeGet(MlirContext ctx) {
3139
return wrap(LLVMVoidType::get(unwrap(ctx)));
3240
}

mlir/python/mlir/dialects/LLVMOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@
1010
#define PYTHON_BINDINGS_LLVM_OPS
1111

1212
include "mlir/Dialect/LLVMIR/LLVMOps.td"
13+
include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td"
1314

1415
#endif

mlir/python/mlir/dialects/llvm.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,11 @@
55
from ._llvm_ops_gen import *
66
from ._llvm_enum_gen import *
77
from .._mlir_libs._mlirDialectsLLVM import *
8+
from ..ir import Value
9+
from ._ods_common import get_op_result_or_op_results as _get_op_result_or_op_results
10+
11+
12+
def mlir_constant(value, *, loc=None, ip=None) -> Value:
13+
return _get_op_result_or_op_results(
14+
ConstantOp(res=value.type, value=value, loc=loc, ip=ip)
15+
)

mlir/test/python/dialects/llvm.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,46 @@ def testSmoke():
107107
)
108108
result = llvm.UndefOp(mat64f32_t)
109109
# CHECK: %0 = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
110+
111+
112+
# CHECK-LABEL: testPointerType
113+
@constructAndPrintInModule
114+
def testPointerType():
115+
ptr = llvm.PointerType.get()
116+
# CHECK: !llvm.ptr
117+
print(ptr)
118+
119+
ptr_with_addr = llvm.PointerType.get(1)
120+
# CHECK: !llvm.ptr<1>
121+
print(ptr_with_addr)
122+
123+
124+
# CHECK-LABEL: testConstant
125+
@constructAndPrintInModule
126+
def testConstant():
127+
i32 = IntegerType.get_signless(32)
128+
c_128 = llvm.mlir_constant(IntegerAttr.get(i32, 128))
129+
# CHECK: %{{.*}} = llvm.mlir.constant(128 : i32) : i32
130+
print(c_128.owner)
131+
132+
133+
# CHECK-LABEL: testIntrinsics
134+
@constructAndPrintInModule
135+
def testIntrinsics():
136+
i32 = IntegerType.get_signless(32)
137+
ptr = llvm.PointerType.get()
138+
c_128 = llvm.mlir_constant(IntegerAttr.get(i32, 128))
139+
# CHECK: %[[CST128:.*]] = llvm.mlir.constant(128 : i32) : i32
140+
print(c_128.owner)
141+
142+
alloca = llvm.alloca(ptr, c_128, i32)
143+
# CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[CST128]] x i32 : (i32) -> !llvm.ptr
144+
print(alloca.owner)
145+
146+
c_0 = llvm.mlir_constant(IntegerAttr.get(IntegerType.get_signless(8), 0))
147+
# CHECK: %[[CST0:.+]] = llvm.mlir.constant(0 : i8) : i8
148+
print(c_0.owner)
149+
150+
result = llvm.intr_memset(alloca, c_0, c_128, False)
151+
# CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[CST0]], %[[CST128]]) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
152+
print(result)

0 commit comments

Comments
 (0)