@@ -72,6 +72,26 @@ or 255), then a splat will be created.
72
72
type or if the buffer does not meet expectations.
73
73
)" ;
74
74
75
+ static const char kDenseElementsAttrGetFromListDocstring [] =
76
+ R"( Gets a DenseElementsAttr from a Python list of attributes.
77
+
78
+ Note that it can be expensive to construct attributes individually.
79
+ For a large number of elements, consider using a Python buffer or array instead.
80
+
81
+ Args:
82
+ attrs: A list of attributes.
83
+ type: The desired shape and type of the resulting DenseElementsAttr.
84
+ If not provided, the element type is determined based on the type
85
+ of the 0th attribute and the shape is `[len(attrs)]`.
86
+ context: Explicit context, if not from context manager.
87
+
88
+ Returns:
89
+ DenseElementsAttr on success.
90
+
91
+ Raises:
92
+ ValueError: If the type of the attributes does not match the type specified by `shaped_type`.
93
+ )" ;
94
+
75
95
static const char kDenseResourceElementsAttrGetFromBufferDocstring [] =
76
96
R"( Gets a DenseResourceElementsAttr from a Python buffer or array.
77
97
@@ -647,6 +667,55 @@ class PyDenseElementsAttribute
647
667
static constexpr const char *pyClassName = " DenseElementsAttr" ;
648
668
using PyConcreteAttribute::PyConcreteAttribute;
649
669
670
+ static PyDenseElementsAttribute
671
+ getFromList (py::list attributes, std::optional<PyType> explicitType,
672
+ DefaultingPyMlirContext contextWrapper) {
673
+
674
+ const size_t numAttributes = py::len (attributes);
675
+ if (numAttributes == 0 )
676
+ throw py::value_error (" Attributes list must be non-empty" );
677
+
678
+ MlirType shapedType;
679
+ if (explicitType) {
680
+ if ((!mlirTypeIsAShaped (*explicitType) ||
681
+ !mlirShapedTypeHasStaticShape (*explicitType))) {
682
+ std::string message =
683
+ " Expected a static ShapedType for the shaped_type parameter: " ;
684
+ message.append (py::repr (py::cast (*explicitType)));
685
+ throw py::value_error (message);
686
+ }
687
+ shapedType = *explicitType;
688
+ } else {
689
+ SmallVector<int64_t > shape{static_cast <int64_t >(numAttributes)};
690
+ shapedType = mlirRankedTensorTypeGet (
691
+ shape.size (), shape.data (),
692
+ mlirAttributeGetType (pyTryCast<PyAttribute>(attributes[0 ])),
693
+ mlirAttributeGetNull ());
694
+ }
695
+
696
+ SmallVector<MlirAttribute> mlirAttributes;
697
+ mlirAttributes.reserve (numAttributes);
698
+ for (auto attribute : attributes) {
699
+ MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
700
+ MlirType attrType = mlirAttributeGetType (mlirAttribute);
701
+ mlirAttributes.push_back (mlirAttribute);
702
+
703
+ if (!mlirTypeEqual (mlirShapedTypeGetElementType (shapedType), attrType)) {
704
+ std::string message = " All attributes must be of the same type and "
705
+ " match the type parameter: expected=" ;
706
+ message.append (py::repr (py::cast (shapedType)));
707
+ message.append (" , but got=" );
708
+ message.append (py::repr (py::cast (attrType)));
709
+ throw py::value_error (message);
710
+ }
711
+ }
712
+
713
+ MlirAttribute elements = mlirDenseElementsAttrGet (
714
+ shapedType, mlirAttributes.size (), mlirAttributes.data ());
715
+
716
+ return PyDenseElementsAttribute (contextWrapper->getRef (), elements);
717
+ }
718
+
650
719
static PyDenseElementsAttribute
651
720
getFromBuffer (py::buffer array, bool signless,
652
721
std::optional<PyType> explicitType,
@@ -883,6 +952,10 @@ class PyDenseElementsAttribute
883
952
py::arg (" type" ) = py::none (), py::arg (" shape" ) = py::none (),
884
953
py::arg (" context" ) = py::none (),
885
954
kDenseElementsAttrGetDocstring )
955
+ .def_static (" get" , PyDenseElementsAttribute::getFromList,
956
+ py::arg (" attrs" ), py::arg (" type" ) = py::none (),
957
+ py::arg (" context" ) = py::none (),
958
+ kDenseElementsAttrGetFromListDocstring )
886
959
.def_static (" get_splat" , PyDenseElementsAttribute::getSplat,
887
960
py::arg (" shaped_type" ), py::arg (" element_attr" ),
888
961
" Gets a DenseElementsAttr where all values are the same" )
@@ -954,8 +1027,8 @@ class PyDenseElementsAttribute
954
1027
}
955
1028
}; // namespace
956
1029
957
- // / Refinement of the PyDenseElementsAttribute for attributes containing integer
958
- // / (and boolean) values. Supports element access.
1030
+ // / Refinement of the PyDenseElementsAttribute for attributes containing
1031
+ // / integer (and boolean) values. Supports element access.
959
1032
class PyDenseIntElementsAttribute
960
1033
: public PyConcreteAttribute<PyDenseIntElementsAttribute,
961
1034
PyDenseElementsAttribute> {
@@ -964,8 +1037,8 @@ class PyDenseIntElementsAttribute
964
1037
static constexpr const char *pyClassName = " DenseIntElementsAttr" ;
965
1038
using PyConcreteAttribute::PyConcreteAttribute;
966
1039
967
- // / Returns the element at the given linear position. Asserts if the index is
968
- // / out of range.
1040
+ // / Returns the element at the given linear position. Asserts if the index
1041
+ // / is out of range.
969
1042
py::int_ dunderGetItem (intptr_t pos) {
970
1043
if (pos < 0 || pos >= dunderLen ()) {
971
1044
throw py::index_error (" attempt to access out of bounds element" );
@@ -1267,7 +1340,8 @@ class PyStridedLayoutAttribute
1267
1340
return PyStridedLayoutAttribute (ctx->getRef (), attr);
1268
1341
},
1269
1342
py::arg (" rank" ), py::arg (" context" ) = py::none (),
1270
- " Gets a strided layout attribute with dynamic offset and strides of a "
1343
+ " Gets a strided layout attribute with dynamic offset and strides of "
1344
+ " a "
1271
1345
" given rank." );
1272
1346
c.def_property_readonly (
1273
1347
" offset" ,
0 commit comments