-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][python] extend LLVM bindings #89797
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir Author: Maksim Levental (makslevental) ChangesFull diff: https://github.com/llvm/llvm-project/pull/89797.diff 6 Files Affected:
diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h
index bd9b7dd26f5e9e..b3e64bd68f7b1c 100644
--- a/mlir/include/mlir-c/Dialect/LLVM.h
+++ b/mlir/include/mlir-c/Dialect/LLVM.h
@@ -23,6 +23,13 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(LLVM, llvm);
MLIR_CAPI_EXPORTED MlirType mlirLLVMPointerTypeGet(MlirContext ctx,
unsigned addressSpace);
+/// Returns `true` if the type is an LLVM dialect pointer type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMPointerType(MlirType type);
+
+/// Returns address space of llvm.ptr
+MLIR_CAPI_EXPORTED unsigned
+mlirLLVMPointerTypeGetAddressSpace(MlirType pointerType);
+
/// Creates an llmv.void type.
MLIR_CAPI_EXPORTED MlirType mlirLLVMVoidTypeGet(MlirContext ctx);
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index 843707751dd849..42a4c8c0793ba8 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -19,6 +19,11 @@ using namespace mlir::python;
using namespace mlir::python::adaptors;
void populateDialectLLVMSubmodule(const pybind11::module &m) {
+
+ //===--------------------------------------------------------------------===//
+ // StructType
+ //===--------------------------------------------------------------------===//
+
auto llvmStructType =
mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);
@@ -35,8 +40,8 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
}
return cls(type);
},
- py::arg("cls"), py::arg("elements"), py::kw_only(),
- py::arg("packed") = false, py::arg("loc") = py::none());
+ "cls"_a, "elements"_a, py::kw_only(), "packed"_a = false,
+ "loc"_a = py::none());
llvmStructType.def_classmethod(
"get_identified",
@@ -44,8 +49,7 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
return cls(mlirLLVMStructTypeIdentifiedGet(
context, mlirStringRefCreate(name.data(), name.size())));
},
- py::arg("cls"), py::arg("name"), py::kw_only(),
- py::arg("context") = py::none());
+ "cls"_a, "name"_a, py::kw_only(), "context"_a = py::none());
llvmStructType.def_classmethod(
"get_opaque",
@@ -53,7 +57,7 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
return cls(mlirLLVMStructTypeOpaqueGet(
context, mlirStringRefCreate(name.data(), name.size())));
},
- py::arg("cls"), py::arg("name"), py::arg("context") = py::none());
+ "cls"_a, "name"_a, "context"_a = py::none());
llvmStructType.def(
"set_body",
@@ -65,7 +69,7 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
"Struct body already set to different content.");
}
},
- py::arg("elements"), py::kw_only(), py::arg("packed") = false);
+ "elements"_a, py::kw_only(), "packed"_a = false);
llvmStructType.def_classmethod(
"new_identified",
@@ -75,8 +79,8 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
ctx, mlirStringRefCreate(name.data(), name.length()),
elements.size(), elements.data(), packed));
},
- py::arg("cls"), py::arg("name"), py::arg("elements"), py::kw_only(),
- py::arg("packed") = false, py::arg("context") = py::none());
+ "cls"_a, "name"_a, "elements"_a, py::kw_only(), "packed"_a = false,
+ "context"_a = py::none());
llvmStructType.def_property_readonly(
"name", [](MlirType type) -> std::optional<std::string> {
@@ -105,6 +109,29 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
llvmStructType.def_property_readonly(
"opaque", [](MlirType type) { return mlirLLVMStructTypeIsOpaque(type); });
+
+ //===--------------------------------------------------------------------===//
+ // PointerType
+ //===--------------------------------------------------------------------===//
+
+ mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType)
+ .def_classmethod(
+ "get",
+ [](py::object cls, std::optional<unsigned> addressSpace,
+ MlirContext context) {
+ CollectDiagnosticsToStringScope scope(context);
+ MlirType type = mlirLLVMPointerTypeGet(
+ context, addressSpace.has_value() ? *addressSpace : 0);
+ if (mlirTypeIsNull(type)) {
+ throw py::value_error(scope.takeMessage());
+ }
+ return cls(type);
+ },
+ "cls"_a, "address_space"_a = py::none(), py::kw_only(),
+ "context"_a = py::none())
+ .def_property_readonly("address_space", [](MlirType type) {
+ return mlirLLVMPointerTypeGetAddressSpace(type);
+ });
}
PYBIND11_MODULE(_mlirDialectsLLVM, m) {
diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp
index 4669c40f843d94..cd817539bb83a0 100644
--- a/mlir/lib/CAPI/Dialect/LLVM.cpp
+++ b/mlir/lib/CAPI/Dialect/LLVM.cpp
@@ -27,6 +27,14 @@ MlirType mlirLLVMPointerTypeGet(MlirContext ctx, unsigned addressSpace) {
return wrap(LLVMPointerType::get(unwrap(ctx), addressSpace));
}
+bool mlirTypeIsALLVMPointerType(MlirType type) {
+ return isa<LLVM::LLVMPointerType>(unwrap(type));
+}
+
+unsigned mlirLLVMPointerTypeGetAddressSpace(MlirType pointerType) {
+ return cast<LLVM::LLVMPointerType>(unwrap(pointerType)).getAddressSpace();
+}
+
MlirType mlirLLVMVoidTypeGet(MlirContext ctx) {
return wrap(LLVMVoidType::get(unwrap(ctx)));
}
diff --git a/mlir/python/mlir/dialects/LLVMOps.td b/mlir/python/mlir/dialects/LLVMOps.td
index dcf2f4245cf49f..30f047f21698e3 100644
--- a/mlir/python/mlir/dialects/LLVMOps.td
+++ b/mlir/python/mlir/dialects/LLVMOps.td
@@ -10,5 +10,6 @@
#define PYTHON_BINDINGS_LLVM_OPS
include "mlir/Dialect/LLVMIR/LLVMOps.td"
+include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td"
#endif
diff --git a/mlir/python/mlir/dialects/llvm.py b/mlir/python/mlir/dialects/llvm.py
index 8aa16e4a256030..941a584966dcde 100644
--- a/mlir/python/mlir/dialects/llvm.py
+++ b/mlir/python/mlir/dialects/llvm.py
@@ -5,3 +5,11 @@
from ._llvm_ops_gen import *
from ._llvm_enum_gen import *
from .._mlir_libs._mlirDialectsLLVM import *
+from ..ir import Value
+from ._ods_common import get_op_result_or_op_results as _get_op_result_or_op_results
+
+
+def mlir_constant(value, *, loc=None, ip=None) -> Value:
+ return _get_op_result_or_op_results(
+ ConstantOp(res=value.type, value=value, loc=loc, ip=ip)
+ )
diff --git a/mlir/test/python/dialects/llvm.py b/mlir/test/python/dialects/llvm.py
index fb4b343b170bae..d9ffdeb65bfd40 100644
--- a/mlir/test/python/dialects/llvm.py
+++ b/mlir/test/python/dialects/llvm.py
@@ -107,3 +107,46 @@ def testSmoke():
)
result = llvm.UndefOp(mat64f32_t)
# CHECK: %0 = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+
+
+# CHECK-LABEL: testPointerType
+@constructAndPrintInModule
+def testPointerType():
+ ptr = llvm.PointerType.get()
+ # CHECK: !llvm.ptr
+ print(ptr)
+
+ ptr_with_addr = llvm.PointerType.get(1)
+ # CHECK: !llvm.ptr<1>
+ print(ptr_with_addr)
+
+
+# CHECK-LABEL: testConstant
+@constructAndPrintInModule
+def testConstant():
+ i32 = IntegerType.get_signless(32)
+ c_128 = llvm.mlir_constant(IntegerAttr.get(i32, 128))
+ # CHECK: %{{.*}} = llvm.mlir.constant(128 : i32) : i32
+ print(c_128.owner)
+
+
+# CHECK-LABEL: testIntrinsics
+@constructAndPrintInModule
+def testIntrinsics():
+ i32 = IntegerType.get_signless(32)
+ ptr = llvm.PointerType.get()
+ c_128 = llvm.mlir_constant(IntegerAttr.get(i32, 128))
+ # CHECK: %[[CST128:.*]] = llvm.mlir.constant(128 : i32) : i32
+ print(c_128.owner)
+
+ alloca = llvm.alloca(ptr, c_128, i32)
+ # CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[CST128]] x i32 : (i32) -> !llvm.ptr
+ print(alloca.owner)
+
+ c_0 = llvm.mlir_constant(IntegerAttr.get(IntegerType.get_signless(8), 0))
+ # CHECK: %[[CST0:.+]] = llvm.mlir.constant(0 : i8) : i8
+ print(c_0.owner)
+
+ result = llvm.intr_memset(alloca, c_0, c_128, False)
+ # CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[CST0]], %[[CST128]]) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ print(result)
|
ftynse
approved these changes
Apr 23, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Add bindings for LLVM pointer type.