Skip to content

Commit 63a8b1a

Browse files
[mlir][python] Add bindings for mlirDenseElementsAttrGet
This change adds bindings for `mlirDenseElementsAttrGet` which accepts a list of MLIR attributes and constructs a DenseElementsAttr. This allows for creating `DenseElementsAttr`s of types not natively supported by Python (e.g. BF16) without requiring other dependencies (e.g. `numpy` + `ml-dtypes`).
1 parent dca3a6e commit 63a8b1a

File tree

2 files changed

+158
-5
lines changed

2 files changed

+158
-5
lines changed

mlir/lib/Bindings/Python/IRAttributes.cpp

Lines changed: 76 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,23 @@ or 255), then a splat will be created.
7272
type or if the buffer does not meet expectations.
7373
)";
7474

75+
static const char kDenseElementsAttrGetFromListDocstring[] =
76+
R"(Gets a DenseElementsAttr from a Python list of attributes.
77+
78+
Args:
79+
attrs: A list of attributes.
80+
type: The desired shape and type of the resulting DenseElementsAttr.
81+
If not provided, the element type is determined based on the types
82+
of the attributes and the shape is `[len(attrs)]`.
83+
context: Explicit context, if not from context manager.
84+
85+
Returns:
86+
DenseElementsAttr on success.
87+
88+
Raises:
89+
ValueError: If the type of the attributes does not match the type specified by `shaped_type`.
90+
)";
91+
7592
static const char kDenseResourceElementsAttrGetFromBufferDocstring[] =
7693
R"(Gets a DenseResourceElementsAttr from a Python buffer or array.
7794
@@ -647,6 +664,55 @@ class PyDenseElementsAttribute
647664
static constexpr const char *pyClassName = "DenseElementsAttr";
648665
using PyConcreteAttribute::PyConcreteAttribute;
649666

667+
static PyDenseElementsAttribute
668+
getFromList(py::list attributes, std::optional<PyType> explicitType,
669+
DefaultingPyMlirContext contextWrapper) {
670+
671+
if (py::len(attributes) == 0) {
672+
throw py::value_error("Attributes list must be non-empty");
673+
}
674+
675+
MlirType shapedType;
676+
if (explicitType) {
677+
if ((!mlirTypeIsAShaped(*explicitType) ||
678+
!mlirShapedTypeHasStaticShape(*explicitType))) {
679+
std::string message =
680+
"Expected a static ShapedType for the shaped_type parameter: ";
681+
message.append(py::repr(py::cast(*explicitType)));
682+
throw py::value_error(message);
683+
}
684+
shapedType = *explicitType;
685+
} else {
686+
SmallVector<int64_t> shape{static_cast<int64_t>(py::len(attributes))};
687+
shapedType = mlirRankedTensorTypeGet(
688+
shape.size(), shape.data(),
689+
mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
690+
mlirAttributeGetNull());
691+
}
692+
693+
SmallVector<MlirAttribute> mlirAttributes;
694+
mlirAttributes.reserve(py::len(attributes));
695+
for (auto attribute : attributes) {
696+
MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
697+
MlirType attrType = mlirAttributeGetType(mlirAttribute);
698+
mlirAttributes.push_back(mlirAttribute);
699+
700+
if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
701+
std::string message = "All attributes must be of the same type and "
702+
"match the type parameter: expected=";
703+
message.append(py::repr(py::cast(shapedType)));
704+
message.append(", but got=");
705+
message.append(py::repr(py::cast(attrType)));
706+
throw py::value_error(message);
707+
}
708+
}
709+
710+
MlirAttribute elements = mlirDenseElementsAttrGet(
711+
shapedType, mlirAttributes.size(), mlirAttributes.data());
712+
713+
return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
714+
}
715+
650716
static PyDenseElementsAttribute
651717
getFromBuffer(py::buffer array, bool signless,
652718
std::optional<PyType> explicitType,
@@ -883,6 +949,10 @@ class PyDenseElementsAttribute
883949
py::arg("type") = py::none(), py::arg("shape") = py::none(),
884950
py::arg("context") = py::none(),
885951
kDenseElementsAttrGetDocstring)
952+
.def_static("get", PyDenseElementsAttribute::getFromList,
953+
py::arg("attrs"), py::arg("type") = py::none(),
954+
py::arg("context") = py::none(),
955+
kDenseElementsAttrGetFromListDocstring)
886956
.def_static("get_splat", PyDenseElementsAttribute::getSplat,
887957
py::arg("shaped_type"), py::arg("element_attr"),
888958
"Gets a DenseElementsAttr where all values are the same")
@@ -954,8 +1024,8 @@ class PyDenseElementsAttribute
9541024
}
9551025
}; // namespace
9561026

957-
/// Refinement of the PyDenseElementsAttribute for attributes containing integer
958-
/// (and boolean) values. Supports element access.
1027+
/// Refinement of the PyDenseElementsAttribute for attributes containing
1028+
/// integer (and boolean) values. Supports element access.
9591029
class PyDenseIntElementsAttribute
9601030
: public PyConcreteAttribute<PyDenseIntElementsAttribute,
9611031
PyDenseElementsAttribute> {
@@ -964,8 +1034,8 @@ class PyDenseIntElementsAttribute
9641034
static constexpr const char *pyClassName = "DenseIntElementsAttr";
9651035
using PyConcreteAttribute::PyConcreteAttribute;
9661036

967-
/// Returns the element at the given linear position. Asserts if the index is
968-
/// out of range.
1037+
/// Returns the element at the given linear position. Asserts if the index
1038+
/// is out of range.
9691039
py::int_ dunderGetItem(intptr_t pos) {
9701040
if (pos < 0 || pos >= dunderLen()) {
9711041
throw py::index_error("attempt to access out of bounds element");
@@ -1267,7 +1337,8 @@ class PyStridedLayoutAttribute
12671337
return PyStridedLayoutAttribute(ctx->getRef(), attr);
12681338
},
12691339
py::arg("rank"), py::arg("context") = py::none(),
1270-
"Gets a strided layout attribute with dynamic offset and strides of a "
1340+
"Gets a strided layout attribute with dynamic offset and strides of "
1341+
"a "
12711342
"given rank.");
12721343
c.def_property_readonly(
12731344
"offset",

mlir/test/python/ir/array_attributes.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,87 @@ def testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided():
5050
print(np.array(attr))
5151

5252

53+
################################################################################
54+
# Tests of the list of attributes .get() factory method
55+
################################################################################
56+
57+
58+
# CHECK-LABEL: TEST: testGetDenseElementsFromList
59+
@run
60+
def testGetDenseElementsFromList():
61+
with Context(), Location.unknown():
62+
attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
63+
attr = DenseElementsAttr.get(attrs)
64+
65+
# CHECK: dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>
66+
print(attr)
67+
68+
69+
# CHECK-LABEL: TEST: testGetDenseElementsFromListWithExplicitType
70+
@run
71+
def testGetDenseElementsFromListWithExplicitType():
72+
with Context(), Location.unknown():
73+
attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
74+
shaped_type = ShapedType(Type.parse("tensor<2xf64>"))
75+
attr = DenseElementsAttr.get(attrs, shaped_type)
76+
77+
# CHECK: dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>
78+
print(attr)
79+
80+
81+
# CHECK-LABEL: TEST: testGetDenseElementsFromListEmptyList
82+
@run
83+
def testGetDenseElementsFromListEmptyList():
84+
with Context(), Location.unknown():
85+
attrs = []
86+
87+
try:
88+
attr = DenseElementsAttr.get(attrs)
89+
except ValueError as e:
90+
# CHECK: Attributes list must be non-empty
91+
print(e)
92+
93+
94+
# CHECK-LABEL: TEST: testGetDenseElementsFromListNonAttributeType
95+
@run
96+
def testGetDenseElementsFromListNonAttributeType():
97+
with Context(), Location.unknown():
98+
attrs = [1.0]
99+
100+
try:
101+
attr = DenseElementsAttr.get(attrs)
102+
except RuntimeError as e:
103+
# CHECK: Invalid attribute when attempting to create an ArrayAttribute
104+
print(e)
105+
106+
107+
# CHECK-LABEL: TEST: testGetDenseElementsFromListMismatchedType
108+
@run
109+
def testGetDenseElementsFromListMismatchedType():
110+
with Context(), Location.unknown():
111+
attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
112+
shaped_type = ShapedType(Type.parse("tensor<2xf32>"))
113+
114+
try:
115+
attr = DenseElementsAttr.get(attrs, shaped_type)
116+
except ValueError as e:
117+
# CHECK: All attributes must be of the same type and match the type parameter
118+
print(e)
119+
120+
121+
# CHECK-LABEL: TEST: testGetDenseElementsFromListMixedTypes
122+
@run
123+
def testGetDenseElementsFromListMixedTypes():
124+
with Context(), Location.unknown():
125+
attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F32Type.get(), 2.0)]
126+
127+
try:
128+
attr = DenseElementsAttr.get(attrs)
129+
except ValueError as e:
130+
# CHECK: All attributes must be of the same type and match the type parameter
131+
print(e)
132+
133+
53134
################################################################################
54135
# Splats.
55136
################################################################################
@@ -205,6 +286,7 @@ def testGetDenseElementsBoolSplat():
205286

206287
### float and double arrays.
207288

289+
208290
# CHECK-LABEL: TEST: testGetDenseElementsF16
209291
@run
210292
def testGetDenseElementsF16():

0 commit comments

Comments
 (0)