Skip to content

[mlir] Add Python bindings for DenseResourceElementsAttr. #66319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions mlir/include/mlir-c/BuiltinAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,23 @@ mlirDenseElementsAttrGetRawData(MlirAttribute attr);
// Resource blob attributes.
//===----------------------------------------------------------------------===//

MLIR_CAPI_EXPORTED bool
mlirAttributeIsADenseResourceElements(MlirAttribute attr);

/// Unlike the typed accessors below, constructs the attribute with a raw
/// data buffer and no type/alignment checking. Use a more strongly typed
/// accessor if possible. If dataIsMutable is false, then an immutable
/// AsmResourceBlob will be created and that passed data contents will be
/// treated as const.
/// If the deleter is non NULL, then it will be called when the data buffer
/// can no longer be accessed (passing userData to it).
MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseResourceElementsAttrGet(
MlirType shapedType, MlirStringRef name, void *data, size_t dataLength,
size_t dataAlignment, bool dataIsMutable,
void (*deleter)(void *userData, const void *data, size_t size,
size_t align),
void *userData);

MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseBoolResourceElementsAttrGet(
MlirType shapedType, MlirStringRef name, intptr_t numElements,
const int *elements);
Expand Down Expand Up @@ -600,13 +617,6 @@ mlirUnmanagedDenseDoubleResourceElementsAttrGet(MlirType shapedType,
intptr_t numElements,
const double *elements);

/// Unlike the typed accessors above, constructs the attribute with a raw
/// data buffer and no type/alignment checking. Use a more strongly typed
/// accessor if possible.
MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseBlobResourceElementsAttrGet(
MlirType shapedType, MlirStringRef name, const void *data,
size_t dataLength);

/// Returns the pos-th value (flat contiguous indexing) of a specific type
/// contained by the given dense resource elements attribute.
MLIR_CAPI_EXPORTED bool
Expand Down
103 changes: 103 additions & 0 deletions mlir/lib/Bindings/Python/IRAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,32 @@ or 255), then a splat will be created.
type or if the buffer does not meet expectations.
)";

static const char kDenseResourceElementsAttrGetFromBufferDocstring[] =
R"(Gets a DenseResourceElementsAttr from a Python buffer or array.

This function does minimal validation or massaging of the data, and it is
up to the caller to ensure that the buffer meets the characteristics
implied by the shape.

The backing buffer and any user objects will be retained for the lifetime
of the resource blob. This is typically bounded to the context but the
resource can have a shorter lifespan depending on how it is used in
subsequent processing.

Args:
buffer: The array or buffer to convert.
name: Name to provide to the resource (may be changed upon collision).
type: The explicit ShapedType to construct the attribute with.
context: Explicit context, if not from context manager.

Returns:
DenseResourceElementsAttr on success.

Raises:
ValueError: If the type of the buffer or array cannot be matched to an MLIR
type or if the buffer does not meet expectations.
)";

namespace {

static MlirStringRef toMlirStringRef(const std::string &s) {
Expand Down Expand Up @@ -997,6 +1023,82 @@ class PyDenseIntElementsAttribute
}
};

class PyDenseResourceElementsAttribute
: public PyConcreteAttribute<PyDenseResourceElementsAttribute> {
public:
static constexpr IsAFunctionTy isaFunction =
mlirAttributeIsADenseResourceElements;
static constexpr const char *pyClassName = "DenseResourceElementsAttr";
using PyConcreteAttribute::PyConcreteAttribute;

static PyDenseResourceElementsAttribute
getFromBuffer(py::buffer buffer, std::string name, PyType type,
std::optional<size_t> alignment, bool isMutable,
DefaultingPyMlirContext contextWrapper) {
if (!mlirTypeIsAShaped(type)) {
throw std::invalid_argument(
"Constructing a DenseResourceElementsAttr requires a ShapedType.");
}

// Do not request any conversions as we must ensure to use caller
// managed memory.
int flags = PyBUF_STRIDES;
std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>();
if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
throw py::error_already_set();
}

// This scope releaser will only release if we haven't yet transferred
// ownership.
auto freeBuffer = llvm::make_scope_exit([&]() {
if (view)
PyBuffer_Release(view.get());
});

if (!PyBuffer_IsContiguous(view.get(), 'A')) {
throw std::invalid_argument("Contiguous buffer is required.");
}

// Infer alignment to be the stride of one element if not explicit.
size_t inferredAlignment;
if (alignment)
inferredAlignment = *alignment;
else
inferredAlignment = view->strides[view->ndim - 1];

// The userData is a Py_buffer* that the deleter owns.
auto deleter = [](void *userData, const void *data, size_t size,
size_t align) {
Py_buffer *ownedView = static_cast<Py_buffer *>(userData);
PyBuffer_Release(ownedView);
delete ownedView;
};

size_t rawBufferSize = view->len;
MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
type, toMlirStringRef(name), view->buf, rawBufferSize,
inferredAlignment, isMutable, deleter, static_cast<void *>(view.get()));
if (mlirAttributeIsNull(attr)) {
throw std::invalid_argument(
"DenseResourceElementsAttr could not be constructed from the given "
"buffer. "
"This may mean that the Python buffer layout does not match that "
"MLIR expected layout and is a bug.");
}
view.release();
return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
}

static void bindDerived(ClassTy &c) {
c.def_static("get_from_buffer",
PyDenseResourceElementsAttribute::getFromBuffer,
py::arg("array"), py::arg("name"), py::arg("type"),
py::arg("alignment") = py::none(),
py::arg("is_mutable") = false, py::arg("context") = py::none(),
kDenseResourceElementsAttrGetFromBufferDocstring);
}
};

class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
Expand Down Expand Up @@ -1273,6 +1375,7 @@ void mlir::python::populateIRAttributes(py::module &m) {
PyGlobals::get().registerTypeCaster(
mlirDenseIntOrFPElementsAttrGetTypeID(),
pybind11::cpp_function(denseIntOrFPElementsAttributeCaster));
PyDenseResourceElementsAttribute::bind(m);

PyDictAttribute::bind(m);
PySymbolRefAttribute::bind(m);
Expand Down
123 changes: 65 additions & 58 deletions mlir/lib/CAPI/IR/BuiltinAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,30 @@ const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) {
// Resource blob attributes.
//===----------------------------------------------------------------------===//

bool mlirAttributeIsADenseResourceElements(MlirAttribute attr) {
return llvm::isa<DenseResourceElementsAttr>(unwrap(attr));
}

MlirAttribute mlirUnmanagedDenseResourceElementsAttrGet(
MlirType shapedType, MlirStringRef name, void *data, size_t dataLength,
size_t dataAlignment, bool dataIsMutable,
void (*deleter)(void *userData, const void *data, size_t size,
size_t align),
void *userData) {
AsmResourceBlob::DeleterFn cppDeleter = {};
if (deleter) {
cppDeleter = [deleter, userData](void *data, size_t size, size_t align) {
deleter(userData, data, size, align);
};
}
AsmResourceBlob blob(
llvm::ArrayRef(static_cast<const char *>(data), dataLength),
dataAlignment, std::move(cppDeleter), dataIsMutable);
return wrap(
DenseResourceElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
unwrap(name), std::move(blob)));
}

template <typename U, typename T>
static MlirAttribute getDenseResource(MlirType shapedType, MlirStringRef name,
intptr_t numElements, const T *elements) {
Expand All @@ -778,139 +802,122 @@ static MlirAttribute getDenseResource(MlirType shapedType, MlirStringRef name,
llvm::ArrayRef(elements, numElements))));
}

MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseBoolResourceElementsAttrGet(
MlirAttribute mlirUnmanagedDenseBoolResourceElementsAttrGet(
MlirType shapedType, MlirStringRef name, intptr_t numElements,
const int *elements) {
return getDenseResource<DenseBoolResourceElementsAttr>(shapedType, name,
numElements, elements);
}
MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseUInt8ResourceElementsAttrGet(
MlirAttribute mlirUnmanagedDenseUInt8ResourceElementsAttrGet(
MlirType shapedType, MlirStringRef name, intptr_t numElements,
const uint8_t *elements) {
return getDenseResource<DenseUI8ResourceElementsAttr>(shapedType, name,
numElements, elements);
}
MLIR_CAPI_EXPORTED MlirAttribute
mlirUnmanagedDenseUInt16ResourceElementsAttrGet(MlirType shapedType,
MlirStringRef name,
intptr_t numElements,
const uint16_t *elements) {
MlirAttribute mlirUnmanagedDenseUInt16ResourceElementsAttrGet(
MlirType shapedType, MlirStringRef name, intptr_t numElements,
const uint16_t *elements) {
return getDenseResource<DenseUI16ResourceElementsAttr>(shapedType, name,
numElements, elements);
}
MLIR_CAPI_EXPORTED MlirAttribute
mlirUnmanagedDenseUInt32ResourceElementsAttrGet(MlirType shapedType,
MlirStringRef name,
intptr_t numElements,
const uint32_t *elements) {
MlirAttribute mlirUnmanagedDenseUInt32ResourceElementsAttrGet(
MlirType shapedType, MlirStringRef name, intptr_t numElements,
const uint32_t *elements) {
return getDenseResource<DenseUI32ResourceElementsAttr>(shapedType, name,
numElements, elements);
}
MLIR_CAPI_EXPORTED MlirAttribute
mlirUnmanagedDenseUInt64ResourceElementsAttrGet(MlirType shapedType,
MlirStringRef name,
intptr_t numElements,
const uint64_t *elements) {
MlirAttribute mlirUnmanagedDenseUInt64ResourceElementsAttrGet(
MlirType shapedType, MlirStringRef name, intptr_t numElements,
const uint64_t *elements) {
return getDenseResource<DenseUI64ResourceElementsAttr>(shapedType, name,
numElements, elements);
}
MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseInt8ResourceElementsAttrGet(
MlirAttribute mlirUnmanagedDenseInt8ResourceElementsAttrGet(
MlirType shapedType, MlirStringRef name, intptr_t numElements,
const int8_t *elements) {
return getDenseResource<DenseUI8ResourceElementsAttr>(shapedType, name,
numElements, elements);
}
MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseInt16ResourceElementsAttrGet(
MlirAttribute mlirUnmanagedDenseInt16ResourceElementsAttrGet(
MlirType shapedType, MlirStringRef name, intptr_t numElements,
const int16_t *elements) {
return getDenseResource<DenseUI16ResourceElementsAttr>(shapedType, name,
numElements, elements);
}
MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseInt32ResourceElementsAttrGet(
MlirAttribute mlirUnmanagedDenseInt32ResourceElementsAttrGet(
MlirType shapedType, MlirStringRef name, intptr_t numElements,
const int32_t *elements) {
return getDenseResource<DenseUI32ResourceElementsAttr>(shapedType, name,
numElements, elements);
}
MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseInt64ResourceElementsAttrGet(
MlirAttribute mlirUnmanagedDenseInt64ResourceElementsAttrGet(
MlirType shapedType, MlirStringRef name, intptr_t numElements,
const int64_t *elements) {
return getDenseResource<DenseUI64ResourceElementsAttr>(shapedType, name,
numElements, elements);
}
MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseFloatResourceElementsAttrGet(
MlirAttribute mlirUnmanagedDenseFloatResourceElementsAttrGet(
MlirType shapedType, MlirStringRef name, intptr_t numElements,
const float *elements) {
return getDenseResource<DenseF32ResourceElementsAttr>(shapedType, name,
numElements, elements);
}
MLIR_CAPI_EXPORTED MlirAttribute
mlirUnmanagedDenseDoubleResourceElementsAttrGet(MlirType shapedType,
MlirStringRef name,
intptr_t numElements,
const double *elements) {
MlirAttribute mlirUnmanagedDenseDoubleResourceElementsAttrGet(
MlirType shapedType, MlirStringRef name, intptr_t numElements,
const double *elements) {
return getDenseResource<DenseF64ResourceElementsAttr>(shapedType, name,
numElements, elements);
}
MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseBlobResourceElementsAttrGet(
MlirType shapedType, MlirStringRef name, const void *data,
size_t dataLength) {
return wrap(DenseResourceElementsAttr::get(
llvm::cast<ShapedType>(unwrap(shapedType)), unwrap(name),
UnmanagedAsmResourceBlob::allocateInferAlign(
llvm::ArrayRef(static_cast<const char *>(data), dataLength))));
}

template <typename U, typename T>
static T getDenseResourceVal(MlirAttribute attr, intptr_t pos) {
return (*llvm::cast<U>(unwrap(attr)).tryGetAsArrayRef())[pos];
}

MLIR_CAPI_EXPORTED bool
mlirDenseBoolResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
bool mlirDenseBoolResourceElementsAttrGetValue(MlirAttribute attr,
intptr_t pos) {
return getDenseResourceVal<DenseBoolResourceElementsAttr, uint8_t>(attr, pos);
}
MLIR_CAPI_EXPORTED uint8_t
mlirDenseUInt8ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
uint8_t mlirDenseUInt8ResourceElementsAttrGetValue(MlirAttribute attr,
intptr_t pos) {
return getDenseResourceVal<DenseUI8ResourceElementsAttr, uint8_t>(attr, pos);
}
MLIR_CAPI_EXPORTED uint16_t
mlirDenseUInt16ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
uint16_t mlirDenseUInt16ResourceElementsAttrGetValue(MlirAttribute attr,
intptr_t pos) {
return getDenseResourceVal<DenseUI16ResourceElementsAttr, uint16_t>(attr,
pos);
}
MLIR_CAPI_EXPORTED uint32_t
mlirDenseUInt32ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
uint32_t mlirDenseUInt32ResourceElementsAttrGetValue(MlirAttribute attr,
intptr_t pos) {
return getDenseResourceVal<DenseUI32ResourceElementsAttr, uint32_t>(attr,
pos);
}
MLIR_CAPI_EXPORTED uint64_t
mlirDenseUInt64ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
uint64_t mlirDenseUInt64ResourceElementsAttrGetValue(MlirAttribute attr,
intptr_t pos) {
return getDenseResourceVal<DenseUI64ResourceElementsAttr, uint64_t>(attr,
pos);
}
MLIR_CAPI_EXPORTED int8_t
mlirDenseInt8ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
int8_t mlirDenseInt8ResourceElementsAttrGetValue(MlirAttribute attr,
intptr_t pos) {
return getDenseResourceVal<DenseUI8ResourceElementsAttr, int8_t>(attr, pos);
}
MLIR_CAPI_EXPORTED int16_t
mlirDenseInt16ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
int16_t mlirDenseInt16ResourceElementsAttrGetValue(MlirAttribute attr,
intptr_t pos) {
return getDenseResourceVal<DenseUI16ResourceElementsAttr, int16_t>(attr, pos);
}
MLIR_CAPI_EXPORTED int32_t
mlirDenseInt32ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
int32_t mlirDenseInt32ResourceElementsAttrGetValue(MlirAttribute attr,
intptr_t pos) {
return getDenseResourceVal<DenseUI32ResourceElementsAttr, int32_t>(attr, pos);
}
MLIR_CAPI_EXPORTED int64_t
mlirDenseInt64ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
int64_t mlirDenseInt64ResourceElementsAttrGetValue(MlirAttribute attr,
intptr_t pos) {
return getDenseResourceVal<DenseUI64ResourceElementsAttr, int64_t>(attr, pos);
}
MLIR_CAPI_EXPORTED float
mlirDenseFloatResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
float mlirDenseFloatResourceElementsAttrGetValue(MlirAttribute attr,
intptr_t pos) {
return getDenseResourceVal<DenseF32ResourceElementsAttr, float>(attr, pos);
}
MLIR_CAPI_EXPORTED double
mlirDenseDoubleResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
double mlirDenseDoubleResourceElementsAttrGetValue(MlirAttribute attr,
intptr_t pos) {
return getDenseResourceVal<DenseF64ResourceElementsAttr, double>(attr, pos);
}

Expand Down
Loading