-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[MLIR,Python] Support converting boolean numpy arrays to and from mlir attributes #113064
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
053d0b7
d5da538
52b49ac
6d3204c
f8a21fc
73df6fb
93156b1
90868b8
d216d43
6543732
75c8264
e5b10a3
b65d7d6
c9b2100
a1ae520
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
#include "IRModule.h" | ||
|
||
#include "PybindUtils.h" | ||
#include <pybind11/numpy.h> | ||
|
||
#include "llvm/ADT/ScopeExit.h" | ||
#include "llvm/Support/raw_ostream.h" | ||
|
@@ -757,103 +758,10 @@ class PyDenseElementsAttribute | |
throw py::error_already_set(); | ||
} | ||
auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); }); | ||
SmallVector<int64_t> shape; | ||
if (explicitShape) { | ||
shape.append(explicitShape->begin(), explicitShape->end()); | ||
} else { | ||
shape.append(view.shape, view.shape + view.ndim); | ||
} | ||
|
||
MlirAttribute encodingAttr = mlirAttributeGetNull(); | ||
MlirContext context = contextWrapper->get(); | ||
|
||
// Detect format codes that are suitable for bulk loading. This includes | ||
// all byte aligned integer and floating point types up to 8 bytes. | ||
// Notably, this excludes, bool (which needs to be bit-packed) and | ||
// other exotics which do not have a direct representation in the buffer | ||
// protocol (i.e. complex, etc). | ||
std::optional<MlirType> bulkLoadElementType; | ||
if (explicitType) { | ||
bulkLoadElementType = *explicitType; | ||
} else { | ||
std::string_view format(view.format); | ||
if (format == "f") { | ||
// f32 | ||
assert(view.itemsize == 4 && "mismatched array itemsize"); | ||
bulkLoadElementType = mlirF32TypeGet(context); | ||
} else if (format == "d") { | ||
// f64 | ||
assert(view.itemsize == 8 && "mismatched array itemsize"); | ||
bulkLoadElementType = mlirF64TypeGet(context); | ||
} else if (format == "e") { | ||
// f16 | ||
assert(view.itemsize == 2 && "mismatched array itemsize"); | ||
bulkLoadElementType = mlirF16TypeGet(context); | ||
} else if (isSignedIntegerFormat(format)) { | ||
if (view.itemsize == 4) { | ||
// i32 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 32) | ||
: mlirIntegerTypeSignedGet(context, 32); | ||
} else if (view.itemsize == 8) { | ||
// i64 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 64) | ||
: mlirIntegerTypeSignedGet(context, 64); | ||
} else if (view.itemsize == 1) { | ||
// i8 | ||
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) | ||
: mlirIntegerTypeSignedGet(context, 8); | ||
} else if (view.itemsize == 2) { | ||
// i16 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 16) | ||
: mlirIntegerTypeSignedGet(context, 16); | ||
} | ||
} else if (isUnsignedIntegerFormat(format)) { | ||
if (view.itemsize == 4) { | ||
// unsigned i32 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 32) | ||
: mlirIntegerTypeUnsignedGet(context, 32); | ||
} else if (view.itemsize == 8) { | ||
// unsigned i64 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 64) | ||
: mlirIntegerTypeUnsignedGet(context, 64); | ||
} else if (view.itemsize == 1) { | ||
// i8 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 8) | ||
: mlirIntegerTypeUnsignedGet(context, 8); | ||
} else if (view.itemsize == 2) { | ||
// i16 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 16) | ||
: mlirIntegerTypeUnsignedGet(context, 16); | ||
} | ||
} | ||
if (!bulkLoadElementType) { | ||
throw std::invalid_argument( | ||
std::string("unimplemented array format conversion from format: ") + | ||
std::string(format)); | ||
} | ||
} | ||
|
||
MlirType shapedType; | ||
if (mlirTypeIsAShaped(*bulkLoadElementType)) { | ||
if (explicitShape) { | ||
throw std::invalid_argument("Shape can only be specified explicitly " | ||
"when the type is not a shaped type."); | ||
} | ||
shapedType = *bulkLoadElementType; | ||
} else { | ||
shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(), | ||
*bulkLoadElementType, encodingAttr); | ||
} | ||
size_t rawBufferSize = view.len; | ||
MlirAttribute attr = | ||
mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf); | ||
MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType, | ||
explicitShape, context); | ||
if (mlirAttributeIsNull(attr)) { | ||
throw std::invalid_argument( | ||
"DenseElementsAttr could not be constructed from the given buffer. " | ||
|
@@ -963,6 +871,13 @@ class PyDenseElementsAttribute | |
// unsigned i16 | ||
return bufferInfo<uint16_t>(shapedType); | ||
} | ||
} else if (mlirTypeIsAInteger(elementType) && | ||
mlirIntegerTypeGetWidth(elementType) == 1) { | ||
// i1 / bool | ||
// We can not send the buffer directly back to Python, because the i1 | ||
// values are bitpacked within MLIR. We call numpy's unpackbits function | ||
// to convert the bytes. | ||
return getBooleanBufferFromBitpackedAttribute(); | ||
} | ||
|
||
// TODO: Currently crashes the program. | ||
|
@@ -1016,14 +931,177 @@ class PyDenseElementsAttribute | |
code == 'q'; | ||
} | ||
|
||
static MlirType | ||
getShapedType(std::optional<MlirType> bulkLoadElementType, | ||
std::optional<std::vector<int64_t>> explicitShape, | ||
Py_buffer &view) { | ||
SmallVector<int64_t> shape; | ||
if (explicitShape) { | ||
shape.append(explicitShape->begin(), explicitShape->end()); | ||
} else { | ||
shape.append(view.shape, view.shape + view.ndim); | ||
} | ||
|
||
if (mlirTypeIsAShaped(*bulkLoadElementType)) { | ||
if (explicitShape) { | ||
throw std::invalid_argument("Shape can only be specified explicitly " | ||
"when the type is not a shaped type."); | ||
} | ||
return *bulkLoadElementType; | ||
} else { | ||
MlirAttribute encodingAttr = mlirAttributeGetNull(); | ||
return mlirRankedTensorTypeGet(shape.size(), shape.data(), | ||
*bulkLoadElementType, encodingAttr); | ||
} | ||
} | ||
|
||
static MlirAttribute getAttributeFromBuffer( | ||
Py_buffer &view, bool signless, std::optional<PyType> explicitType, | ||
std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) { | ||
// Detect format codes that are suitable for bulk loading. This includes | ||
// all byte aligned integer and floating point types up to 8 bytes. | ||
// Notably, this excludes exotics types which do not have a direct | ||
// representation in the buffer protocol (i.e. complex, etc). | ||
std::optional<MlirType> bulkLoadElementType; | ||
if (explicitType) { | ||
bulkLoadElementType = *explicitType; | ||
} else { | ||
std::string_view format(view.format); | ||
if (format == "f") { | ||
// f32 | ||
assert(view.itemsize == 4 && "mismatched array itemsize"); | ||
bulkLoadElementType = mlirF32TypeGet(context); | ||
} else if (format == "d") { | ||
// f64 | ||
assert(view.itemsize == 8 && "mismatched array itemsize"); | ||
bulkLoadElementType = mlirF64TypeGet(context); | ||
} else if (format == "e") { | ||
// f16 | ||
assert(view.itemsize == 2 && "mismatched array itemsize"); | ||
bulkLoadElementType = mlirF16TypeGet(context); | ||
} else if (format == "?") { | ||
kasper0406 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// i1 | ||
// The i1 type needs to be bit-packed, so we will handle it seperately | ||
return getBitpackedAttributeFromBooleanBuffer(view, explicitShape, | ||
context); | ||
} else if (isSignedIntegerFormat(format)) { | ||
if (view.itemsize == 4) { | ||
// i32 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 32) | ||
: mlirIntegerTypeSignedGet(context, 32); | ||
} else if (view.itemsize == 8) { | ||
// i64 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 64) | ||
: mlirIntegerTypeSignedGet(context, 64); | ||
} else if (view.itemsize == 1) { | ||
// i8 | ||
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) | ||
: mlirIntegerTypeSignedGet(context, 8); | ||
} else if (view.itemsize == 2) { | ||
// i16 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 16) | ||
: mlirIntegerTypeSignedGet(context, 16); | ||
} | ||
} else if (isUnsignedIntegerFormat(format)) { | ||
if (view.itemsize == 4) { | ||
// unsigned i32 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 32) | ||
: mlirIntegerTypeUnsignedGet(context, 32); | ||
} else if (view.itemsize == 8) { | ||
// unsigned i64 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 64) | ||
: mlirIntegerTypeUnsignedGet(context, 64); | ||
} else if (view.itemsize == 1) { | ||
// i8 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 8) | ||
: mlirIntegerTypeUnsignedGet(context, 8); | ||
} else if (view.itemsize == 2) { | ||
// i16 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 16) | ||
: mlirIntegerTypeUnsignedGet(context, 16); | ||
} | ||
} | ||
if (!bulkLoadElementType) { | ||
throw std::invalid_argument( | ||
std::string("unimplemented array format conversion from format: ") + | ||
std::string(format)); | ||
} | ||
} | ||
|
||
MlirType type = getShapedType(bulkLoadElementType, explicitShape, view); | ||
return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf); | ||
} | ||
|
||
// There is a complication for boolean numpy arrays, as numpy represents them | ||
// as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 booleans | ||
// per byte. This function does the bit-packing respecting endianess. | ||
static MlirAttribute getBitpackedAttributeFromBooleanBuffer( | ||
Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape, | ||
MlirContext &context) { | ||
// First read the content of the python buffer as u8's, to correct for | ||
// endianess | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry don't understand? what/why does this need to be "corrected"? I can see you're setting There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From https://mlir.llvm.org/docs/BytecodeFormat/#fixed-width-integers my understanding is that MLIR attributes always have their data stored in a little-endian format. If that is not true for MlirAttributes in general, the code is wrong. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @stellaraccident this is the one i'm most uncertain about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's for bytecode rather than in memory. If this is converting the in memory format, then that section doesn't apply. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ack, one thing I don't understand though, is that if I follow the It is possible that it ends up calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you know of a way to test this? From reading around other llvm docs, it seems like user-mode qemu is the common approach for testing big-endian systems. Do any of you have experience with that and the Python API? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
which does indeed hit which then goes to I think possibly you could just remove the explicit use of I mean really this is all academic because as I understand it, there are no extant big-endian systems we care about? Someone with more experience can comment on this hypothesis... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the pointers! Unfortunately, it will not work simply remove the "little" part from the numpy calls, as it defaults to big-endian and not the native byteorder. As you mention, I doubt any big-endian users will ever call this code anyways. What do you think about just throwing a "this is unsupported" exception (which is the current behaviour) if someone calls it on a big-endian system? I think that check can be made something like: if (llvm::endianness::native == llvm::endianness::big) {
throw py::type_error("Boolean types are unsupported on big-endian systems");
} I could also use the above check to switch between "little"/"big" endian in the numpy call, but I am considering if throwing is actually better, given we have no way to really test it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ya throwing here is acceptable and reasonable (if any IBM customers ever show up we can discuss it then 😅) |
||
MlirType byteType = getShapedType(mlirIntegerTypeUnsignedGet(context, 8), | ||
explicitShape, view); | ||
MlirAttribute intermediateAttr = | ||
mlirDenseElementsAttrRawBufferGet(byteType, view.len, view.buf); | ||
|
||
uint8_t *unpackedData = static_cast<uint8_t *>( | ||
const_cast<void *>(mlirDenseElementsAttrGetRawData(intermediateAttr))); | ||
py::array_t<uint8_t> unpackedArray(view.len, unpackedData); | ||
|
||
py::module numpy = py::module::import("numpy"); | ||
py::object packbits_func = numpy.attr("packbits"); | ||
py::object packed_booleans = | ||
packbits_func(unpackedArray, "bitorder"_a = "little"); | ||
Comment on lines
+1060
to
+1061
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ................ you can do this ..................... wow color me shocked I never noticed/knew you could do "kwargs" like this on cpp side. |
||
py::buffer_info pythonBuffer = packed_booleans.cast<py::buffer>().request(); | ||
|
||
MlirType bitpackedType = | ||
getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view); | ||
return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size, | ||
pythonBuffer.ptr); | ||
} | ||
|
||
// This does the opposite transformation of | ||
// `getBitpackedAttributeFromBooleanBuffer` | ||
py::buffer_info getBooleanBufferFromBitpackedAttribute() { | ||
int64_t numBooleans = mlirElementsAttrGetNumElements(*this); | ||
int64_t numBitpackedBytes = (numBooleans + 7) / 8; | ||
kasper0406 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
uint8_t *bitpackedData = static_cast<uint8_t *>( | ||
const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); | ||
py::array_t<uint8_t> packedArray(numBitpackedBytes, bitpackedData); | ||
|
||
py::module numpy = py::module::import("numpy"); | ||
py::object unpackbits_func = numpy.attr("unpackbits"); | ||
py::object unpacked_booleans = | ||
unpackbits_func(packedArray, "bitorder"_a = "little"); | ||
py::buffer_info pythonBuffer = | ||
unpacked_booleans.cast<py::buffer>().request(); | ||
|
||
MlirType shapedType = mlirAttributeGetType(*this); | ||
return bufferInfo<bool>(shapedType, (bool *)pythonBuffer.ptr, "?"); | ||
} | ||
|
||
template <typename Type> | ||
py::buffer_info bufferInfo(MlirType shapedType, | ||
const char *explicitFormat = nullptr) { | ||
intptr_t rank = mlirShapedTypeGetRank(shapedType); | ||
// Prepare the data for the buffer_info. | ||
// Buffer is configured for read-only access below. | ||
// Buffer is configured for read-only access inside the `bufferInfo` call. | ||
Type *data = static_cast<Type *>( | ||
const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); | ||
return bufferInfo<Type>(shapedType, data, explicitFormat); | ||
} | ||
|
||
template <typename Type> | ||
py::buffer_info bufferInfo(MlirType shapedType, Type *data, | ||
const char *explicitFormat = nullptr) { | ||
intptr_t rank = mlirShapedTypeGetRank(shapedType); | ||
// Prepare the shape for the buffer_info. | ||
SmallVector<intptr_t, 4> shape; | ||
for (intptr_t i = 0; i < rank; ++i) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to move the if-statements to the
getAttributeFromBuffer
method, as now the i1 case will not follow the usual flow, but instead callgetBitpackedAttributeFromBooleanBuffer
to construct the MlirAttribute.