Skip to content

Commit f66cd9e

Browse files
[mlir] Add Python bindings for DenseResourceElementsAttr. (#66319)
Only construction and type casting are implemented. The method to create is explicitly named "unsafe" and the documentation calls out what the caller is responsible for. There really isn't a better way to do this and retain the power-user feature this represents.
1 parent 1e40dfc commit f66cd9e

File tree

5 files changed

+252
-71
lines changed

5 files changed

+252
-71
lines changed

mlir/include/mlir-c/BuiltinAttributes.h

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,23 @@ mlirDenseElementsAttrGetRawData(MlirAttribute attr);
558558
// Resource blob attributes.
559559
//===----------------------------------------------------------------------===//
560560

561+
MLIR_CAPI_EXPORTED bool
562+
mlirAttributeIsADenseResourceElements(MlirAttribute attr);
563+
564+
/// Unlike the typed accessors below, constructs the attribute with a raw
565+
/// data buffer and no type/alignment checking. Use a more strongly typed
566+
/// accessor if possible. If dataIsMutable is false, then an immutable
567+
/// AsmResourceBlob will be created and that passed data contents will be
568+
/// treated as const.
569+
/// If the deleter is non NULL, then it will be called when the data buffer
570+
/// can no longer be accessed (passing userData to it).
571+
MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseResourceElementsAttrGet(
572+
MlirType shapedType, MlirStringRef name, void *data, size_t dataLength,
573+
size_t dataAlignment, bool dataIsMutable,
574+
void (*deleter)(void *userData, const void *data, size_t size,
575+
size_t align),
576+
void *userData);
577+
561578
MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseBoolResourceElementsAttrGet(
562579
MlirType shapedType, MlirStringRef name, intptr_t numElements,
563580
const int *elements);
@@ -600,13 +617,6 @@ mlirUnmanagedDenseDoubleResourceElementsAttrGet(MlirType shapedType,
600617
intptr_t numElements,
601618
const double *elements);
602619

603-
/// Unlike the typed accessors above, constructs the attribute with a raw
604-
/// data buffer and no type/alignment checking. Use a more strongly typed
605-
/// accessor if possible.
606-
MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseBlobResourceElementsAttrGet(
607-
MlirType shapedType, MlirStringRef name, const void *data,
608-
size_t dataLength);
609-
610620
/// Returns the pos-th value (flat contiguous indexing) of a specific type
611621
/// contained by the given dense resource elements attribute.
612622
MLIR_CAPI_EXPORTED bool

mlir/lib/Bindings/Python/IRAttributes.cpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,32 @@ or 255), then a splat will be created.
7272
type or if the buffer does not meet expectations.
7373
)";
7474

75+
static const char kDenseResourceElementsAttrGetFromBufferDocstring[] =
76+
R"(Gets a DenseResourceElementsAttr from a Python buffer or array.
77+
78+
This function does minimal validation or massaging of the data, and it is
79+
up to the caller to ensure that the buffer meets the characteristics
80+
implied by the shape.
81+
82+
The backing buffer and any user objects will be retained for the lifetime
83+
of the resource blob. This is typically bounded to the context but the
84+
resource can have a shorter lifespan depending on how it is used in
85+
subsequent processing.
86+
87+
Args:
88+
buffer: The array or buffer to convert.
89+
name: Name to provide to the resource (may be changed upon collision).
90+
type: The explicit ShapedType to construct the attribute with.
91+
context: Explicit context, if not from context manager.
92+
93+
Returns:
94+
DenseResourceElementsAttr on success.
95+
96+
Raises:
97+
ValueError: If the type of the buffer or array cannot be matched to an MLIR
98+
type or if the buffer does not meet expectations.
99+
)";
100+
75101
namespace {
76102

77103
static MlirStringRef toMlirStringRef(const std::string &s) {
@@ -997,6 +1023,82 @@ class PyDenseIntElementsAttribute
9971023
}
9981024
};
9991025

1026+
class PyDenseResourceElementsAttribute
1027+
: public PyConcreteAttribute<PyDenseResourceElementsAttribute> {
1028+
public:
1029+
static constexpr IsAFunctionTy isaFunction =
1030+
mlirAttributeIsADenseResourceElements;
1031+
static constexpr const char *pyClassName = "DenseResourceElementsAttr";
1032+
using PyConcreteAttribute::PyConcreteAttribute;
1033+
1034+
static PyDenseResourceElementsAttribute
1035+
getFromBuffer(py::buffer buffer, std::string name, PyType type,
1036+
std::optional<size_t> alignment, bool isMutable,
1037+
DefaultingPyMlirContext contextWrapper) {
1038+
if (!mlirTypeIsAShaped(type)) {
1039+
throw std::invalid_argument(
1040+
"Constructing a DenseResourceElementsAttr requires a ShapedType.");
1041+
}
1042+
1043+
// Do not request any conversions as we must ensure to use caller
1044+
// managed memory.
1045+
int flags = PyBUF_STRIDES;
1046+
std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>();
1047+
if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
1048+
throw py::error_already_set();
1049+
}
1050+
1051+
// This scope releaser will only release if we haven't yet transferred
1052+
// ownership.
1053+
auto freeBuffer = llvm::make_scope_exit([&]() {
1054+
if (view)
1055+
PyBuffer_Release(view.get());
1056+
});
1057+
1058+
if (!PyBuffer_IsContiguous(view.get(), 'A')) {
1059+
throw std::invalid_argument("Contiguous buffer is required.");
1060+
}
1061+
1062+
// Infer alignment to be the stride of one element if not explicit.
1063+
size_t inferredAlignment;
1064+
if (alignment)
1065+
inferredAlignment = *alignment;
1066+
else
1067+
inferredAlignment = view->strides[view->ndim - 1];
1068+
1069+
// The userData is a Py_buffer* that the deleter owns.
1070+
auto deleter = [](void *userData, const void *data, size_t size,
1071+
size_t align) {
1072+
Py_buffer *ownedView = static_cast<Py_buffer *>(userData);
1073+
PyBuffer_Release(ownedView);
1074+
delete ownedView;
1075+
};
1076+
1077+
size_t rawBufferSize = view->len;
1078+
MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
1079+
type, toMlirStringRef(name), view->buf, rawBufferSize,
1080+
inferredAlignment, isMutable, deleter, static_cast<void *>(view.get()));
1081+
if (mlirAttributeIsNull(attr)) {
1082+
throw std::invalid_argument(
1083+
"DenseResourceElementsAttr could not be constructed from the given "
1084+
"buffer. "
1085+
"This may mean that the Python buffer layout does not match that "
1086+
"MLIR expected layout and is a bug.");
1087+
}
1088+
view.release();
1089+
return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
1090+
}
1091+
1092+
static void bindDerived(ClassTy &c) {
1093+
c.def_static("get_from_buffer",
1094+
PyDenseResourceElementsAttribute::getFromBuffer,
1095+
py::arg("array"), py::arg("name"), py::arg("type"),
1096+
py::arg("alignment") = py::none(),
1097+
py::arg("is_mutable") = false, py::arg("context") = py::none(),
1098+
kDenseResourceElementsAttrGetFromBufferDocstring);
1099+
}
1100+
};
1101+
10001102
class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
10011103
public:
10021104
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
@@ -1273,6 +1375,7 @@ void mlir::python::populateIRAttributes(py::module &m) {
12731375
PyGlobals::get().registerTypeCaster(
12741376
mlirDenseIntOrFPElementsAttrGetTypeID(),
12751377
pybind11::cpp_function(denseIntOrFPElementsAttributeCaster));
1378+
PyDenseResourceElementsAttribute::bind(m);
12761379

12771380
PyDictAttribute::bind(m);
12781381
PySymbolRefAttribute::bind(m);

mlir/lib/CAPI/IR/BuiltinAttributes.cpp

Lines changed: 65 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,30 @@ const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) {
770770
// Resource blob attributes.
771771
//===----------------------------------------------------------------------===//
772772

773+
bool mlirAttributeIsADenseResourceElements(MlirAttribute attr) {
774+
return llvm::isa<DenseResourceElementsAttr>(unwrap(attr));
775+
}
776+
777+
MlirAttribute mlirUnmanagedDenseResourceElementsAttrGet(
778+
MlirType shapedType, MlirStringRef name, void *data, size_t dataLength,
779+
size_t dataAlignment, bool dataIsMutable,
780+
void (*deleter)(void *userData, const void *data, size_t size,
781+
size_t align),
782+
void *userData) {
783+
AsmResourceBlob::DeleterFn cppDeleter = {};
784+
if (deleter) {
785+
cppDeleter = [deleter, userData](void *data, size_t size, size_t align) {
786+
deleter(userData, data, size, align);
787+
};
788+
}
789+
AsmResourceBlob blob(
790+
llvm::ArrayRef(static_cast<const char *>(data), dataLength),
791+
dataAlignment, std::move(cppDeleter), dataIsMutable);
792+
return wrap(
793+
DenseResourceElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
794+
unwrap(name), std::move(blob)));
795+
}
796+
773797
template <typename U, typename T>
774798
static MlirAttribute getDenseResource(MlirType shapedType, MlirStringRef name,
775799
intptr_t numElements, const T *elements) {
@@ -778,139 +802,122 @@ static MlirAttribute getDenseResource(MlirType shapedType, MlirStringRef name,
778802
llvm::ArrayRef(elements, numElements))));
779803
}
780804

781-
MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseBoolResourceElementsAttrGet(
805+
MlirAttribute mlirUnmanagedDenseBoolResourceElementsAttrGet(
782806
MlirType shapedType, MlirStringRef name, intptr_t numElements,
783807
const int *elements) {
784808
return getDenseResource<DenseBoolResourceElementsAttr>(shapedType, name,
785809
numElements, elements);
786810
}
787-
MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseUInt8ResourceElementsAttrGet(
811+
MlirAttribute mlirUnmanagedDenseUInt8ResourceElementsAttrGet(
788812
MlirType shapedType, MlirStringRef name, intptr_t numElements,
789813
const uint8_t *elements) {
790814
return getDenseResource<DenseUI8ResourceElementsAttr>(shapedType, name,
791815
numElements, elements);
792816
}
793-
MLIR_CAPI_EXPORTED MlirAttribute
794-
mlirUnmanagedDenseUInt16ResourceElementsAttrGet(MlirType shapedType,
795-
MlirStringRef name,
796-
intptr_t numElements,
797-
const uint16_t *elements) {
817+
MlirAttribute mlirUnmanagedDenseUInt16ResourceElementsAttrGet(
818+
MlirType shapedType, MlirStringRef name, intptr_t numElements,
819+
const uint16_t *elements) {
798820
return getDenseResource<DenseUI16ResourceElementsAttr>(shapedType, name,
799821
numElements, elements);
800822
}
801-
MLIR_CAPI_EXPORTED MlirAttribute
802-
mlirUnmanagedDenseUInt32ResourceElementsAttrGet(MlirType shapedType,
803-
MlirStringRef name,
804-
intptr_t numElements,
805-
const uint32_t *elements) {
823+
MlirAttribute mlirUnmanagedDenseUInt32ResourceElementsAttrGet(
824+
MlirType shapedType, MlirStringRef name, intptr_t numElements,
825+
const uint32_t *elements) {
806826
return getDenseResource<DenseUI32ResourceElementsAttr>(shapedType, name,
807827
numElements, elements);
808828
}
809-
MLIR_CAPI_EXPORTED MlirAttribute
810-
mlirUnmanagedDenseUInt64ResourceElementsAttrGet(MlirType shapedType,
811-
MlirStringRef name,
812-
intptr_t numElements,
813-
const uint64_t *elements) {
829+
MlirAttribute mlirUnmanagedDenseUInt64ResourceElementsAttrGet(
830+
MlirType shapedType, MlirStringRef name, intptr_t numElements,
831+
const uint64_t *elements) {
814832
return getDenseResource<DenseUI64ResourceElementsAttr>(shapedType, name,
815833
numElements, elements);
816834
}
817-
MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseInt8ResourceElementsAttrGet(
835+
MlirAttribute mlirUnmanagedDenseInt8ResourceElementsAttrGet(
818836
MlirType shapedType, MlirStringRef name, intptr_t numElements,
819837
const int8_t *elements) {
820838
return getDenseResource<DenseUI8ResourceElementsAttr>(shapedType, name,
821839
numElements, elements);
822840
}
823-
MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseInt16ResourceElementsAttrGet(
841+
MlirAttribute mlirUnmanagedDenseInt16ResourceElementsAttrGet(
824842
MlirType shapedType, MlirStringRef name, intptr_t numElements,
825843
const int16_t *elements) {
826844
return getDenseResource<DenseUI16ResourceElementsAttr>(shapedType, name,
827845
numElements, elements);
828846
}
829-
MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseInt32ResourceElementsAttrGet(
847+
MlirAttribute mlirUnmanagedDenseInt32ResourceElementsAttrGet(
830848
MlirType shapedType, MlirStringRef name, intptr_t numElements,
831849
const int32_t *elements) {
832850
return getDenseResource<DenseUI32ResourceElementsAttr>(shapedType, name,
833851
numElements, elements);
834852
}
835-
MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseInt64ResourceElementsAttrGet(
853+
MlirAttribute mlirUnmanagedDenseInt64ResourceElementsAttrGet(
836854
MlirType shapedType, MlirStringRef name, intptr_t numElements,
837855
const int64_t *elements) {
838856
return getDenseResource<DenseUI64ResourceElementsAttr>(shapedType, name,
839857
numElements, elements);
840858
}
841-
MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseFloatResourceElementsAttrGet(
859+
MlirAttribute mlirUnmanagedDenseFloatResourceElementsAttrGet(
842860
MlirType shapedType, MlirStringRef name, intptr_t numElements,
843861
const float *elements) {
844862
return getDenseResource<DenseF32ResourceElementsAttr>(shapedType, name,
845863
numElements, elements);
846864
}
847-
MLIR_CAPI_EXPORTED MlirAttribute
848-
mlirUnmanagedDenseDoubleResourceElementsAttrGet(MlirType shapedType,
849-
MlirStringRef name,
850-
intptr_t numElements,
851-
const double *elements) {
865+
MlirAttribute mlirUnmanagedDenseDoubleResourceElementsAttrGet(
866+
MlirType shapedType, MlirStringRef name, intptr_t numElements,
867+
const double *elements) {
852868
return getDenseResource<DenseF64ResourceElementsAttr>(shapedType, name,
853869
numElements, elements);
854870
}
855-
MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseBlobResourceElementsAttrGet(
856-
MlirType shapedType, MlirStringRef name, const void *data,
857-
size_t dataLength) {
858-
return wrap(DenseResourceElementsAttr::get(
859-
llvm::cast<ShapedType>(unwrap(shapedType)), unwrap(name),
860-
UnmanagedAsmResourceBlob::allocateInferAlign(
861-
llvm::ArrayRef(static_cast<const char *>(data), dataLength))));
862-
}
863-
864871
template <typename U, typename T>
865872
static T getDenseResourceVal(MlirAttribute attr, intptr_t pos) {
866873
return (*llvm::cast<U>(unwrap(attr)).tryGetAsArrayRef())[pos];
867874
}
868875

869-
MLIR_CAPI_EXPORTED bool
870-
mlirDenseBoolResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
876+
bool mlirDenseBoolResourceElementsAttrGetValue(MlirAttribute attr,
877+
intptr_t pos) {
871878
return getDenseResourceVal<DenseBoolResourceElementsAttr, uint8_t>(attr, pos);
872879
}
873-
MLIR_CAPI_EXPORTED uint8_t
874-
mlirDenseUInt8ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
880+
uint8_t mlirDenseUInt8ResourceElementsAttrGetValue(MlirAttribute attr,
881+
intptr_t pos) {
875882
return getDenseResourceVal<DenseUI8ResourceElementsAttr, uint8_t>(attr, pos);
876883
}
877-
MLIR_CAPI_EXPORTED uint16_t
878-
mlirDenseUInt16ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
884+
uint16_t mlirDenseUInt16ResourceElementsAttrGetValue(MlirAttribute attr,
885+
intptr_t pos) {
879886
return getDenseResourceVal<DenseUI16ResourceElementsAttr, uint16_t>(attr,
880887
pos);
881888
}
882-
MLIR_CAPI_EXPORTED uint32_t
883-
mlirDenseUInt32ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
889+
uint32_t mlirDenseUInt32ResourceElementsAttrGetValue(MlirAttribute attr,
890+
intptr_t pos) {
884891
return getDenseResourceVal<DenseUI32ResourceElementsAttr, uint32_t>(attr,
885892
pos);
886893
}
887-
MLIR_CAPI_EXPORTED uint64_t
888-
mlirDenseUInt64ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
894+
uint64_t mlirDenseUInt64ResourceElementsAttrGetValue(MlirAttribute attr,
895+
intptr_t pos) {
889896
return getDenseResourceVal<DenseUI64ResourceElementsAttr, uint64_t>(attr,
890897
pos);
891898
}
892-
MLIR_CAPI_EXPORTED int8_t
893-
mlirDenseInt8ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
899+
int8_t mlirDenseInt8ResourceElementsAttrGetValue(MlirAttribute attr,
900+
intptr_t pos) {
894901
return getDenseResourceVal<DenseUI8ResourceElementsAttr, int8_t>(attr, pos);
895902
}
896-
MLIR_CAPI_EXPORTED int16_t
897-
mlirDenseInt16ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
903+
int16_t mlirDenseInt16ResourceElementsAttrGetValue(MlirAttribute attr,
904+
intptr_t pos) {
898905
return getDenseResourceVal<DenseUI16ResourceElementsAttr, int16_t>(attr, pos);
899906
}
900-
MLIR_CAPI_EXPORTED int32_t
901-
mlirDenseInt32ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
907+
int32_t mlirDenseInt32ResourceElementsAttrGetValue(MlirAttribute attr,
908+
intptr_t pos) {
902909
return getDenseResourceVal<DenseUI32ResourceElementsAttr, int32_t>(attr, pos);
903910
}
904-
MLIR_CAPI_EXPORTED int64_t
905-
mlirDenseInt64ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
911+
int64_t mlirDenseInt64ResourceElementsAttrGetValue(MlirAttribute attr,
912+
intptr_t pos) {
906913
return getDenseResourceVal<DenseUI64ResourceElementsAttr, int64_t>(attr, pos);
907914
}
908-
MLIR_CAPI_EXPORTED float
909-
mlirDenseFloatResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
915+
float mlirDenseFloatResourceElementsAttrGetValue(MlirAttribute attr,
916+
intptr_t pos) {
910917
return getDenseResourceVal<DenseF32ResourceElementsAttr, float>(attr, pos);
911918
}
912-
MLIR_CAPI_EXPORTED double
913-
mlirDenseDoubleResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
919+
double mlirDenseDoubleResourceElementsAttrGetValue(MlirAttribute attr,
920+
intptr_t pos) {
914921
return getDenseResourceVal<DenseF64ResourceElementsAttr, double>(attr, pos);
915922
}
916923

0 commit comments

Comments
 (0)