Skip to content

Commit bd8fcf7

Browse files
authored
[mlir][python] expose LLVMStructType API (llvm#81672)
Expose the API for constructing and inspecting StructTypes from the LLVM dialect. Separate constructor methods are used instead of overloads for better readability, similarly to IntegerType.
1 parent 6c84709 commit bd8fcf7

File tree

7 files changed

+525
-3
lines changed

7 files changed

+525
-3
lines changed

mlir/include/mlir-c/Dialect/LLVM.h

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,70 @@ MLIR_CAPI_EXPORTED MlirType
3434
mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes,
3535
MlirType const *argumentTypes, bool isVarArg);
3636

37-
/// Creates an LLVM literal (unnamed) struct type.
37+
/// Returns `true` if the type is an LLVM dialect struct type.
38+
MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMStructType(MlirType type);
39+
40+
/// Returns `true` if the type is a literal (unnamed) LLVM struct type.
41+
MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsLiteral(MlirType type);
42+
43+
/// Returns the number of fields in the struct. Asserts if the struct is opaque
44+
/// or not yet initialized.
45+
MLIR_CAPI_EXPORTED intptr_t mlirLLVMStructTypeGetNumElementTypes(MlirType type);
46+
47+
/// Returns the `positions`-th field of the struct. Asserts if the struct is
48+
/// opaque, not yet initialized or if the position is out of range.
49+
MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeGetElementType(MlirType type,
50+
intptr_t position);
51+
52+
/// Returns `true` if the struct is packed.
53+
MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsPacked(MlirType type);
54+
55+
/// Returns the identifier of the identified struct. Asserts that the struct is
56+
/// identified, i.e., not literal.
57+
MLIR_CAPI_EXPORTED MlirStringRef mlirLLVMStructTypeGetIdentifier(MlirType type);
58+
59+
/// Returns `true` is the struct is explicitly opaque (will not have a body) or
60+
/// uninitialized (will eventually have a body).
61+
MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsOpaque(MlirType type);
62+
63+
/// Creates an LLVM literal (unnamed) struct type. This may assert if the fields
64+
/// have types not compatible with the LLVM dialect. For a graceful failure, use
65+
/// the checked version.
3866
MLIR_CAPI_EXPORTED MlirType
3967
mlirLLVMStructTypeLiteralGet(MlirContext ctx, intptr_t nFieldTypes,
4068
MlirType const *fieldTypes, bool isPacked);
4169

70+
/// Creates an LLVM literal (unnamed) struct type if possible. Emits a
71+
/// diagnostic at the given location and returns null otherwise.
72+
MLIR_CAPI_EXPORTED MlirType
73+
mlirLLVMStructTypeLiteralGetChecked(MlirLocation loc, intptr_t nFieldTypes,
74+
MlirType const *fieldTypes, bool isPacked);
75+
76+
/// Creates an LLVM identified struct type with no body. If a struct type with
77+
/// this name already exists in the context, returns that type. Use
78+
/// mlirLLVMStructTypeIdentifiedNewGet to create a fresh struct type,
79+
/// potentially renaming it. The body should be set separatelty by calling
80+
/// mlirLLVMStructTypeSetBody, if it isn't set already.
81+
MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeIdentifiedGet(MlirContext ctx,
82+
MlirStringRef name);
83+
84+
/// Creates an LLVM identified struct type with no body and a name starting with
85+
/// the given prefix. If a struct with the exact name as the given prefix
86+
/// already exists, appends an unspecified suffix to the name so that the name
87+
/// is unique in context.
88+
MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeIdentifiedNewGet(
89+
MlirContext ctx, MlirStringRef name, intptr_t nFieldTypes,
90+
MlirType const *fieldTypes, bool isPacked);
91+
92+
MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeOpaqueGet(MlirContext ctx,
93+
MlirStringRef name);
94+
95+
/// Sets the body of the identified struct if it hasn't been set yet. Returns
96+
/// whether the operation was successful.
97+
MLIR_CAPI_EXPORTED MlirLogicalResult
98+
mlirLLVMStructTypeSetBody(MlirType structType, intptr_t nFieldTypes,
99+
MlirType const *fieldTypes, bool isPacked);
100+
42101
#ifdef __cplusplus
43102
}
44103
#endif
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
//===- DialectLLVM.cpp - Pybind module for LLVM dialect API support -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir-c/Diagnostics.h"
10+
#include "mlir-c/Dialect/LLVM.h"
11+
#include "mlir-c/IR.h"
12+
#include "mlir-c/Support.h"
13+
#include "mlir/Bindings/Python/PybindAdaptors.h"
14+
#include <string>
15+
16+
namespace py = pybind11;
17+
using namespace llvm;
18+
using namespace mlir;
19+
using namespace mlir::python;
20+
using namespace mlir::python::adaptors;
21+
22+
/// RAII scope intercepting all diagnostics into a string. The message must be
23+
/// checked before this goes out of scope.
24+
class CollectDiagnosticsToStringScope {
25+
public:
26+
explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
27+
handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
28+
/*deleteUserData=*/nullptr);
29+
}
30+
~CollectDiagnosticsToStringScope() {
31+
assert(errorMessage.empty() && "unchecked error message");
32+
mlirContextDetachDiagnosticHandler(context, handlerID);
33+
}
34+
35+
[[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
36+
37+
private:
38+
static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
39+
auto printer = +[](MlirStringRef message, void *data) {
40+
*static_cast<std::string *>(data) +=
41+
StringRef(message.data, message.length);
42+
};
43+
mlirDiagnosticPrint(diag, printer, data);
44+
return mlirLogicalResultSuccess();
45+
}
46+
47+
MlirContext context;
48+
MlirDiagnosticHandlerID handlerID;
49+
std::string errorMessage = "";
50+
};
51+
52+
void populateDialectLLVMSubmodule(const pybind11::module &m) {
53+
auto llvmStructType =
54+
mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);
55+
56+
llvmStructType.def_classmethod(
57+
"get_literal",
58+
[](py::object cls, const std::vector<MlirType> &elements, bool packed,
59+
MlirLocation loc) {
60+
CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
61+
62+
MlirType type = mlirLLVMStructTypeLiteralGetChecked(
63+
loc, elements.size(), elements.data(), packed);
64+
if (mlirTypeIsNull(type)) {
65+
throw py::value_error(scope.takeMessage());
66+
}
67+
return cls(type);
68+
},
69+
py::arg("cls"), py::arg("elements"), py::kw_only(),
70+
py::arg("packed") = false, py::arg("loc") = py::none());
71+
72+
llvmStructType.def_classmethod(
73+
"get_identified",
74+
[](py::object cls, const std::string &name, MlirContext context) {
75+
return cls(mlirLLVMStructTypeIdentifiedGet(
76+
context, mlirStringRefCreate(name.data(), name.size())));
77+
},
78+
py::arg("cls"), py::arg("name"), py::kw_only(),
79+
py::arg("context") = py::none());
80+
81+
llvmStructType.def_classmethod(
82+
"get_opaque",
83+
[](py::object cls, const std::string &name, MlirContext context) {
84+
return cls(mlirLLVMStructTypeOpaqueGet(
85+
context, mlirStringRefCreate(name.data(), name.size())));
86+
},
87+
py::arg("cls"), py::arg("name"), py::arg("context") = py::none());
88+
89+
llvmStructType.def(
90+
"set_body",
91+
[](MlirType self, const std::vector<MlirType> &elements, bool packed) {
92+
MlirLogicalResult result = mlirLLVMStructTypeSetBody(
93+
self, elements.size(), elements.data(), packed);
94+
if (!mlirLogicalResultIsSuccess(result)) {
95+
throw py::value_error(
96+
"Struct body already set to different content.");
97+
}
98+
},
99+
py::arg("elements"), py::kw_only(), py::arg("packed") = false);
100+
101+
llvmStructType.def_classmethod(
102+
"new_identified",
103+
[](py::object cls, const std::string &name,
104+
const std::vector<MlirType> &elements, bool packed, MlirContext ctx) {
105+
return cls(mlirLLVMStructTypeIdentifiedNewGet(
106+
ctx, mlirStringRefCreate(name.data(), name.length()),
107+
elements.size(), elements.data(), packed));
108+
},
109+
py::arg("cls"), py::arg("name"), py::arg("elements"), py::kw_only(),
110+
py::arg("packed") = false, py::arg("context") = py::none());
111+
112+
llvmStructType.def_property_readonly(
113+
"name", [](MlirType type) -> std::optional<std::string> {
114+
if (mlirLLVMStructTypeIsLiteral(type))
115+
return std::nullopt;
116+
117+
MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type);
118+
return StringRef(stringRef.data, stringRef.length).str();
119+
});
120+
121+
llvmStructType.def_property_readonly("body", [](MlirType type) -> py::object {
122+
// Don't crash in absence of a body.
123+
if (mlirLLVMStructTypeIsOpaque(type))
124+
return py::none();
125+
126+
py::list body;
127+
for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type); i < e;
128+
++i) {
129+
body.append(mlirLLVMStructTypeGetElementType(type, i));
130+
}
131+
return body;
132+
});
133+
134+
llvmStructType.def_property_readonly(
135+
"packed", [](MlirType type) { return mlirLLVMStructTypeIsPacked(type); });
136+
137+
llvmStructType.def_property_readonly(
138+
"opaque", [](MlirType type) { return mlirLLVMStructTypeIsOpaque(type); });
139+
}
140+
141+
PYBIND11_MODULE(_mlirDialectsLLVM, m) {
142+
m.doc() = "MLIR LLVM Dialect";
143+
144+
populateDialectLLVMSubmodule(m);
145+
}

mlir/lib/CAPI/Dialect/LLVM.cpp

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,77 @@ MlirType mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes,
3636
unwrapList(nArgumentTypes, argumentTypes, argumentStorage), isVarArg));
3737
}
3838

39+
bool mlirTypeIsALLVMStructType(MlirType type) {
40+
return isa<LLVM::LLVMStructType>(unwrap(type));
41+
}
42+
43+
bool mlirLLVMStructTypeIsLiteral(MlirType type) {
44+
return !cast<LLVM::LLVMStructType>(unwrap(type)).isIdentified();
45+
}
46+
47+
intptr_t mlirLLVMStructTypeGetNumElementTypes(MlirType type) {
48+
return cast<LLVM::LLVMStructType>(unwrap(type)).getBody().size();
49+
}
50+
51+
MlirType mlirLLVMStructTypeGetElementType(MlirType type, intptr_t position) {
52+
return wrap(cast<LLVM::LLVMStructType>(unwrap(type)).getBody()[position]);
53+
}
54+
55+
bool mlirLLVMStructTypeIsPacked(MlirType type) {
56+
return cast<LLVM::LLVMStructType>(unwrap(type)).isPacked();
57+
}
58+
59+
MlirStringRef mlirLLVMStructTypeGetIdentifier(MlirType type) {
60+
return wrap(cast<LLVM::LLVMStructType>(unwrap(type)).getName());
61+
}
62+
63+
bool mlirLLVMStructTypeIsOpaque(MlirType type) {
64+
return cast<LLVM::LLVMStructType>(unwrap(type)).isOpaque();
65+
}
66+
3967
MlirType mlirLLVMStructTypeLiteralGet(MlirContext ctx, intptr_t nFieldTypes,
4068
MlirType const *fieldTypes,
4169
bool isPacked) {
42-
SmallVector<Type, 2> fieldStorage;
70+
SmallVector<Type> fieldStorage;
4371
return wrap(LLVMStructType::getLiteral(
4472
unwrap(ctx), unwrapList(nFieldTypes, fieldTypes, fieldStorage),
4573
isPacked));
4674
}
75+
76+
MlirType mlirLLVMStructTypeLiteralGetChecked(MlirLocation loc,
77+
intptr_t nFieldTypes,
78+
MlirType const *fieldTypes,
79+
bool isPacked) {
80+
SmallVector<Type> fieldStorage;
81+
return wrap(LLVMStructType::getLiteralChecked(
82+
[loc]() { return emitError(unwrap(loc)); }, unwrap(loc)->getContext(),
83+
unwrapList(nFieldTypes, fieldTypes, fieldStorage), isPacked));
84+
}
85+
86+
MlirType mlirLLVMStructTypeOpaqueGet(MlirContext ctx, MlirStringRef name) {
87+
return wrap(LLVMStructType::getOpaque(unwrap(name), unwrap(ctx)));
88+
}
89+
90+
MlirType mlirLLVMStructTypeIdentifiedGet(MlirContext ctx, MlirStringRef name) {
91+
return wrap(LLVMStructType::getIdentified(unwrap(ctx), unwrap(name)));
92+
}
93+
94+
MlirType mlirLLVMStructTypeIdentifiedNewGet(MlirContext ctx, MlirStringRef name,
95+
intptr_t nFieldTypes,
96+
MlirType const *fieldTypes,
97+
bool isPacked) {
98+
SmallVector<Type> fields;
99+
return wrap(LLVMStructType::getNewIdentified(
100+
unwrap(ctx), unwrap(name), unwrapList(nFieldTypes, fieldTypes, fields),
101+
isPacked));
102+
}
103+
104+
MlirLogicalResult mlirLLVMStructTypeSetBody(MlirType structType,
105+
intptr_t nFieldTypes,
106+
MlirType const *fieldTypes,
107+
bool isPacked) {
108+
SmallVector<Type> fields;
109+
return wrap(
110+
cast<LLVM::LLVMStructType>(unwrap(structType))
111+
.setBody(unwrapList(nFieldTypes, fieldTypes, fields), isPacked));
112+
}

mlir/python/CMakeLists.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,19 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind
482482
MLIRCAPILinalg
483483
)
484484

485+
declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind
486+
MODULE_NAME _mlirDialectsLLVM
487+
ADD_TO_PARENT MLIRPythonSources.Dialects.llvm
488+
ROOT_DIR "${PYTHON_SOURCE_DIR}"
489+
SOURCES
490+
DialectLLVM.cpp
491+
PRIVATE_LINK_LIBS
492+
LLVMSupport
493+
EMBED_CAPI_LINK_LIBS
494+
MLIRCAPIIR
495+
MLIRCAPILLVM
496+
)
497+
485498
declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind
486499
MODULE_NAME _mlirDialectsQuant
487500
ADD_TO_PARENT MLIRPythonSources.Dialects.quant

mlir/python/mlir/dialects/llvm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44

55
from ._llvm_ops_gen import *
66
from ._llvm_enum_gen import *
7+
from .._mlir_libs._mlirDialectsLLVM import *

0 commit comments

Comments
 (0)