Skip to content

Commit 1824e45

Browse files
authored
[MLIR,Python] Support converting boolean numpy arrays to and from mlir attributes (unrevert) (#115481)
This PR re-introduces the functionality of #113064, which was reverted in 0a68171 due to memory lifetime issues. Notice that I was not able to re-produce the ASan results myself, so I have not been able to verify that this PR really fixes the issue. --- Currently it is unsupported to: 1. Convert a MlirAttribute with type i1 to a numpy array 2. Convert a boolean numpy array to a MlirAttribute Currently the entire Python application violently crashes with a quite poor error message pybind/pybind11#3336 The complication handling these conversions, is that MlirAttribute represent booleans as a bit-packed i1 type, whereas numpy represents booleans as a byte array with 8 bit used per boolean. This PR proposes the following approach: 1. When converting a i1 typed MlirAttribute to a numpy array, we can not directly use the underlying raw data backing the MlirAttribute as a buffer to Python, as done for other types. Instead, a copy of the data is generated using numpy's unpackbits function, and the result is send back to Python. 2. When constructing a MlirAttribute from a numpy array, first the python data is read as a uint8_t to get it converted to the endianess used internally in mlir. Then the booleans are bitpacked using numpy's bitpack function, and the bitpacked array is saved as the MlirAttribute representation.
1 parent 804d3c4 commit 1824e45

File tree

2 files changed

+267
-95
lines changed

2 files changed

+267
-95
lines changed

mlir/lib/Bindings/Python/IRAttributes.cpp

Lines changed: 195 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "IRModule.h"
1414

1515
#include "PybindUtils.h"
16+
#include <pybind11/numpy.h>
1617

1718
#include "llvm/ADT/ScopeExit.h"
1819
#include "llvm/Support/raw_ostream.h"
@@ -757,103 +758,10 @@ class PyDenseElementsAttribute
757758
throw py::error_already_set();
758759
}
759760
auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
760-
SmallVector<int64_t> shape;
761-
if (explicitShape) {
762-
shape.append(explicitShape->begin(), explicitShape->end());
763-
} else {
764-
shape.append(view.shape, view.shape + view.ndim);
765-
}
766761

767-
MlirAttribute encodingAttr = mlirAttributeGetNull();
768762
MlirContext context = contextWrapper->get();
769-
770-
// Detect format codes that are suitable for bulk loading. This includes
771-
// all byte aligned integer and floating point types up to 8 bytes.
772-
// Notably, this excludes, bool (which needs to be bit-packed) and
773-
// other exotics which do not have a direct representation in the buffer
774-
// protocol (i.e. complex, etc).
775-
std::optional<MlirType> bulkLoadElementType;
776-
if (explicitType) {
777-
bulkLoadElementType = *explicitType;
778-
} else {
779-
std::string_view format(view.format);
780-
if (format == "f") {
781-
// f32
782-
assert(view.itemsize == 4 && "mismatched array itemsize");
783-
bulkLoadElementType = mlirF32TypeGet(context);
784-
} else if (format == "d") {
785-
// f64
786-
assert(view.itemsize == 8 && "mismatched array itemsize");
787-
bulkLoadElementType = mlirF64TypeGet(context);
788-
} else if (format == "e") {
789-
// f16
790-
assert(view.itemsize == 2 && "mismatched array itemsize");
791-
bulkLoadElementType = mlirF16TypeGet(context);
792-
} else if (isSignedIntegerFormat(format)) {
793-
if (view.itemsize == 4) {
794-
// i32
795-
bulkLoadElementType = signless
796-
? mlirIntegerTypeGet(context, 32)
797-
: mlirIntegerTypeSignedGet(context, 32);
798-
} else if (view.itemsize == 8) {
799-
// i64
800-
bulkLoadElementType = signless
801-
? mlirIntegerTypeGet(context, 64)
802-
: mlirIntegerTypeSignedGet(context, 64);
803-
} else if (view.itemsize == 1) {
804-
// i8
805-
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
806-
: mlirIntegerTypeSignedGet(context, 8);
807-
} else if (view.itemsize == 2) {
808-
// i16
809-
bulkLoadElementType = signless
810-
? mlirIntegerTypeGet(context, 16)
811-
: mlirIntegerTypeSignedGet(context, 16);
812-
}
813-
} else if (isUnsignedIntegerFormat(format)) {
814-
if (view.itemsize == 4) {
815-
// unsigned i32
816-
bulkLoadElementType = signless
817-
? mlirIntegerTypeGet(context, 32)
818-
: mlirIntegerTypeUnsignedGet(context, 32);
819-
} else if (view.itemsize == 8) {
820-
// unsigned i64
821-
bulkLoadElementType = signless
822-
? mlirIntegerTypeGet(context, 64)
823-
: mlirIntegerTypeUnsignedGet(context, 64);
824-
} else if (view.itemsize == 1) {
825-
// i8
826-
bulkLoadElementType = signless
827-
? mlirIntegerTypeGet(context, 8)
828-
: mlirIntegerTypeUnsignedGet(context, 8);
829-
} else if (view.itemsize == 2) {
830-
// i16
831-
bulkLoadElementType = signless
832-
? mlirIntegerTypeGet(context, 16)
833-
: mlirIntegerTypeUnsignedGet(context, 16);
834-
}
835-
}
836-
if (!bulkLoadElementType) {
837-
throw std::invalid_argument(
838-
std::string("unimplemented array format conversion from format: ") +
839-
std::string(format));
840-
}
841-
}
842-
843-
MlirType shapedType;
844-
if (mlirTypeIsAShaped(*bulkLoadElementType)) {
845-
if (explicitShape) {
846-
throw std::invalid_argument("Shape can only be specified explicitly "
847-
"when the type is not a shaped type.");
848-
}
849-
shapedType = *bulkLoadElementType;
850-
} else {
851-
shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
852-
*bulkLoadElementType, encodingAttr);
853-
}
854-
size_t rawBufferSize = view.len;
855-
MlirAttribute attr =
856-
mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf);
763+
MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType,
764+
explicitShape, context);
857765
if (mlirAttributeIsNull(attr)) {
858766
throw std::invalid_argument(
859767
"DenseElementsAttr could not be constructed from the given buffer. "
@@ -963,6 +871,13 @@ class PyDenseElementsAttribute
963871
// unsigned i16
964872
return bufferInfo<uint16_t>(shapedType);
965873
}
874+
} else if (mlirTypeIsAInteger(elementType) &&
875+
mlirIntegerTypeGetWidth(elementType) == 1) {
876+
// i1 / bool
877+
// We can not send the buffer directly back to Python, because the i1
878+
// values are bitpacked within MLIR. We call numpy's unpackbits function
879+
// to convert the bytes.
880+
return getBooleanBufferFromBitpackedAttribute();
966881
}
967882

968883
// TODO: Currently crashes the program.
@@ -1016,6 +931,191 @@ class PyDenseElementsAttribute
1016931
code == 'q';
1017932
}
1018933

934+
static MlirType
935+
getShapedType(std::optional<MlirType> bulkLoadElementType,
936+
std::optional<std::vector<int64_t>> explicitShape,
937+
Py_buffer &view) {
938+
SmallVector<int64_t> shape;
939+
if (explicitShape) {
940+
shape.append(explicitShape->begin(), explicitShape->end());
941+
} else {
942+
shape.append(view.shape, view.shape + view.ndim);
943+
}
944+
945+
if (mlirTypeIsAShaped(*bulkLoadElementType)) {
946+
if (explicitShape) {
947+
throw std::invalid_argument("Shape can only be specified explicitly "
948+
"when the type is not a shaped type.");
949+
}
950+
return *bulkLoadElementType;
951+
} else {
952+
MlirAttribute encodingAttr = mlirAttributeGetNull();
953+
return mlirRankedTensorTypeGet(shape.size(), shape.data(),
954+
*bulkLoadElementType, encodingAttr);
955+
}
956+
}
957+
958+
static MlirAttribute getAttributeFromBuffer(
959+
Py_buffer &view, bool signless, std::optional<PyType> explicitType,
960+
std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) {
961+
// Detect format codes that are suitable for bulk loading. This includes
962+
// all byte aligned integer and floating point types up to 8 bytes.
963+
// Notably, this excludes exotics types which do not have a direct
964+
// representation in the buffer protocol (i.e. complex, etc).
965+
std::optional<MlirType> bulkLoadElementType;
966+
if (explicitType) {
967+
bulkLoadElementType = *explicitType;
968+
} else {
969+
std::string_view format(view.format);
970+
if (format == "f") {
971+
// f32
972+
assert(view.itemsize == 4 && "mismatched array itemsize");
973+
bulkLoadElementType = mlirF32TypeGet(context);
974+
} else if (format == "d") {
975+
// f64
976+
assert(view.itemsize == 8 && "mismatched array itemsize");
977+
bulkLoadElementType = mlirF64TypeGet(context);
978+
} else if (format == "e") {
979+
// f16
980+
assert(view.itemsize == 2 && "mismatched array itemsize");
981+
bulkLoadElementType = mlirF16TypeGet(context);
982+
} else if (format == "?") {
983+
// i1
984+
// The i1 type needs to be bit-packed, so we will handle it seperately
985+
return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
986+
context);
987+
} else if (isSignedIntegerFormat(format)) {
988+
if (view.itemsize == 4) {
989+
// i32
990+
bulkLoadElementType = signless
991+
? mlirIntegerTypeGet(context, 32)
992+
: mlirIntegerTypeSignedGet(context, 32);
993+
} else if (view.itemsize == 8) {
994+
// i64
995+
bulkLoadElementType = signless
996+
? mlirIntegerTypeGet(context, 64)
997+
: mlirIntegerTypeSignedGet(context, 64);
998+
} else if (view.itemsize == 1) {
999+
// i8
1000+
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
1001+
: mlirIntegerTypeSignedGet(context, 8);
1002+
} else if (view.itemsize == 2) {
1003+
// i16
1004+
bulkLoadElementType = signless
1005+
? mlirIntegerTypeGet(context, 16)
1006+
: mlirIntegerTypeSignedGet(context, 16);
1007+
}
1008+
} else if (isUnsignedIntegerFormat(format)) {
1009+
if (view.itemsize == 4) {
1010+
// unsigned i32
1011+
bulkLoadElementType = signless
1012+
? mlirIntegerTypeGet(context, 32)
1013+
: mlirIntegerTypeUnsignedGet(context, 32);
1014+
} else if (view.itemsize == 8) {
1015+
// unsigned i64
1016+
bulkLoadElementType = signless
1017+
? mlirIntegerTypeGet(context, 64)
1018+
: mlirIntegerTypeUnsignedGet(context, 64);
1019+
} else if (view.itemsize == 1) {
1020+
// i8
1021+
bulkLoadElementType = signless
1022+
? mlirIntegerTypeGet(context, 8)
1023+
: mlirIntegerTypeUnsignedGet(context, 8);
1024+
} else if (view.itemsize == 2) {
1025+
// i16
1026+
bulkLoadElementType = signless
1027+
? mlirIntegerTypeGet(context, 16)
1028+
: mlirIntegerTypeUnsignedGet(context, 16);
1029+
}
1030+
}
1031+
if (!bulkLoadElementType) {
1032+
throw std::invalid_argument(
1033+
std::string("unimplemented array format conversion from format: ") +
1034+
std::string(format));
1035+
}
1036+
}
1037+
1038+
MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
1039+
return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
1040+
}
1041+
1042+
// There is a complication for boolean numpy arrays, as numpy represents them
1043+
// as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 booleans
1044+
// per byte.
1045+
static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
1046+
Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
1047+
MlirContext &context) {
1048+
if (llvm::endianness::native != llvm::endianness::little) {
1049+
// Given we have no good way of testing the behavior on big-endian systems
1050+
// we will throw
1051+
throw py::type_error("Constructing a bit-packed MLIR attribute is "
1052+
"unsupported on big-endian systems");
1053+
}
1054+
1055+
py::array_t<uint8_t> unpackedArray(view.len,
1056+
static_cast<uint8_t *>(view.buf));
1057+
1058+
py::module numpy = py::module::import("numpy");
1059+
py::object packbitsFunc = numpy.attr("packbits");
1060+
py::object packedBooleans =
1061+
packbitsFunc(unpackedArray, "bitorder"_a = "little");
1062+
py::buffer_info pythonBuffer = packedBooleans.cast<py::buffer>().request();
1063+
1064+
MlirType bitpackedType =
1065+
getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
1066+
assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8");
1067+
// Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of
1068+
// packedBooleans, hence the MlirAttribute will remain valid even when
1069+
// packedBooleans get reclaimed by the end of the function.
1070+
return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
1071+
pythonBuffer.ptr);
1072+
}
1073+
1074+
// This does the opposite transformation of
1075+
// `getBitpackedAttributeFromBooleanBuffer`
1076+
py::buffer_info getBooleanBufferFromBitpackedAttribute() {
1077+
if (llvm::endianness::native != llvm::endianness::little) {
1078+
// Given we have no good way of testing the behavior on big-endian systems
1079+
// we will throw
1080+
throw py::type_error("Constructing a numpy array from a MLIR attribute "
1081+
"is unsupported on big-endian systems");
1082+
}
1083+
1084+
int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
1085+
int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8);
1086+
uint8_t *bitpackedData = static_cast<uint8_t *>(
1087+
const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
1088+
py::array_t<uint8_t> packedArray(numBitpackedBytes, bitpackedData);
1089+
1090+
py::module numpy = py::module::import("numpy");
1091+
py::object unpackbitsFunc = numpy.attr("unpackbits");
1092+
py::object equalFunc = numpy.attr("equal");
1093+
py::object reshapeFunc = numpy.attr("reshape");
1094+
py::array unpackedBooleans =
1095+
unpackbitsFunc(packedArray, "bitorder"_a = "little");
1096+
1097+
// Unpackbits operates on bytes and gives back a flat 0 / 1 integer array.
1098+
// We need to:
1099+
// 1. Slice away the padded bits
1100+
// 2. Make the boolean array have the correct shape
1101+
// 3. Convert the array to a boolean array
1102+
unpackedBooleans = unpackedBooleans[py::slice(0, numBooleans, 1)];
1103+
unpackedBooleans = equalFunc(unpackedBooleans, 1);
1104+
1105+
std::vector<intptr_t> shape;
1106+
MlirType shapedType = mlirAttributeGetType(*this);
1107+
intptr_t rank = mlirShapedTypeGetRank(shapedType);
1108+
for (intptr_t i = 0; i < rank; ++i) {
1109+
shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
1110+
}
1111+
unpackedBooleans = reshapeFunc(unpackedBooleans, shape);
1112+
1113+
// Make sure the returned py::buffer_view claims ownership of the data in
1114+
// `pythonBuffer` so it remains valid when Python reads it
1115+
py::buffer pythonBuffer = unpackedBooleans.cast<py::buffer>();
1116+
return pythonBuffer.request();
1117+
}
1118+
10191119
template <typename Type>
10201120
py::buffer_info bufferInfo(MlirType shapedType,
10211121
const char *explicitFormat = nullptr) {

mlir/test/python/ir/array_attributes.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,78 @@ def testGetDenseElementsF64():
326326
print(np.array(attr))
327327

328328

329+
### 1 bit/boolean integer arrays
330+
# CHECK-LABEL: TEST: testGetDenseElementsI1Signless
331+
@run
332+
def testGetDenseElementsI1Signless():
333+
with Context():
334+
array = np.array([True], dtype=np.bool_)
335+
attr = DenseElementsAttr.get(array)
336+
# CHECK: dense<true> : tensor<1xi1>
337+
print(attr)
338+
# CHECK{LITERAL}: [ True]
339+
print(np.array(attr))
340+
341+
array = np.array([[True, False, True], [True, True, False]], dtype=np.bool_)
342+
attr = DenseElementsAttr.get(array)
343+
# CHECK{LITERAL}: dense<[[true, false, true], [true, true, false]]> : tensor<2x3xi1>
344+
print(attr)
345+
# CHECK{LITERAL}: [[ True False True]
346+
# CHECK{LITERAL}: [ True True False]]
347+
print(np.array(attr))
348+
349+
array = np.array(
350+
[[True, True, False, False], [True, False, True, False]], dtype=np.bool_
351+
)
352+
attr = DenseElementsAttr.get(array)
353+
# CHECK{LITERAL}: dense<[[true, true, false, false], [true, false, true, false]]> : tensor<2x4xi1>
354+
print(attr)
355+
# CHECK{LITERAL}: [[ True True False False]
356+
# CHECK{LITERAL}: [ True False True False]]
357+
print(np.array(attr))
358+
359+
array = np.array(
360+
[
361+
[True, True, False, False],
362+
[True, False, True, False],
363+
[False, False, False, False],
364+
[True, True, True, True],
365+
[True, False, False, True],
366+
],
367+
dtype=np.bool_,
368+
)
369+
attr = DenseElementsAttr.get(array)
370+
# CHECK{LITERAL}: dense<[[true, true, false, false], [true, false, true, false], [false, false, false, false], [true, true, true, true], [true, false, false, true]]> : tensor<5x4xi1>
371+
print(attr)
372+
# CHECK{LITERAL}: [[ True True False False]
373+
# CHECK{LITERAL}: [ True False True False]
374+
# CHECK{LITERAL}: [False False False False]
375+
# CHECK{LITERAL}: [ True True True True]
376+
# CHECK{LITERAL}: [ True False False True]]
377+
print(np.array(attr))
378+
379+
array = np.array(
380+
[
381+
[True, True, False, False, True, True, False, False, False],
382+
[False, False, False, True, False, True, True, False, True],
383+
],
384+
dtype=np.bool_,
385+
)
386+
attr = DenseElementsAttr.get(array)
387+
# CHECK{LITERAL}: dense<[[true, true, false, false, true, true, false, false, false], [false, false, false, true, false, true, true, false, true]]> : tensor<2x9xi1>
388+
print(attr)
389+
# CHECK{LITERAL}: [[ True True False False True True False False False]
390+
# CHECK{LITERAL}: [False False False True False True True False True]]
391+
print(np.array(attr))
392+
393+
array = np.array([], dtype=np.bool_)
394+
attr = DenseElementsAttr.get(array)
395+
# CHECK: dense<> : tensor<0xi1>
396+
print(attr)
397+
# CHECK{LITERAL}: []
398+
print(np.array(attr))
399+
400+
329401
### 16 bit integer arrays
330402
# CHECK-LABEL: TEST: testGetDenseElementsI16Signless
331403
@run

0 commit comments

Comments
 (0)