Skip to content

[MLIR,Python] Support converting boolean numpy arrays to and from mlir attributes #113064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Nov 2, 2024
Merged
278 changes: 181 additions & 97 deletions mlir/lib/Bindings/Python/IRAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "IRModule.h"

#include "PybindUtils.h"
#include <pybind11/numpy.h>

#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/raw_ostream.h"
Expand Down Expand Up @@ -757,103 +758,10 @@ class PyDenseElementsAttribute
throw py::error_already_set();
}
auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
SmallVector<int64_t> shape;
if (explicitShape) {
shape.append(explicitShape->begin(), explicitShape->end());
} else {
shape.append(view.shape, view.shape + view.ndim);
}

MlirAttribute encodingAttr = mlirAttributeGetNull();
MlirContext context = contextWrapper->get();

// Detect format codes that are suitable for bulk loading. This includes
// all byte aligned integer and floating point types up to 8 bytes.
// Notably, this excludes, bool (which needs to be bit-packed) and
// other exotics which do not have a direct representation in the buffer
// protocol (i.e. complex, etc).
std::optional<MlirType> bulkLoadElementType;
if (explicitType) {
bulkLoadElementType = *explicitType;
} else {
std::string_view format(view.format);
if (format == "f") {
// f32
assert(view.itemsize == 4 && "mismatched array itemsize");
bulkLoadElementType = mlirF32TypeGet(context);
} else if (format == "d") {
// f64
assert(view.itemsize == 8 && "mismatched array itemsize");
bulkLoadElementType = mlirF64TypeGet(context);
} else if (format == "e") {
// f16
assert(view.itemsize == 2 && "mismatched array itemsize");
bulkLoadElementType = mlirF16TypeGet(context);
} else if (isSignedIntegerFormat(format)) {
if (view.itemsize == 4) {
// i32
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 32)
: mlirIntegerTypeSignedGet(context, 32);
} else if (view.itemsize == 8) {
// i64
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 64)
: mlirIntegerTypeSignedGet(context, 64);
} else if (view.itemsize == 1) {
// i8
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
: mlirIntegerTypeSignedGet(context, 8);
} else if (view.itemsize == 2) {
// i16
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 16)
: mlirIntegerTypeSignedGet(context, 16);
}
} else if (isUnsignedIntegerFormat(format)) {
if (view.itemsize == 4) {
// unsigned i32
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 32)
: mlirIntegerTypeUnsignedGet(context, 32);
} else if (view.itemsize == 8) {
// unsigned i64
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 64)
: mlirIntegerTypeUnsignedGet(context, 64);
} else if (view.itemsize == 1) {
// i8
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 8)
: mlirIntegerTypeUnsignedGet(context, 8);
} else if (view.itemsize == 2) {
// i16
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 16)
: mlirIntegerTypeUnsignedGet(context, 16);
}
}
if (!bulkLoadElementType) {
throw std::invalid_argument(
std::string("unimplemented array format conversion from format: ") +
std::string(format));
}
}

MlirType shapedType;
if (mlirTypeIsAShaped(*bulkLoadElementType)) {
if (explicitShape) {
throw std::invalid_argument("Shape can only be specified explicitly "
"when the type is not a shaped type.");
}
shapedType = *bulkLoadElementType;
} else {
shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
*bulkLoadElementType, encodingAttr);
}
size_t rawBufferSize = view.len;
MlirAttribute attr =
mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf);
MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to move the if-statements to the getAttributeFromBuffer method, as now the i1 case will not follow the usual flow, but instead call getBitpackedAttributeFromBooleanBuffer to construct the MlirAttribute.

explicitShape, context);
if (mlirAttributeIsNull(attr)) {
throw std::invalid_argument(
"DenseElementsAttr could not be constructed from the given buffer. "
Expand Down Expand Up @@ -963,6 +871,13 @@ class PyDenseElementsAttribute
// unsigned i16
return bufferInfo<uint16_t>(shapedType);
}
} else if (mlirTypeIsAInteger(elementType) &&
mlirIntegerTypeGetWidth(elementType) == 1) {
// i1 / bool
// We can not send the buffer directly back to Python, because the i1
// values are bitpacked within MLIR. We call numpy's unpackbits function
// to convert the bytes.
return getBooleanBufferFromBitpackedAttribute();
}

// TODO: Currently crashes the program.
Expand Down Expand Up @@ -1016,14 +931,183 @@ class PyDenseElementsAttribute
code == 'q';
}

static MlirType
getShapedType(std::optional<MlirType> bulkLoadElementType,
std::optional<std::vector<int64_t>> explicitShape,
Py_buffer &view) {
SmallVector<int64_t> shape;
if (explicitShape) {
shape.append(explicitShape->begin(), explicitShape->end());
} else {
shape.append(view.shape, view.shape + view.ndim);
}

if (mlirTypeIsAShaped(*bulkLoadElementType)) {
if (explicitShape) {
throw std::invalid_argument("Shape can only be specified explicitly "
"when the type is not a shaped type.");
}
return *bulkLoadElementType;
} else {
MlirAttribute encodingAttr = mlirAttributeGetNull();
return mlirRankedTensorTypeGet(shape.size(), shape.data(),
*bulkLoadElementType, encodingAttr);
}
}

static MlirAttribute getAttributeFromBuffer(
Py_buffer &view, bool signless, std::optional<PyType> explicitType,
std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) {
// Detect format codes that are suitable for bulk loading. This includes
// all byte aligned integer and floating point types up to 8 bytes.
// Notably, this excludes exotics types which do not have a direct
// representation in the buffer protocol (i.e. complex, etc).
std::optional<MlirType> bulkLoadElementType;
if (explicitType) {
bulkLoadElementType = *explicitType;
} else {
std::string_view format(view.format);
if (format == "f") {
// f32
assert(view.itemsize == 4 && "mismatched array itemsize");
bulkLoadElementType = mlirF32TypeGet(context);
} else if (format == "d") {
// f64
assert(view.itemsize == 8 && "mismatched array itemsize");
bulkLoadElementType = mlirF64TypeGet(context);
} else if (format == "e") {
// f16
assert(view.itemsize == 2 && "mismatched array itemsize");
bulkLoadElementType = mlirF16TypeGet(context);
} else if (format == "?") {
// i1
// The i1 type needs to be bit-packed, so we will handle it seperately
return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
context);
} else if (isSignedIntegerFormat(format)) {
if (view.itemsize == 4) {
// i32
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 32)
: mlirIntegerTypeSignedGet(context, 32);
} else if (view.itemsize == 8) {
// i64
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 64)
: mlirIntegerTypeSignedGet(context, 64);
} else if (view.itemsize == 1) {
// i8
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
: mlirIntegerTypeSignedGet(context, 8);
} else if (view.itemsize == 2) {
// i16
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 16)
: mlirIntegerTypeSignedGet(context, 16);
}
} else if (isUnsignedIntegerFormat(format)) {
if (view.itemsize == 4) {
// unsigned i32
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 32)
: mlirIntegerTypeUnsignedGet(context, 32);
} else if (view.itemsize == 8) {
// unsigned i64
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 64)
: mlirIntegerTypeUnsignedGet(context, 64);
} else if (view.itemsize == 1) {
// i8
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 8)
: mlirIntegerTypeUnsignedGet(context, 8);
} else if (view.itemsize == 2) {
// i16
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 16)
: mlirIntegerTypeUnsignedGet(context, 16);
}
}
if (!bulkLoadElementType) {
throw std::invalid_argument(
std::string("unimplemented array format conversion from format: ") +
std::string(format));
}
}

MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
}

// There is a complication for boolean numpy arrays, as numpy represents them
// as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 booleans
// per byte.
static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
MlirContext &context) {
if (llvm::endianness::native != llvm::endianness::little) {
// Given we have no good way of testing the behavior on big-endian systems
// we will throw
throw py::type_error("Constructing a bit-packed MLIR attribute is "
"unsupported on big-endian systems");
}

py::array_t<uint8_t> unpackedArray(view.len,
static_cast<uint8_t *>(view.buf));

py::module numpy = py::module::import("numpy");
py::object packbits_func = numpy.attr("packbits");
py::object packed_booleans =
packbits_func(unpackedArray, "bitorder"_a = "little");
Comment on lines +1060 to +1061
Copy link
Contributor

@makslevental makslevental Oct 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

................ you can do this ..................... wow color me shocked I never noticed/knew you could do "kwargs" like this on cpp side.

py::buffer_info pythonBuffer = packed_booleans.cast<py::buffer>().request();

MlirType bitpackedType =
getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
pythonBuffer.ptr);
}

// This does the opposite transformation of
// `getBitpackedAttributeFromBooleanBuffer`
py::buffer_info getBooleanBufferFromBitpackedAttribute() {
if (llvm::endianness::native != llvm::endianness::little) {
// Given we have no good way of testing the behavior on big-endian systems
// we will throw
throw py::type_error("Constructing a numpy array from a MLIR attribute "
"is unsupported on big-endian systems");
}

int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8);
uint8_t *bitpackedData = static_cast<uint8_t *>(
const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
py::array_t<uint8_t> packedArray(numBitpackedBytes, bitpackedData);

py::module numpy = py::module::import("numpy");
py::object unpackbits_func = numpy.attr("unpackbits");
py::object unpacked_booleans =
unpackbits_func(packedArray, "bitorder"_a = "little");
py::buffer_info pythonBuffer =
unpacked_booleans.cast<py::buffer>().request();

MlirType shapedType = mlirAttributeGetType(*this);
return bufferInfo<bool>(shapedType, (bool *)pythonBuffer.ptr, "?");
}

template <typename Type>
py::buffer_info bufferInfo(MlirType shapedType,
const char *explicitFormat = nullptr) {
intptr_t rank = mlirShapedTypeGetRank(shapedType);
// Prepare the data for the buffer_info.
// Buffer is configured for read-only access below.
// Buffer is configured for read-only access inside the `bufferInfo` call.
Type *data = static_cast<Type *>(
const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
return bufferInfo<Type>(shapedType, data, explicitFormat);
}

template <typename Type>
py::buffer_info bufferInfo(MlirType shapedType, Type *data,
const char *explicitFormat = nullptr) {
intptr_t rank = mlirShapedTypeGetRank(shapedType);
// Prepare the shape for the buffer_info.
SmallVector<intptr_t, 4> shape;
for (intptr_t i = 0; i < rank; ++i)
Expand Down
72 changes: 72 additions & 0 deletions mlir/test/python/ir/array_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,78 @@ def testGetDenseElementsF64():
print(np.array(attr))


### 1 bit/boolean integer arrays
# CHECK-LABEL: TEST: testGetDenseElementsI1Signless
@run
def testGetDenseElementsI1Signless():
with Context():
array = np.array([True], dtype=np.bool_)
attr = DenseElementsAttr.get(array)
# CHECK: dense<true> : tensor<1xi1>
print(attr)
# CHECK{LITERAL}: [ True]
print(np.array(attr))

array = np.array([[True, False, True], [True, True, False]], dtype=np.bool_)
attr = DenseElementsAttr.get(array)
# CHECK{LITERAL}: dense<[[true, false, true], [true, true, false]]> : tensor<2x3xi1>
print(attr)
# CHECK{LITERAL}: [[ True False True]
# CHECK{LITERAL}: [ True True False]]
print(np.array(attr))

array = np.array(
[[True, True, False, False], [True, False, True, False]], dtype=np.bool_
)
attr = DenseElementsAttr.get(array)
# CHECK{LITERAL}: dense<[[true, true, false, false], [true, false, true, false]]> : tensor<2x4xi1>
print(attr)
# CHECK{LITERAL}: [[ True True False False]
# CHECK{LITERAL}: [ True False True False]]
print(np.array(attr))

array = np.array(
[
[True, True, False, False],
[True, False, True, False],
[False, False, False, False],
[True, True, True, True],
[True, False, False, True],
],
dtype=np.bool_,
)
attr = DenseElementsAttr.get(array)
# CHECK{LITERAL}: dense<[[true, true, false, false], [true, false, true, false], [false, false, false, false], [true, true, true, true], [true, false, false, true]]> : tensor<5x4xi1>
print(attr)
# CHECK{LITERAL}: [[ True True False False]
# CHECK{LITERAL}: [ True False True False]
# CHECK{LITERAL}: [False False False False]
# CHECK{LITERAL}: [ True True True True]
# CHECK{LITERAL}: [ True False False True]]
print(np.array(attr))

array = np.array(
[
[True, True, False, False, True, True, False, False, False],
[False, False, False, True, False, True, True, False, True],
],
dtype=np.bool_,
)
attr = DenseElementsAttr.get(array)
# CHECK{LITERAL}: dense<[[true, true, false, false, true, true, false, false, false], [false, false, false, true, false, true, true, false, true]]> : tensor<2x9xi1>
print(attr)
# CHECK{LITERAL}: [[ True True False False True True False False False]
# CHECK{LITERAL}: [False False False True False True True False True]]
print(np.array(attr))

array = np.array([], dtype=np.bool_)
attr = DenseElementsAttr.get(array)
# CHECK: dense<> : tensor<0xi1>
print(attr)
# CHECK{LITERAL}: []
print(np.array(attr))


### 16 bit integer arrays
# CHECK-LABEL: TEST: testGetDenseElementsI16Signless
@run
Expand Down