Skip to content

Commit 4c3f1be

Browse files
[mlir][python] Add python binding for AffineMapAttribute.
Differential Revision: https://reviews.llvm.org/D96815
1 parent 60d71a2 commit 4c3f1be

File tree

2 files changed

+41
-5
lines changed

2 files changed

+41
-5
lines changed

mlir/lib/Bindings/Python/IRModules.cpp

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1763,6 +1763,23 @@ class PyConcreteAttribute : public BaseTy {
17631763
static void bindDerived(ClassTy &m) {}
17641764
};
17651765

1766+
class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
1767+
public:
1768+
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
1769+
static constexpr const char *pyClassName = "AffineMapAttr";
1770+
using PyConcreteAttribute::PyConcreteAttribute;
1771+
1772+
static void bindDerived(ClassTy &c) {
1773+
c.def_static(
1774+
"get",
1775+
[](PyAffineMap &affineMap) {
1776+
MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
1777+
return PyAffineMapAttribute(affineMap.getContext(), attr);
1778+
},
1779+
py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
1780+
}
1781+
};
1782+
17661783
class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
17671784
public:
17681785
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
@@ -3994,17 +4011,18 @@ void mlir::python::populateIRSubmodule(py::module &m) {
39944011
"The underlying generic attribute of the NamedAttribute binding");
39954012

39964013
// Builtin attribute bindings.
3997-
PyFloatAttribute::bind(m);
4014+
PyAffineMapAttribute::bind(m);
39984015
PyArrayAttribute::bind(m);
39994016
PyArrayAttribute::PyArrayAttributeIterator::bind(m);
4000-
PyIntegerAttribute::bind(m);
40014017
PyBoolAttribute::bind(m);
4002-
PyFlatSymbolRefAttribute::bind(m);
4003-
PyStringAttribute::bind(m);
40044018
PyDenseElementsAttribute::bind(m);
4005-
PyDenseIntElementsAttribute::bind(m);
40064019
PyDenseFPElementsAttribute::bind(m);
4020+
PyDenseIntElementsAttribute::bind(m);
40074021
PyDictAttribute::bind(m);
4022+
PyFlatSymbolRefAttribute::bind(m);
4023+
PyFloatAttribute::bind(m);
4024+
PyIntegerAttribute::bind(m);
4025+
PyStringAttribute::bind(m);
40084026
PyTypeAttribute::bind(m);
40094027
PyUnitAttribute::bind(m);
40104028

mlir/test/Bindings/Python/ir_attributes.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,24 @@ def testStandardAttrCasts():
107107
run(testStandardAttrCasts)
108108

109109

110+
# CHECK-LABEL: TEST: testAffineMapAttr
111+
def testAffineMapAttr():
112+
with Context() as ctx:
113+
d0 = AffineDimExpr.get(0)
114+
d1 = AffineDimExpr.get(1)
115+
c2 = AffineConstantExpr.get(2)
116+
map0 = AffineMap.get(2, 3, [])
117+
118+
# CHECK: affine_map<(d0, d1)[s0, s1, s2] -> ()>
119+
attr_built = AffineMapAttr.get(map0)
120+
print(str(attr_built))
121+
122+
attr_parsed = Attribute.parse(str(attr_built))
123+
assert attr_built == attr_parsed
124+
125+
run(testAffineMapAttr)
126+
127+
110128
# CHECK-LABEL: TEST: testFloatAttr
111129
def testFloatAttr():
112130
with Context(), Location.unknown():

0 commit comments

Comments
 (0)