Skip to content

Commit 334873f

Browse files
authored
[MLIR][Python] Python binding support for IntegerSet attribute (#107640)
Support IntegerSet attribute python binding.
1 parent e50131a commit 334873f

File tree

6 files changed

+78
-1
lines changed

6 files changed

+78
-1
lines changed

mlir/include/mlir-c/BuiltinAttributes.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "mlir-c/AffineMap.h"
1818
#include "mlir-c/IR.h"
19+
#include "mlir-c/IntegerSet.h"
1920
#include "mlir-c/Support.h"
2021

2122
#ifdef __cplusplus
@@ -177,6 +178,14 @@ MLIR_CAPI_EXPORTED bool mlirBoolAttrGetValue(MlirAttribute attr);
177178
/// Checks whether the given attribute is an integer set attribute.
178179
MLIR_CAPI_EXPORTED bool mlirAttributeIsAIntegerSet(MlirAttribute attr);
179180

181+
/// Creates an integer set attribute wrapping the given set. The attribute
182+
/// belongs to the same context as the integer set.
183+
MLIR_CAPI_EXPORTED MlirAttribute mlirIntegerSetAttrGet(MlirIntegerSet set);
184+
185+
/// Returns the integer set wrapped in the given integer set attribute.
186+
MLIR_CAPI_EXPORTED MlirIntegerSet
187+
mlirIntegerSetAttrGetValue(MlirAttribute attr);
188+
180189
/// Returns the typeID of an IntegerSet attribute.
181190
MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerSetAttrGetTypeID(void);
182191

mlir/lib/Bindings/Python/IRAttributes.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,26 @@ class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
147147
}
148148
};
149149

150+
class PyIntegerSetAttribute
151+
: public PyConcreteAttribute<PyIntegerSetAttribute> {
152+
public:
153+
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet;
154+
static constexpr const char *pyClassName = "IntegerSetAttr";
155+
using PyConcreteAttribute::PyConcreteAttribute;
156+
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
157+
mlirIntegerSetAttrGetTypeID;
158+
159+
static void bindDerived(ClassTy &c) {
160+
c.def_static(
161+
"get",
162+
[](PyIntegerSet &integerSet) {
163+
MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get());
164+
return PyIntegerSetAttribute(integerSet.getContext(), attr);
165+
},
166+
py::arg("integer_set"), "Gets an attribute wrapping an IntegerSet.");
167+
}
168+
};
169+
150170
template <typename T>
151171
static T pyTryCast(py::handle object) {
152172
try {
@@ -1426,7 +1446,6 @@ py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
14261446

14271447
void mlir::python::populateIRAttributes(py::module &m) {
14281448
PyAffineMapAttribute::bind(m);
1429-
14301449
PyDenseBoolArrayAttribute::bind(m);
14311450
PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
14321451
PyDenseI8ArrayAttribute::bind(m);
@@ -1466,6 +1485,7 @@ void mlir::python::populateIRAttributes(py::module &m) {
14661485
PyOpaqueAttribute::bind(m);
14671486
PyFloatAttribute::bind(m);
14681487
PyIntegerAttribute::bind(m);
1488+
PyIntegerSetAttribute::bind(m);
14691489
PyStringAttribute::bind(m);
14701490
PyTypeAttribute::bind(m);
14711491
PyGlobals::get().registerTypeCaster(

mlir/lib/CAPI/IR/BuiltinAttributes.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "mlir-c/Support.h"
1111
#include "mlir/CAPI/AffineMap.h"
1212
#include "mlir/CAPI/IR.h"
13+
#include "mlir/CAPI/IntegerSet.h"
1314
#include "mlir/CAPI/Support.h"
1415
#include "mlir/IR/AsmState.h"
1516
#include "mlir/IR/Attributes.h"
@@ -192,6 +193,14 @@ MlirTypeID mlirIntegerSetAttrGetTypeID(void) {
192193
return wrap(IntegerSetAttr::getTypeID());
193194
}
194195

196+
MlirAttribute mlirIntegerSetAttrGet(MlirIntegerSet set) {
197+
return wrap(IntegerSetAttr::get(unwrap(set)));
198+
}
199+
200+
MlirIntegerSet mlirIntegerSetAttrGetValue(MlirAttribute attr) {
201+
return wrap(llvm::cast<IntegerSetAttr>(unwrap(attr)).getValue());
202+
}
203+
195204
//===----------------------------------------------------------------------===//
196205
// Opaque attribute.
197206
//===----------------------------------------------------------------------===//

mlir/python/mlir/_mlir_libs/_mlir/ir.pyi

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ __all__ = [
138138
"InsertionPoint",
139139
"IntegerAttr",
140140
"IntegerSet",
141+
"IntegerSetAttr",
141142
"IntegerSetConstraint",
142143
"IntegerSetConstraintList",
143144
"IntegerType",
@@ -1905,6 +1906,21 @@ class IntegerSet:
19051906
@property
19061907
def n_symbols(self) -> int: ...
19071908

1909+
class IntegerSetAttr(Attribute):
1910+
static_typeid: ClassVar[TypeID]
1911+
@staticmethod
1912+
def get(integer_set) -> IntegerSetAttr:
1913+
"""
1914+
Gets an attribute wrapping an IntegerSet.
1915+
"""
1916+
@staticmethod
1917+
def isinstance(other: Attribute) -> bool: ...
1918+
def __init__(self, cast_from_attr: Attribute) -> None: ...
1919+
@property
1920+
def type(self) -> Type: ...
1921+
@property
1922+
def typeid(self) -> TypeID: ...
1923+
19081924
class IntegerSetConstraint:
19091925
def __init__(self, *args, **kwargs) -> None: ...
19101926
@property

mlir/python/mlir/ir.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ def _affineMapAttr(x, context):
2222
return AffineMapAttr.get(x)
2323

2424

25+
@register_attribute_builder("IntegerSetAttr")
26+
def _integerSetAttr(x, context):
27+
return IntegerSetAttr.get(x)
28+
29+
2530
@register_attribute_builder("BoolAttr")
2631
def _boolAttr(x, context):
2732
return BoolAttr.get(x, context=context)

mlir/test/python/ir/attributes.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,24 @@ def testAffineMapAttr():
162162
assert attr_built == attr_parsed
163163

164164

165+
# CHECK-LABEL: TEST: testIntegerSetAttr
166+
@run
167+
def testIntegerSetAttr():
168+
with Context() as ctx:
169+
d0 = AffineDimExpr.get(0)
170+
d1 = AffineDimExpr.get(1)
171+
s0 = AffineSymbolExpr.get(0)
172+
c42 = AffineConstantExpr.get(42)
173+
set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42], [True, False])
174+
175+
# CHECK: affine_set<(d0, d1)[s0] : (d0 - d1 == 0, s0 - 42 >= 0)>
176+
attr_built = IntegerSetAttr.get(set0)
177+
print(str(attr_built))
178+
179+
attr_parsed = Attribute.parse(str(attr_built))
180+
assert attr_built == attr_parsed
181+
182+
165183
# CHECK-LABEL: TEST: testFloatAttr
166184
@run
167185
def testFloatAttr():

0 commit comments

Comments
 (0)