@@ -72,6 +72,23 @@ or 255), then a splat will be created.
72
72
type or if the buffer does not meet expectations.
73
73
)" ;
74
74
75
+ static const char kDenseElementsAttrGetFromListDocstring [] =
76
+ R"( Gets a DenseElementsAttr from a Python list of attributes.
77
+
78
+ Args:
79
+ attrs: A list of attributes.
80
+ type: The desired shape and type of the resulting DenseElementsAttr.
81
+ If not provided, the element type is determined based on the types
82
+ of the attributes and the shape is `[len(attrs)]`.
83
+ context: Explicit context, if not from context manager.
84
+
85
+ Returns:
86
+ DenseElementsAttr on success.
87
+
88
+ Raises:
89
+ ValueError: If the type of the attributes does not match the type specified by `shaped_type`.
90
+ )" ;
91
+
75
92
static const char kDenseResourceElementsAttrGetFromBufferDocstring [] =
76
93
R"( Gets a DenseResourceElementsAttr from a Python buffer or array.
77
94
@@ -647,6 +664,55 @@ class PyDenseElementsAttribute
647
664
static constexpr const char *pyClassName = " DenseElementsAttr" ;
648
665
using PyConcreteAttribute::PyConcreteAttribute;
649
666
667
+ static PyDenseElementsAttribute
668
+ getFromList (py::list attributes, std::optional<PyType> explicitType,
669
+ DefaultingPyMlirContext contextWrapper) {
670
+
671
+ if (py::len (attributes) == 0 ) {
672
+ throw py::value_error (" Attributes list must be non-empty" );
673
+ }
674
+
675
+ MlirType shapedType;
676
+ if (explicitType) {
677
+ if ((!mlirTypeIsAShaped (*explicitType) ||
678
+ !mlirShapedTypeHasStaticShape (*explicitType))) {
679
+ std::string message =
680
+ " Expected a static ShapedType for the shaped_type parameter: " ;
681
+ message.append (py::repr (py::cast (*explicitType)));
682
+ throw py::value_error (message);
683
+ }
684
+ shapedType = *explicitType;
685
+ } else {
686
+ SmallVector<int64_t > shape{static_cast <int64_t >(py::len (attributes))};
687
+ shapedType = mlirRankedTensorTypeGet (
688
+ shape.size (), shape.data (),
689
+ mlirAttributeGetType (pyTryCast<PyAttribute>(attributes[0 ])),
690
+ mlirAttributeGetNull ());
691
+ }
692
+
693
+ SmallVector<MlirAttribute> mlirAttributes;
694
+ mlirAttributes.reserve (py::len (attributes));
695
+ for (auto attribute : attributes) {
696
+ MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
697
+ MlirType attrType = mlirAttributeGetType (mlirAttribute);
698
+ mlirAttributes.push_back (mlirAttribute);
699
+
700
+ if (!mlirTypeEqual (mlirShapedTypeGetElementType (shapedType), attrType)) {
701
+ std::string message = " All attributes must be of the same type and "
702
+ " match the type parameter: expected=" ;
703
+ message.append (py::repr (py::cast (shapedType)));
704
+ message.append (" , but got=" );
705
+ message.append (py::repr (py::cast (attrType)));
706
+ throw py::value_error (message);
707
+ }
708
+ }
709
+
710
+ MlirAttribute elements = mlirDenseElementsAttrGet (
711
+ shapedType, mlirAttributes.size (), mlirAttributes.data ());
712
+
713
+ return PyDenseElementsAttribute (contextWrapper->getRef (), elements);
714
+ }
715
+
650
716
static PyDenseElementsAttribute
651
717
getFromBuffer (py::buffer array, bool signless,
652
718
std::optional<PyType> explicitType,
@@ -883,6 +949,10 @@ class PyDenseElementsAttribute
883
949
py::arg (" type" ) = py::none (), py::arg (" shape" ) = py::none (),
884
950
py::arg (" context" ) = py::none (),
885
951
kDenseElementsAttrGetDocstring )
952
+ .def_static (" get" , PyDenseElementsAttribute::getFromList,
953
+ py::arg (" attrs" ), py::arg (" type" ) = py::none (),
954
+ py::arg (" context" ) = py::none (),
955
+ kDenseElementsAttrGetFromListDocstring )
886
956
.def_static (" get_splat" , PyDenseElementsAttribute::getSplat,
887
957
py::arg (" shaped_type" ), py::arg (" element_attr" ),
888
958
" Gets a DenseElementsAttr where all values are the same" )
@@ -954,8 +1024,8 @@ class PyDenseElementsAttribute
954
1024
}
955
1025
}; // namespace
956
1026
957
- // / Refinement of the PyDenseElementsAttribute for attributes containing integer
958
- // / (and boolean) values. Supports element access.
1027
+ // / Refinement of the PyDenseElementsAttribute for attributes containing
1028
+ // / integer (and boolean) values. Supports element access.
959
1029
class PyDenseIntElementsAttribute
960
1030
: public PyConcreteAttribute<PyDenseIntElementsAttribute,
961
1031
PyDenseElementsAttribute> {
@@ -964,8 +1034,8 @@ class PyDenseIntElementsAttribute
964
1034
static constexpr const char *pyClassName = " DenseIntElementsAttr" ;
965
1035
using PyConcreteAttribute::PyConcreteAttribute;
966
1036
967
- // / Returns the element at the given linear position. Asserts if the index is
968
- // / out of range.
1037
+ // / Returns the element at the given linear position. Asserts if the index
1038
+ // / is out of range.
969
1039
py::int_ dunderGetItem (intptr_t pos) {
970
1040
if (pos < 0 || pos >= dunderLen ()) {
971
1041
throw py::index_error (" attempt to access out of bounds element" );
@@ -1267,7 +1337,8 @@ class PyStridedLayoutAttribute
1267
1337
return PyStridedLayoutAttribute (ctx->getRef (), attr);
1268
1338
},
1269
1339
py::arg (" rank" ), py::arg (" context" ) = py::none (),
1270
- " Gets a strided layout attribute with dynamic offset and strides of a "
1340
+ " Gets a strided layout attribute with dynamic offset and strides of "
1341
+ " a "
1271
1342
" given rank." );
1272
1343
c.def_property_readonly (
1273
1344
" offset" ,
0 commit comments