Skip to content

[MLIR,Python] Support converting boolean numpy arrays to and from mlir attributes #113064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Nov 2, 2024
Merged
272 changes: 175 additions & 97 deletions mlir/lib/Bindings/Python/IRAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "IRModule.h"

#include "PybindUtils.h"
#include <pybind11/numpy.h>

#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/raw_ostream.h"
Expand Down Expand Up @@ -757,103 +758,10 @@ class PyDenseElementsAttribute
throw py::error_already_set();
}
auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
SmallVector<int64_t> shape;
if (explicitShape) {
shape.append(explicitShape->begin(), explicitShape->end());
} else {
shape.append(view.shape, view.shape + view.ndim);
}

MlirAttribute encodingAttr = mlirAttributeGetNull();
MlirContext context = contextWrapper->get();

// Detect format codes that are suitable for bulk loading. This includes
// all byte aligned integer and floating point types up to 8 bytes.
// Notably, this excludes, bool (which needs to be bit-packed) and
// other exotics which do not have a direct representation in the buffer
// protocol (i.e. complex, etc).
std::optional<MlirType> bulkLoadElementType;
if (explicitType) {
bulkLoadElementType = *explicitType;
} else {
std::string_view format(view.format);
if (format == "f") {
// f32
assert(view.itemsize == 4 && "mismatched array itemsize");
bulkLoadElementType = mlirF32TypeGet(context);
} else if (format == "d") {
// f64
assert(view.itemsize == 8 && "mismatched array itemsize");
bulkLoadElementType = mlirF64TypeGet(context);
} else if (format == "e") {
// f16
assert(view.itemsize == 2 && "mismatched array itemsize");
bulkLoadElementType = mlirF16TypeGet(context);
} else if (isSignedIntegerFormat(format)) {
if (view.itemsize == 4) {
// i32
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 32)
: mlirIntegerTypeSignedGet(context, 32);
} else if (view.itemsize == 8) {
// i64
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 64)
: mlirIntegerTypeSignedGet(context, 64);
} else if (view.itemsize == 1) {
// i8
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
: mlirIntegerTypeSignedGet(context, 8);
} else if (view.itemsize == 2) {
// i16
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 16)
: mlirIntegerTypeSignedGet(context, 16);
}
} else if (isUnsignedIntegerFormat(format)) {
if (view.itemsize == 4) {
// unsigned i32
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 32)
: mlirIntegerTypeUnsignedGet(context, 32);
} else if (view.itemsize == 8) {
// unsigned i64
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 64)
: mlirIntegerTypeUnsignedGet(context, 64);
} else if (view.itemsize == 1) {
// i8
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 8)
: mlirIntegerTypeUnsignedGet(context, 8);
} else if (view.itemsize == 2) {
// i16
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 16)
: mlirIntegerTypeUnsignedGet(context, 16);
}
}
if (!bulkLoadElementType) {
throw std::invalid_argument(
std::string("unimplemented array format conversion from format: ") +
std::string(format));
}
}

MlirType shapedType;
if (mlirTypeIsAShaped(*bulkLoadElementType)) {
if (explicitShape) {
throw std::invalid_argument("Shape can only be specified explicitly "
"when the type is not a shaped type.");
}
shapedType = *bulkLoadElementType;
} else {
shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
*bulkLoadElementType, encodingAttr);
}
size_t rawBufferSize = view.len;
MlirAttribute attr =
mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf);
MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to move the if-statements to the getAttributeFromBuffer method, as now the i1 case will not follow the usual flow, but instead call getBitpackedAttributeFromBooleanBuffer to construct the MlirAttribute.

explicitShape, context);
if (mlirAttributeIsNull(attr)) {
throw std::invalid_argument(
"DenseElementsAttr could not be constructed from the given buffer. "
Expand Down Expand Up @@ -963,6 +871,13 @@ class PyDenseElementsAttribute
// unsigned i16
return bufferInfo<uint16_t>(shapedType);
}
} else if (mlirTypeIsAInteger(elementType) &&
mlirIntegerTypeGetWidth(elementType) == 1) {
// i1 / bool
// We can not send the buffer directly back to Python, because the i1
// values are bitpacked within MLIR. We call numpy's unpackbits function
// to convert the bytes.
return getBooleanBufferFromBitpackedAttribute();
}

// TODO: Currently crashes the program.
Expand Down Expand Up @@ -1016,14 +931,177 @@ class PyDenseElementsAttribute
code == 'q';
}

static MlirType
getShapedType(std::optional<MlirType> bulkLoadElementType,
std::optional<std::vector<int64_t>> explicitShape,
Py_buffer &view) {
SmallVector<int64_t> shape;
if (explicitShape) {
shape.append(explicitShape->begin(), explicitShape->end());
} else {
shape.append(view.shape, view.shape + view.ndim);
}

if (mlirTypeIsAShaped(*bulkLoadElementType)) {
if (explicitShape) {
throw std::invalid_argument("Shape can only be specified explicitly "
"when the type is not a shaped type.");
}
return *bulkLoadElementType;
} else {
MlirAttribute encodingAttr = mlirAttributeGetNull();
return mlirRankedTensorTypeGet(shape.size(), shape.data(),
*bulkLoadElementType, encodingAttr);
}
}

static MlirAttribute getAttributeFromBuffer(
Py_buffer &view, bool signless, std::optional<PyType> explicitType,
std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) {
// Detect format codes that are suitable for bulk loading. This includes
// all byte aligned integer and floating point types up to 8 bytes.
// Notably, this excludes exotics types which do not have a direct
// representation in the buffer protocol (i.e. complex, etc).
std::optional<MlirType> bulkLoadElementType;
if (explicitType) {
bulkLoadElementType = *explicitType;
} else {
std::string_view format(view.format);
if (format == "f") {
// f32
assert(view.itemsize == 4 && "mismatched array itemsize");
bulkLoadElementType = mlirF32TypeGet(context);
} else if (format == "d") {
// f64
assert(view.itemsize == 8 && "mismatched array itemsize");
bulkLoadElementType = mlirF64TypeGet(context);
} else if (format == "e") {
// f16
assert(view.itemsize == 2 && "mismatched array itemsize");
bulkLoadElementType = mlirF16TypeGet(context);
} else if (format == "?") {
// i1
// The i1 type needs to be bit-packed, so we will handle it seperately
return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
context);
} else if (isSignedIntegerFormat(format)) {
if (view.itemsize == 4) {
// i32
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 32)
: mlirIntegerTypeSignedGet(context, 32);
} else if (view.itemsize == 8) {
// i64
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 64)
: mlirIntegerTypeSignedGet(context, 64);
} else if (view.itemsize == 1) {
// i8
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
: mlirIntegerTypeSignedGet(context, 8);
} else if (view.itemsize == 2) {
// i16
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 16)
: mlirIntegerTypeSignedGet(context, 16);
}
} else if (isUnsignedIntegerFormat(format)) {
if (view.itemsize == 4) {
// unsigned i32
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 32)
: mlirIntegerTypeUnsignedGet(context, 32);
} else if (view.itemsize == 8) {
// unsigned i64
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 64)
: mlirIntegerTypeUnsignedGet(context, 64);
} else if (view.itemsize == 1) {
// i8
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 8)
: mlirIntegerTypeUnsignedGet(context, 8);
} else if (view.itemsize == 2) {
// i16
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 16)
: mlirIntegerTypeUnsignedGet(context, 16);
}
}
if (!bulkLoadElementType) {
throw std::invalid_argument(
std::string("unimplemented array format conversion from format: ") +
std::string(format));
}
}

MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
}

// There is a complication for boolean numpy arrays, as numpy represents them
// as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 booleans
// per byte. This function does the bit-packing respecting endianess.
static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
MlirContext &context) {
// First read the content of the python buffer as u8's, to correct for
// endianess
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry don't understand? what/why does this need to be "corrected"? I can see you're setting little below but shouldn't you be preserving/using the endianness for the host arch (wherever the bindings are running...)? probably I don't understand something here...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From https://mlir.llvm.org/docs/BytecodeFormat/#fixed-width-integers my understanding is that MLIR attributes always have their data stored in a little-endian format. If that is not true for MlirAttributes in general, the code is wrong.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stellaraccident this is the one i'm most uncertain about

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's for bytecode rather than in memory. If this is converting the in memory format, then that section doesn't apply.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack, one thing I don't understand though, is that if I follow the mlirDenseElementsAttrRawBufferGet function, it goes to DenseElementsAttr::getFromRawBuffer, and I believe it eventually ends up in BuiltinAttributes.cpp where it uses the writeBits and readBits functions respectively. Both of these have special handling for big endian, where it seems to convert the format.

It is possible that it ends up calling Base::get instead, which may be why the above is not true. My C++ navigation skills is currently failing me, to follow which code this Base::get call will execute.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know of a way to test this? From reading around other llvm docs, it seems like user-mode qemu is the common approach for testing big-endian systems. Do any of you have experience with that and the Python API?

Copy link
Contributor

@makslevental makslevental Nov 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it goes to DenseElementsAttr::getFromRawBuffer

https://github.com/makslevental/llvm-project/blob/84aa02d3fa1f1f614c4f3c144ec118b2f05ae6b0/mlir/lib/IR/BuiltinAttributes.cpp#L1064

which does indeed hit Base::get

https://github.com/makslevental/llvm-project/blob/84aa02d3fa1f1f614c4f3c144ec118b2f05ae6b0/mlir/lib/IR/BuiltinAttributes.cpp#L1354

which then goes to static ConcreteT StorageUserBase::get

https://github.com/makslevental/llvm-project/blob/a9636b7f60f283926c66e96c036f5b5d9e57c026/mlir/include/mlir/IR/StorageUniquerSupport.h#L177

I think possibly you could just remove the explicit use of little in the various places and defer to host endiannes?

I mean really this is all academic because as I understand it, there are no extant big-endian systems we care about? Someone with more experience can comment on this hypothesis...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointers!

Unfortunately, it will not work simply remove the "little" part from the numpy calls, as it defaults to big-endian and not the native byteorder.

As you mention, I doubt any big-endian users will ever call this code anyways. What do you think about just throwing a "this is unsupported" exception (which is the current behaviour) if someone calls it on a big-endian system?

I think that check can be made something like:

if (llvm::endianness::native == llvm::endianness::big) {
    throw py::type_error("Boolean types are unsupported on big-endian systems");
}

I could also use the above check to switch between "little"/"big" endian in the numpy call, but I am considering if throwing is actually better, given we have no way to really test it?

Copy link
Contributor

@makslevental makslevental Nov 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya throwing here is acceptable and reasonable (if any IBM customers ever show up we can discuss it then 😅)

MlirType byteType = getShapedType(mlirIntegerTypeUnsignedGet(context, 8),
explicitShape, view);
MlirAttribute intermediateAttr =
mlirDenseElementsAttrRawBufferGet(byteType, view.len, view.buf);

uint8_t *unpackedData = static_cast<uint8_t *>(
const_cast<void *>(mlirDenseElementsAttrGetRawData(intermediateAttr)));
py::array_t<uint8_t> unpackedArray(view.len, unpackedData);

py::module numpy = py::module::import("numpy");
py::object packbits_func = numpy.attr("packbits");
py::object packed_booleans =
packbits_func(unpackedArray, "bitorder"_a = "little");
Comment on lines +1060 to +1061
Copy link
Contributor

@makslevental makslevental Oct 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

................ you can do this ..................... wow color me shocked I never noticed/knew you could do "kwargs" like this on cpp side.

py::buffer_info pythonBuffer = packed_booleans.cast<py::buffer>().request();

MlirType bitpackedType =
getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
pythonBuffer.ptr);
}

// This does the opposite transformation of
// `getBitpackedAttributeFromBooleanBuffer`
py::buffer_info getBooleanBufferFromBitpackedAttribute() {
int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
int64_t numBitpackedBytes = (numBooleans + 7) / 8;
uint8_t *bitpackedData = static_cast<uint8_t *>(
const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
py::array_t<uint8_t> packedArray(numBitpackedBytes, bitpackedData);

py::module numpy = py::module::import("numpy");
py::object unpackbits_func = numpy.attr("unpackbits");
py::object unpacked_booleans =
unpackbits_func(packedArray, "bitorder"_a = "little");
py::buffer_info pythonBuffer =
unpacked_booleans.cast<py::buffer>().request();

MlirType shapedType = mlirAttributeGetType(*this);
return bufferInfo<bool>(shapedType, (bool *)pythonBuffer.ptr, "?");
}

template <typename Type>
py::buffer_info bufferInfo(MlirType shapedType,
const char *explicitFormat = nullptr) {
intptr_t rank = mlirShapedTypeGetRank(shapedType);
// Prepare the data for the buffer_info.
// Buffer is configured for read-only access below.
// Buffer is configured for read-only access inside the `bufferInfo` call.
Type *data = static_cast<Type *>(
const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
return bufferInfo<Type>(shapedType, data, explicitFormat);
}

template <typename Type>
py::buffer_info bufferInfo(MlirType shapedType, Type *data,
const char *explicitFormat = nullptr) {
intptr_t rank = mlirShapedTypeGetRank(shapedType);
// Prepare the shape for the buffer_info.
SmallVector<intptr_t, 4> shape;
for (intptr_t i = 0; i < rank; ++i)
Expand Down
72 changes: 72 additions & 0 deletions mlir/test/python/ir/array_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,78 @@ def testGetDenseElementsF64():
print(np.array(attr))


### 1 bit/boolean integer arrays
# CHECK-LABEL: TEST: testGetDenseElementsI1Signless
@run
def testGetDenseElementsI1Signless():
with Context():
array = np.array([True], dtype=np.bool_)
attr = DenseElementsAttr.get(array)
# CHECK: dense<true> : tensor<1xi1>
print(attr)
# CHECK: {{\[}} True]
print(np.array(attr))

array = np.array([[True, False, True], [True, True, False]], dtype=np.bool_)
attr = DenseElementsAttr.get(array)
# CHECK: dense<{{\[}}[true, false, true], [true, true, false]]> : tensor<2x3xi1>
print(attr)
# CHECK: {{\[}}[ True False True]
# CHECK: {{\[}} True True False]]
print(np.array(attr))

array = np.array(
[[True, True, False, False], [True, False, True, False]], dtype=np.bool_
)
attr = DenseElementsAttr.get(array)
# CHECK: dense<{{\[}}[true, true, false, false], [true, false, true, false]]> : tensor<2x4xi1>
print(attr)
# CHECK: {{\[}}[ True True False False]
# CHECK: {{\[}} True False True False]]
print(np.array(attr))

array = np.array(
[
[True, True, False, False],
[True, False, True, False],
[False, False, False, False],
[True, True, True, True],
[True, False, False, True],
],
dtype=np.bool_,
)
attr = DenseElementsAttr.get(array)
# CHECK: dense<{{\[}}[true, true, false, false], [true, false, true, false], [false, false, false, false], [true, true, true, true], [true, false, false, true]]> : tensor<5x4xi1>
print(attr)
# CHECK: {{\[}}[ True True False False]
# CHECK: {{\[}} True False True False]
# CHECK: {{\[}}False False False False]
# CHECK: {{\[}} True True True True]
# CHECK: {{\[}} True False False True]]
print(np.array(attr))

array = np.array(
[
[True, True, False, False, True, True, False, False, False],
[False, False, False, True, False, True, True, False, True],
],
dtype=np.bool_,
)
attr = DenseElementsAttr.get(array)
# CHECK: dense<{{\[}}[true, true, false, false, true, true, false, false, false], [false, false, false, true, false, true, true, false, true]]> : tensor<2x9xi1>
print(attr)
# CHECK: {{\[}}[ True True False False True True False False False]
# CHECK: {{\[}}False False False True False True True False True]]
print(np.array(attr))

array = np.array([], dtype=np.bool_)
attr = DenseElementsAttr.get(array)
# CHECK: dense<> : tensor<0xi1>
print(attr)
# CHECK: {{\[}}]
print(np.array(attr))


### 16 bit integer arrays
# CHECK-LABEL: TEST: testGetDenseElementsI16Signless
@run
Expand Down