Skip to content

Commit c912f0e

Browse files
[mlir][python] Add bindings for mlirDenseElementsAttrGet (#91389)
This change adds bindings for `mlirDenseElementsAttrGet` which accepts a list of MLIR attributes and constructs a DenseElementsAttr. This allows for creating `DenseElementsAttr`s of types not natively supported by Python (e.g. BF16) without requiring other dependencies (e.g. `numpy` + `ml-dtypes`).
1 parent 25c021a commit c912f0e

File tree

2 files changed

+159
-0
lines changed

2 files changed

+159
-0
lines changed

mlir/lib/Bindings/Python/IRAttributes.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "PybindUtils.h"
1616

1717
#include "llvm/ADT/ScopeExit.h"
18+
#include "llvm/Support/raw_ostream.h"
1819

1920
#include "mlir-c/BuiltinAttributes.h"
2021
#include "mlir-c/BuiltinTypes.h"
@@ -72,6 +73,27 @@ or 255), then a splat will be created.
7273
type or if the buffer does not meet expectations.
7374
)";
7475

76+
static const char kDenseElementsAttrGetFromListDocstring[] =
77+
R"(Gets a DenseElementsAttr from a Python list of attributes.
78+
79+
Note that it can be expensive to construct attributes individually.
80+
For a large number of elements, consider using a Python buffer or array instead.
81+
82+
Args:
83+
attrs: A list of attributes.
84+
type: The desired shape and type of the resulting DenseElementsAttr.
85+
If not provided, the element type is determined based on the type
86+
of the 0th attribute and the shape is `[len(attrs)]`.
87+
context: Explicit context, if not from context manager.
88+
89+
Returns:
90+
DenseElementsAttr on success.
91+
92+
Raises:
93+
ValueError: If the type of the attributes does not match the type
94+
specified by `shaped_type`.
95+
)";
96+
7597
static const char kDenseResourceElementsAttrGetFromBufferDocstring[] =
7698
R"(Gets a DenseResourceElementsAttr from a Python buffer or array.
7799
@@ -647,6 +669,57 @@ class PyDenseElementsAttribute
647669
static constexpr const char *pyClassName = "DenseElementsAttr";
648670
using PyConcreteAttribute::PyConcreteAttribute;
649671

672+
static PyDenseElementsAttribute
673+
getFromList(py::list attributes, std::optional<PyType> explicitType,
674+
DefaultingPyMlirContext contextWrapper) {
675+
676+
const size_t numAttributes = py::len(attributes);
677+
if (numAttributes == 0)
678+
throw py::value_error("Attributes list must be non-empty.");
679+
680+
MlirType shapedType;
681+
if (explicitType) {
682+
if ((!mlirTypeIsAShaped(*explicitType) ||
683+
!mlirShapedTypeHasStaticShape(*explicitType))) {
684+
685+
std::string message;
686+
llvm::raw_string_ostream os(message);
687+
os << "Expected a static ShapedType for the shaped_type parameter: "
688+
<< py::repr(py::cast(*explicitType));
689+
throw py::value_error(os.str());
690+
}
691+
shapedType = *explicitType;
692+
} else {
693+
SmallVector<int64_t> shape{static_cast<int64_t>(numAttributes)};
694+
shapedType = mlirRankedTensorTypeGet(
695+
shape.size(), shape.data(),
696+
mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
697+
mlirAttributeGetNull());
698+
}
699+
700+
SmallVector<MlirAttribute> mlirAttributes;
701+
mlirAttributes.reserve(numAttributes);
702+
for (const py::handle &attribute : attributes) {
703+
MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
704+
MlirType attrType = mlirAttributeGetType(mlirAttribute);
705+
mlirAttributes.push_back(mlirAttribute);
706+
707+
if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
708+
std::string message;
709+
llvm::raw_string_ostream os(message);
710+
os << "All attributes must be of the same type and match "
711+
<< "the type parameter: expected=" << py::repr(py::cast(shapedType))
712+
<< ", but got=" << py::repr(py::cast(attrType));
713+
throw py::value_error(os.str());
714+
}
715+
}
716+
717+
MlirAttribute elements = mlirDenseElementsAttrGet(
718+
shapedType, mlirAttributes.size(), mlirAttributes.data());
719+
720+
return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
721+
}
722+
650723
static PyDenseElementsAttribute
651724
getFromBuffer(py::buffer array, bool signless,
652725
std::optional<PyType> explicitType,
@@ -883,6 +956,10 @@ class PyDenseElementsAttribute
883956
py::arg("type") = py::none(), py::arg("shape") = py::none(),
884957
py::arg("context") = py::none(),
885958
kDenseElementsAttrGetDocstring)
959+
.def_static("get", PyDenseElementsAttribute::getFromList,
960+
py::arg("attrs"), py::arg("type") = py::none(),
961+
py::arg("context") = py::none(),
962+
kDenseElementsAttrGetFromListDocstring)
886963
.def_static("get_splat", PyDenseElementsAttribute::getSplat,
887964
py::arg("shaped_type"), py::arg("element_attr"),
888965
"Gets a DenseElementsAttr where all values are the same")

mlir/test/python/ir/array_attributes.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,87 @@ def testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided():
5050
print(np.array(attr))
5151

5252

53+
################################################################################
54+
# Tests of the list of attributes .get() factory method
55+
################################################################################
56+
57+
58+
# CHECK-LABEL: TEST: testGetDenseElementsFromList
59+
@run
60+
def testGetDenseElementsFromList():
61+
with Context(), Location.unknown():
62+
attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
63+
attr = DenseElementsAttr.get(attrs)
64+
65+
# CHECK: dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>
66+
print(attr)
67+
68+
69+
# CHECK-LABEL: TEST: testGetDenseElementsFromListWithExplicitType
70+
@run
71+
def testGetDenseElementsFromListWithExplicitType():
72+
with Context(), Location.unknown():
73+
attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
74+
shaped_type = ShapedType(Type.parse("tensor<2xf64>"))
75+
attr = DenseElementsAttr.get(attrs, shaped_type)
76+
77+
# CHECK: dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>
78+
print(attr)
79+
80+
81+
# CHECK-LABEL: TEST: testGetDenseElementsFromListEmptyList
82+
@run
83+
def testGetDenseElementsFromListEmptyList():
84+
with Context(), Location.unknown():
85+
attrs = []
86+
87+
try:
88+
attr = DenseElementsAttr.get(attrs)
89+
except ValueError as e:
90+
# CHECK: Attributes list must be non-empty
91+
print(e)
92+
93+
94+
# CHECK-LABEL: TEST: testGetDenseElementsFromListNonAttributeType
95+
@run
96+
def testGetDenseElementsFromListNonAttributeType():
97+
with Context(), Location.unknown():
98+
attrs = [1.0]
99+
100+
try:
101+
attr = DenseElementsAttr.get(attrs)
102+
except RuntimeError as e:
103+
# CHECK: Invalid attribute when attempting to create an ArrayAttribute
104+
print(e)
105+
106+
107+
# CHECK-LABEL: TEST: testGetDenseElementsFromListMismatchedType
108+
@run
109+
def testGetDenseElementsFromListMismatchedType():
110+
with Context(), Location.unknown():
111+
attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
112+
shaped_type = ShapedType(Type.parse("tensor<2xf32>"))
113+
114+
try:
115+
attr = DenseElementsAttr.get(attrs, shaped_type)
116+
except ValueError as e:
117+
# CHECK: All attributes must be of the same type and match the type parameter
118+
print(e)
119+
120+
121+
# CHECK-LABEL: TEST: testGetDenseElementsFromListMixedTypes
122+
@run
123+
def testGetDenseElementsFromListMixedTypes():
124+
with Context(), Location.unknown():
125+
attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F32Type.get(), 2.0)]
126+
127+
try:
128+
attr = DenseElementsAttr.get(attrs)
129+
except ValueError as e:
130+
# CHECK: All attributes must be of the same type and match the type parameter
131+
print(e)
132+
133+
53134
################################################################################
54135
# Splats.
55136
################################################################################
@@ -205,6 +286,7 @@ def testGetDenseElementsBoolSplat():
205286

206287
### float and double arrays.
207288

289+
208290
# CHECK-LABEL: TEST: testGetDenseElementsF16
209291
@run
210292
def testGetDenseElementsF16():

0 commit comments

Comments
 (0)