|
15 | 15 | #include "PybindUtils.h"
|
16 | 16 |
|
17 | 17 | #include "llvm/ADT/ScopeExit.h"
|
| 18 | +#include "llvm/Support/raw_ostream.h" |
18 | 19 |
|
19 | 20 | #include "mlir-c/BuiltinAttributes.h"
|
20 | 21 | #include "mlir-c/BuiltinTypes.h"
|
@@ -72,6 +73,27 @@ or 255), then a splat will be created.
|
72 | 73 | type or if the buffer does not meet expectations.
|
73 | 74 | )";
|
74 | 75 |
|
| 76 | +static const char kDenseElementsAttrGetFromListDocstring[] = |
| 77 | + R"(Gets a DenseElementsAttr from a Python list of attributes. |
| 78 | +
|
| 79 | +Note that it can be expensive to construct attributes individually. |
| 80 | +For a large number of elements, consider using a Python buffer or array instead. |
| 81 | +
|
| 82 | +Args: |
| 83 | + attrs: A list of attributes. |
| 84 | + type: The desired shape and type of the resulting DenseElementsAttr. |
| 85 | + If not provided, the element type is determined based on the type |
| 86 | + of the 0th attribute and the shape is `[len(attrs)]`. |
| 87 | + context: Explicit context, if not from context manager. |
| 88 | +
|
| 89 | +Returns: |
| 90 | + DenseElementsAttr on success. |
| 91 | +
|
| 92 | +Raises: |
| 93 | + ValueError: If the type of the attributes does not match the type |
| 94 | + specified by `shaped_type`. |
| 95 | +)"; |
| 96 | + |
75 | 97 | static const char kDenseResourceElementsAttrGetFromBufferDocstring[] =
|
76 | 98 | R"(Gets a DenseResourceElementsAttr from a Python buffer or array.
|
77 | 99 |
|
@@ -647,6 +669,57 @@ class PyDenseElementsAttribute
|
647 | 669 | static constexpr const char *pyClassName = "DenseElementsAttr";
|
648 | 670 | using PyConcreteAttribute::PyConcreteAttribute;
|
649 | 671 |
|
| 672 | + static PyDenseElementsAttribute |
| 673 | + getFromList(py::list attributes, std::optional<PyType> explicitType, |
| 674 | + DefaultingPyMlirContext contextWrapper) { |
| 675 | + |
| 676 | + const size_t numAttributes = py::len(attributes); |
| 677 | + if (numAttributes == 0) |
| 678 | + throw py::value_error("Attributes list must be non-empty."); |
| 679 | + |
| 680 | + MlirType shapedType; |
| 681 | + if (explicitType) { |
| 682 | + if ((!mlirTypeIsAShaped(*explicitType) || |
| 683 | + !mlirShapedTypeHasStaticShape(*explicitType))) { |
| 684 | + |
| 685 | + std::string message; |
| 686 | + llvm::raw_string_ostream os(message); |
| 687 | + os << "Expected a static ShapedType for the shaped_type parameter: " |
| 688 | + << py::repr(py::cast(*explicitType)); |
| 689 | + throw py::value_error(os.str()); |
| 690 | + } |
| 691 | + shapedType = *explicitType; |
| 692 | + } else { |
| 693 | + SmallVector<int64_t> shape{static_cast<int64_t>(numAttributes)}; |
| 694 | + shapedType = mlirRankedTensorTypeGet( |
| 695 | + shape.size(), shape.data(), |
| 696 | + mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])), |
| 697 | + mlirAttributeGetNull()); |
| 698 | + } |
| 699 | + |
| 700 | + SmallVector<MlirAttribute> mlirAttributes; |
| 701 | + mlirAttributes.reserve(numAttributes); |
| 702 | + for (const py::handle &attribute : attributes) { |
| 703 | + MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute); |
| 704 | + MlirType attrType = mlirAttributeGetType(mlirAttribute); |
| 705 | + mlirAttributes.push_back(mlirAttribute); |
| 706 | + |
| 707 | + if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) { |
| 708 | + std::string message; |
| 709 | + llvm::raw_string_ostream os(message); |
| 710 | + os << "All attributes must be of the same type and match " |
| 711 | + << "the type parameter: expected=" << py::repr(py::cast(shapedType)) |
| 712 | + << ", but got=" << py::repr(py::cast(attrType)); |
| 713 | + throw py::value_error(os.str()); |
| 714 | + } |
| 715 | + } |
| 716 | + |
| 717 | + MlirAttribute elements = mlirDenseElementsAttrGet( |
| 718 | + shapedType, mlirAttributes.size(), mlirAttributes.data()); |
| 719 | + |
| 720 | + return PyDenseElementsAttribute(contextWrapper->getRef(), elements); |
| 721 | + } |
| 722 | + |
650 | 723 | static PyDenseElementsAttribute
|
651 | 724 | getFromBuffer(py::buffer array, bool signless,
|
652 | 725 | std::optional<PyType> explicitType,
|
@@ -883,6 +956,10 @@ class PyDenseElementsAttribute
|
883 | 956 | py::arg("type") = py::none(), py::arg("shape") = py::none(),
|
884 | 957 | py::arg("context") = py::none(),
|
885 | 958 | kDenseElementsAttrGetDocstring)
|
| 959 | + .def_static("get", PyDenseElementsAttribute::getFromList, |
| 960 | + py::arg("attrs"), py::arg("type") = py::none(), |
| 961 | + py::arg("context") = py::none(), |
| 962 | + kDenseElementsAttrGetFromListDocstring) |
886 | 963 | .def_static("get_splat", PyDenseElementsAttribute::getSplat,
|
887 | 964 | py::arg("shaped_type"), py::arg("element_attr"),
|
888 | 965 | "Gets a DenseElementsAttr where all values are the same")
|
|
0 commit comments