Skip to content

[mlir][python] extend LLVM bindings #89797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 24, 2024

Conversation

makslevental
Copy link
Contributor

@makslevental makslevental commented Apr 23, 2024

Add bindings for LLVM pointer type.

@llvmbot
Copy link
Member

llvmbot commented Apr 23, 2024

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/89797.diff

6 Files Affected:

  • (modified) mlir/include/mlir-c/Dialect/LLVM.h (+7)
  • (modified) mlir/lib/Bindings/Python/DialectLLVM.cpp (+35-8)
  • (modified) mlir/lib/CAPI/Dialect/LLVM.cpp (+8)
  • (modified) mlir/python/mlir/dialects/LLVMOps.td (+1)
  • (modified) mlir/python/mlir/dialects/llvm.py (+8)
  • (modified) mlir/test/python/dialects/llvm.py (+43)
diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h
index bd9b7dd26f5e9e..b3e64bd68f7b1c 100644
--- a/mlir/include/mlir-c/Dialect/LLVM.h
+++ b/mlir/include/mlir-c/Dialect/LLVM.h
@@ -23,6 +23,13 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(LLVM, llvm);
 MLIR_CAPI_EXPORTED MlirType mlirLLVMPointerTypeGet(MlirContext ctx,
                                                    unsigned addressSpace);
 
+/// Returns `true` if the type is an LLVM dialect pointer type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMPointerType(MlirType type);
+
+/// Returns address space of llvm.ptr
+MLIR_CAPI_EXPORTED unsigned
+mlirLLVMPointerTypeGetAddressSpace(MlirType pointerType);
+
 /// Creates an llmv.void type.
 MLIR_CAPI_EXPORTED MlirType mlirLLVMVoidTypeGet(MlirContext ctx);
 
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index 843707751dd849..42a4c8c0793ba8 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -19,6 +19,11 @@ using namespace mlir::python;
 using namespace mlir::python::adaptors;
 
 void populateDialectLLVMSubmodule(const pybind11::module &m) {
+
+  //===--------------------------------------------------------------------===//
+  // StructType
+  //===--------------------------------------------------------------------===//
+
   auto llvmStructType =
       mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);
 
@@ -35,8 +40,8 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
         }
         return cls(type);
       },
-      py::arg("cls"), py::arg("elements"), py::kw_only(),
-      py::arg("packed") = false, py::arg("loc") = py::none());
+      "cls"_a, "elements"_a, py::kw_only(), "packed"_a = false,
+      "loc"_a = py::none());
 
   llvmStructType.def_classmethod(
       "get_identified",
@@ -44,8 +49,7 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
         return cls(mlirLLVMStructTypeIdentifiedGet(
             context, mlirStringRefCreate(name.data(), name.size())));
       },
-      py::arg("cls"), py::arg("name"), py::kw_only(),
-      py::arg("context") = py::none());
+      "cls"_a, "name"_a, py::kw_only(), "context"_a = py::none());
 
   llvmStructType.def_classmethod(
       "get_opaque",
@@ -53,7 +57,7 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
         return cls(mlirLLVMStructTypeOpaqueGet(
             context, mlirStringRefCreate(name.data(), name.size())));
       },
-      py::arg("cls"), py::arg("name"), py::arg("context") = py::none());
+      "cls"_a, "name"_a, "context"_a = py::none());
 
   llvmStructType.def(
       "set_body",
@@ -65,7 +69,7 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
               "Struct body already set to different content.");
         }
       },
-      py::arg("elements"), py::kw_only(), py::arg("packed") = false);
+      "elements"_a, py::kw_only(), "packed"_a = false);
 
   llvmStructType.def_classmethod(
       "new_identified",
@@ -75,8 +79,8 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
             ctx, mlirStringRefCreate(name.data(), name.length()),
             elements.size(), elements.data(), packed));
       },
-      py::arg("cls"), py::arg("name"), py::arg("elements"), py::kw_only(),
-      py::arg("packed") = false, py::arg("context") = py::none());
+      "cls"_a, "name"_a, "elements"_a, py::kw_only(), "packed"_a = false,
+      "context"_a = py::none());
 
   llvmStructType.def_property_readonly(
       "name", [](MlirType type) -> std::optional<std::string> {
@@ -105,6 +109,29 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
 
   llvmStructType.def_property_readonly(
       "opaque", [](MlirType type) { return mlirLLVMStructTypeIsOpaque(type); });
+
+  //===--------------------------------------------------------------------===//
+  // PointerType
+  //===--------------------------------------------------------------------===//
+
+  mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType)
+      .def_classmethod(
+          "get",
+          [](py::object cls, std::optional<unsigned> addressSpace,
+             MlirContext context) {
+            CollectDiagnosticsToStringScope scope(context);
+            MlirType type = mlirLLVMPointerTypeGet(
+                context, addressSpace.has_value() ? *addressSpace : 0);
+            if (mlirTypeIsNull(type)) {
+              throw py::value_error(scope.takeMessage());
+            }
+            return cls(type);
+          },
+          "cls"_a, "address_space"_a = py::none(), py::kw_only(),
+          "context"_a = py::none())
+      .def_property_readonly("address_space", [](MlirType type) {
+        return mlirLLVMPointerTypeGetAddressSpace(type);
+      });
 }
 
 PYBIND11_MODULE(_mlirDialectsLLVM, m) {
diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp
index 4669c40f843d94..cd817539bb83a0 100644
--- a/mlir/lib/CAPI/Dialect/LLVM.cpp
+++ b/mlir/lib/CAPI/Dialect/LLVM.cpp
@@ -27,6 +27,14 @@ MlirType mlirLLVMPointerTypeGet(MlirContext ctx, unsigned addressSpace) {
   return wrap(LLVMPointerType::get(unwrap(ctx), addressSpace));
 }
 
+bool mlirTypeIsALLVMPointerType(MlirType type) {
+  return isa<LLVM::LLVMPointerType>(unwrap(type));
+}
+
+unsigned mlirLLVMPointerTypeGetAddressSpace(MlirType pointerType) {
+  return cast<LLVM::LLVMPointerType>(unwrap(pointerType)).getAddressSpace();
+}
+
 MlirType mlirLLVMVoidTypeGet(MlirContext ctx) {
   return wrap(LLVMVoidType::get(unwrap(ctx)));
 }
diff --git a/mlir/python/mlir/dialects/LLVMOps.td b/mlir/python/mlir/dialects/LLVMOps.td
index dcf2f4245cf49f..30f047f21698e3 100644
--- a/mlir/python/mlir/dialects/LLVMOps.td
+++ b/mlir/python/mlir/dialects/LLVMOps.td
@@ -10,5 +10,6 @@
 #define PYTHON_BINDINGS_LLVM_OPS
 
 include "mlir/Dialect/LLVMIR/LLVMOps.td"
+include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td"
 
 #endif
diff --git a/mlir/python/mlir/dialects/llvm.py b/mlir/python/mlir/dialects/llvm.py
index 8aa16e4a256030..941a584966dcde 100644
--- a/mlir/python/mlir/dialects/llvm.py
+++ b/mlir/python/mlir/dialects/llvm.py
@@ -5,3 +5,11 @@
 from ._llvm_ops_gen import *
 from ._llvm_enum_gen import *
 from .._mlir_libs._mlirDialectsLLVM import *
+from ..ir import Value
+from ._ods_common import get_op_result_or_op_results as _get_op_result_or_op_results
+
+
+def mlir_constant(value, *, loc=None, ip=None) -> Value:
+    return _get_op_result_or_op_results(
+        ConstantOp(res=value.type, value=value, loc=loc, ip=ip)
+    )
diff --git a/mlir/test/python/dialects/llvm.py b/mlir/test/python/dialects/llvm.py
index fb4b343b170bae..d9ffdeb65bfd40 100644
--- a/mlir/test/python/dialects/llvm.py
+++ b/mlir/test/python/dialects/llvm.py
@@ -107,3 +107,46 @@ def testSmoke():
     )
     result = llvm.UndefOp(mat64f32_t)
     # CHECK: %0 = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+
+
+# CHECK-LABEL: testPointerType
+@constructAndPrintInModule
+def testPointerType():
+    ptr = llvm.PointerType.get()
+    # CHECK: !llvm.ptr
+    print(ptr)
+
+    ptr_with_addr = llvm.PointerType.get(1)
+    # CHECK: !llvm.ptr<1>
+    print(ptr_with_addr)
+
+
+# CHECK-LABEL: testConstant
+@constructAndPrintInModule
+def testConstant():
+    i32 = IntegerType.get_signless(32)
+    c_128 = llvm.mlir_constant(IntegerAttr.get(i32, 128))
+    # CHECK: %{{.*}} = llvm.mlir.constant(128 : i32) : i32
+    print(c_128.owner)
+
+
+# CHECK-LABEL: testIntrinsics
+@constructAndPrintInModule
+def testIntrinsics():
+    i32 = IntegerType.get_signless(32)
+    ptr = llvm.PointerType.get()
+    c_128 = llvm.mlir_constant(IntegerAttr.get(i32, 128))
+    # CHECK: %[[CST128:.*]] = llvm.mlir.constant(128 : i32) : i32
+    print(c_128.owner)
+
+    alloca = llvm.alloca(ptr, c_128, i32)
+    # CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[CST128]] x i32 : (i32) -> !llvm.ptr
+    print(alloca.owner)
+
+    c_0 = llvm.mlir_constant(IntegerAttr.get(IntegerType.get_signless(8), 0))
+    # CHECK: %[[CST0:.+]] = llvm.mlir.constant(0 : i8) : i8
+    print(c_0.owner)
+
+    result = llvm.intr_memset(alloca, c_0, c_128, False)
+    # CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[CST0]], %[[CST128]]) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+    print(result)

@makslevental makslevental merged commit 79d4d16 into llvm:main Apr 24, 2024
@makslevental makslevental deleted the add_llvm_intrinsics branch April 24, 2024 12:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants