-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR,Python] Support converting boolean numpy arrays to and from mlir attributes #113064
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
053d0b7
[Python] Attempt at getting boolean types working
kasper0406 d5da538
Refactorings
kasper0406 52b49ac
Cleanups
kasper0406 6d3204c
Fix style
kasper0406 f8a21fc
More styles
kasper0406 73df6fb
Use numpy to bitpack and unpack, to avoid additional fields
kasper0406 93156b1
Small refactoring
kasper0406 90868b8
Fix styles
kasper0406 d216d43
Minor rename
kasper0406 6543732
Code format
kasper0406 75c8264
Address comments
kasper0406 e5b10a3
Fix nits
kasper0406 b65d7d6
Fix comments
kasper0406 c9b2100
Throw an exception if used on big-endian machines
kasper0406 a1ae520
Fix C++ formatting
kasper0406 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
#include "IRModule.h" | ||
|
||
#include "PybindUtils.h" | ||
#include <pybind11/numpy.h> | ||
|
||
#include "llvm/ADT/ScopeExit.h" | ||
#include "llvm/Support/raw_ostream.h" | ||
|
@@ -757,103 +758,10 @@ class PyDenseElementsAttribute | |
throw py::error_already_set(); | ||
} | ||
auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); }); | ||
SmallVector<int64_t> shape; | ||
if (explicitShape) { | ||
shape.append(explicitShape->begin(), explicitShape->end()); | ||
} else { | ||
shape.append(view.shape, view.shape + view.ndim); | ||
} | ||
|
||
MlirAttribute encodingAttr = mlirAttributeGetNull(); | ||
MlirContext context = contextWrapper->get(); | ||
|
||
// Detect format codes that are suitable for bulk loading. This includes | ||
// all byte aligned integer and floating point types up to 8 bytes. | ||
// Notably, this excludes, bool (which needs to be bit-packed) and | ||
// other exotics which do not have a direct representation in the buffer | ||
// protocol (i.e. complex, etc). | ||
std::optional<MlirType> bulkLoadElementType; | ||
if (explicitType) { | ||
bulkLoadElementType = *explicitType; | ||
} else { | ||
std::string_view format(view.format); | ||
if (format == "f") { | ||
// f32 | ||
assert(view.itemsize == 4 && "mismatched array itemsize"); | ||
bulkLoadElementType = mlirF32TypeGet(context); | ||
} else if (format == "d") { | ||
// f64 | ||
assert(view.itemsize == 8 && "mismatched array itemsize"); | ||
bulkLoadElementType = mlirF64TypeGet(context); | ||
} else if (format == "e") { | ||
// f16 | ||
assert(view.itemsize == 2 && "mismatched array itemsize"); | ||
bulkLoadElementType = mlirF16TypeGet(context); | ||
} else if (isSignedIntegerFormat(format)) { | ||
if (view.itemsize == 4) { | ||
// i32 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 32) | ||
: mlirIntegerTypeSignedGet(context, 32); | ||
} else if (view.itemsize == 8) { | ||
// i64 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 64) | ||
: mlirIntegerTypeSignedGet(context, 64); | ||
} else if (view.itemsize == 1) { | ||
// i8 | ||
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) | ||
: mlirIntegerTypeSignedGet(context, 8); | ||
} else if (view.itemsize == 2) { | ||
// i16 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 16) | ||
: mlirIntegerTypeSignedGet(context, 16); | ||
} | ||
} else if (isUnsignedIntegerFormat(format)) { | ||
if (view.itemsize == 4) { | ||
// unsigned i32 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 32) | ||
: mlirIntegerTypeUnsignedGet(context, 32); | ||
} else if (view.itemsize == 8) { | ||
// unsigned i64 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 64) | ||
: mlirIntegerTypeUnsignedGet(context, 64); | ||
} else if (view.itemsize == 1) { | ||
// i8 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 8) | ||
: mlirIntegerTypeUnsignedGet(context, 8); | ||
} else if (view.itemsize == 2) { | ||
// i16 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 16) | ||
: mlirIntegerTypeUnsignedGet(context, 16); | ||
} | ||
} | ||
if (!bulkLoadElementType) { | ||
throw std::invalid_argument( | ||
std::string("unimplemented array format conversion from format: ") + | ||
std::string(format)); | ||
} | ||
} | ||
|
||
MlirType shapedType; | ||
if (mlirTypeIsAShaped(*bulkLoadElementType)) { | ||
if (explicitShape) { | ||
throw std::invalid_argument("Shape can only be specified explicitly " | ||
"when the type is not a shaped type."); | ||
} | ||
shapedType = *bulkLoadElementType; | ||
} else { | ||
shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(), | ||
*bulkLoadElementType, encodingAttr); | ||
} | ||
size_t rawBufferSize = view.len; | ||
MlirAttribute attr = | ||
mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf); | ||
MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType, | ||
explicitShape, context); | ||
if (mlirAttributeIsNull(attr)) { | ||
throw std::invalid_argument( | ||
"DenseElementsAttr could not be constructed from the given buffer. " | ||
|
@@ -963,6 +871,13 @@ class PyDenseElementsAttribute | |
// unsigned i16 | ||
return bufferInfo<uint16_t>(shapedType); | ||
} | ||
} else if (mlirTypeIsAInteger(elementType) && | ||
mlirIntegerTypeGetWidth(elementType) == 1) { | ||
// i1 / bool | ||
// We can not send the buffer directly back to Python, because the i1 | ||
// values are bitpacked within MLIR. We call numpy's unpackbits function | ||
// to convert the bytes. | ||
return getBooleanBufferFromBitpackedAttribute(); | ||
} | ||
|
||
// TODO: Currently crashes the program. | ||
|
@@ -1016,14 +931,183 @@ class PyDenseElementsAttribute | |
code == 'q'; | ||
} | ||
|
||
static MlirType | ||
getShapedType(std::optional<MlirType> bulkLoadElementType, | ||
std::optional<std::vector<int64_t>> explicitShape, | ||
Py_buffer &view) { | ||
SmallVector<int64_t> shape; | ||
if (explicitShape) { | ||
shape.append(explicitShape->begin(), explicitShape->end()); | ||
} else { | ||
shape.append(view.shape, view.shape + view.ndim); | ||
} | ||
|
||
if (mlirTypeIsAShaped(*bulkLoadElementType)) { | ||
if (explicitShape) { | ||
throw std::invalid_argument("Shape can only be specified explicitly " | ||
"when the type is not a shaped type."); | ||
} | ||
return *bulkLoadElementType; | ||
} else { | ||
MlirAttribute encodingAttr = mlirAttributeGetNull(); | ||
return mlirRankedTensorTypeGet(shape.size(), shape.data(), | ||
*bulkLoadElementType, encodingAttr); | ||
} | ||
} | ||
|
||
static MlirAttribute getAttributeFromBuffer( | ||
Py_buffer &view, bool signless, std::optional<PyType> explicitType, | ||
std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) { | ||
// Detect format codes that are suitable for bulk loading. This includes | ||
// all byte aligned integer and floating point types up to 8 bytes. | ||
// Notably, this excludes exotics types which do not have a direct | ||
// representation in the buffer protocol (i.e. complex, etc). | ||
std::optional<MlirType> bulkLoadElementType; | ||
if (explicitType) { | ||
bulkLoadElementType = *explicitType; | ||
} else { | ||
std::string_view format(view.format); | ||
if (format == "f") { | ||
// f32 | ||
assert(view.itemsize == 4 && "mismatched array itemsize"); | ||
bulkLoadElementType = mlirF32TypeGet(context); | ||
} else if (format == "d") { | ||
// f64 | ||
assert(view.itemsize == 8 && "mismatched array itemsize"); | ||
bulkLoadElementType = mlirF64TypeGet(context); | ||
} else if (format == "e") { | ||
// f16 | ||
assert(view.itemsize == 2 && "mismatched array itemsize"); | ||
bulkLoadElementType = mlirF16TypeGet(context); | ||
} else if (format == "?") { | ||
kasper0406 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// i1 | ||
// The i1 type needs to be bit-packed, so we will handle it seperately | ||
return getBitpackedAttributeFromBooleanBuffer(view, explicitShape, | ||
context); | ||
} else if (isSignedIntegerFormat(format)) { | ||
if (view.itemsize == 4) { | ||
// i32 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 32) | ||
: mlirIntegerTypeSignedGet(context, 32); | ||
} else if (view.itemsize == 8) { | ||
// i64 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 64) | ||
: mlirIntegerTypeSignedGet(context, 64); | ||
} else if (view.itemsize == 1) { | ||
// i8 | ||
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) | ||
: mlirIntegerTypeSignedGet(context, 8); | ||
} else if (view.itemsize == 2) { | ||
// i16 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 16) | ||
: mlirIntegerTypeSignedGet(context, 16); | ||
} | ||
} else if (isUnsignedIntegerFormat(format)) { | ||
if (view.itemsize == 4) { | ||
// unsigned i32 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 32) | ||
: mlirIntegerTypeUnsignedGet(context, 32); | ||
} else if (view.itemsize == 8) { | ||
// unsigned i64 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 64) | ||
: mlirIntegerTypeUnsignedGet(context, 64); | ||
} else if (view.itemsize == 1) { | ||
// i8 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 8) | ||
: mlirIntegerTypeUnsignedGet(context, 8); | ||
} else if (view.itemsize == 2) { | ||
// i16 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 16) | ||
: mlirIntegerTypeUnsignedGet(context, 16); | ||
} | ||
} | ||
if (!bulkLoadElementType) { | ||
throw std::invalid_argument( | ||
std::string("unimplemented array format conversion from format: ") + | ||
std::string(format)); | ||
} | ||
} | ||
|
||
MlirType type = getShapedType(bulkLoadElementType, explicitShape, view); | ||
return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf); | ||
} | ||
|
||
// There is a complication for boolean numpy arrays, as numpy represents them | ||
// as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 booleans | ||
// per byte. | ||
static MlirAttribute getBitpackedAttributeFromBooleanBuffer( | ||
Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape, | ||
MlirContext &context) { | ||
if (llvm::endianness::native != llvm::endianness::little) { | ||
// Given we have no good way of testing the behavior on big-endian systems | ||
// we will throw | ||
throw py::type_error("Constructing a bit-packed MLIR attribute is " | ||
"unsupported on big-endian systems"); | ||
} | ||
|
||
py::array_t<uint8_t> unpackedArray(view.len, | ||
static_cast<uint8_t *>(view.buf)); | ||
|
||
py::module numpy = py::module::import("numpy"); | ||
py::object packbits_func = numpy.attr("packbits"); | ||
py::object packed_booleans = | ||
packbits_func(unpackedArray, "bitorder"_a = "little"); | ||
Comment on lines
+1060
to
+1061
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ................ you can do this ..................... wow color me shocked I never noticed/knew you could do "kwargs" like this on cpp side. |
||
py::buffer_info pythonBuffer = packed_booleans.cast<py::buffer>().request(); | ||
|
||
MlirType bitpackedType = | ||
getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view); | ||
return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size, | ||
pythonBuffer.ptr); | ||
} | ||
|
||
// This does the opposite transformation of | ||
// `getBitpackedAttributeFromBooleanBuffer` | ||
py::buffer_info getBooleanBufferFromBitpackedAttribute() { | ||
if (llvm::endianness::native != llvm::endianness::little) { | ||
// Given we have no good way of testing the behavior on big-endian systems | ||
// we will throw | ||
throw py::type_error("Constructing a numpy array from a MLIR attribute " | ||
"is unsupported on big-endian systems"); | ||
} | ||
|
||
int64_t numBooleans = mlirElementsAttrGetNumElements(*this); | ||
int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8); | ||
uint8_t *bitpackedData = static_cast<uint8_t *>( | ||
const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); | ||
py::array_t<uint8_t> packedArray(numBitpackedBytes, bitpackedData); | ||
|
||
py::module numpy = py::module::import("numpy"); | ||
py::object unpackbits_func = numpy.attr("unpackbits"); | ||
py::object unpacked_booleans = | ||
unpackbits_func(packedArray, "bitorder"_a = "little"); | ||
py::buffer_info pythonBuffer = | ||
unpacked_booleans.cast<py::buffer>().request(); | ||
|
||
MlirType shapedType = mlirAttributeGetType(*this); | ||
return bufferInfo<bool>(shapedType, (bool *)pythonBuffer.ptr, "?"); | ||
} | ||
|
||
template <typename Type> | ||
py::buffer_info bufferInfo(MlirType shapedType, | ||
const char *explicitFormat = nullptr) { | ||
intptr_t rank = mlirShapedTypeGetRank(shapedType); | ||
// Prepare the data for the buffer_info. | ||
// Buffer is configured for read-only access below. | ||
// Buffer is configured for read-only access inside the `bufferInfo` call. | ||
Type *data = static_cast<Type *>( | ||
const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); | ||
return bufferInfo<Type>(shapedType, data, explicitFormat); | ||
} | ||
|
||
template <typename Type> | ||
py::buffer_info bufferInfo(MlirType shapedType, Type *data, | ||
const char *explicitFormat = nullptr) { | ||
intptr_t rank = mlirShapedTypeGetRank(shapedType); | ||
// Prepare the shape for the buffer_info. | ||
SmallVector<intptr_t, 4> shape; | ||
for (intptr_t i = 0; i < rank; ++i) | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to move the if-statements to the
getAttributeFromBuffer
method, as now the i1 case will not follow the usual flow, but instead callgetBitpackedAttributeFromBooleanBuffer
to construct the MlirAttribute.