Skip to content

Commit 0a68171

Browse files
committed
Revert "[MLIR,Python] Support converting boolean numpy arrays to and from mlir attributes (#113064)"
This reverts commit fb7bf7a. There is an ASan issue here, see the discussion on #113064.
1 parent 0c18def commit 0a68171

File tree

2 files changed

+97
-253
lines changed

2 files changed

+97
-253
lines changed

mlir/lib/Bindings/Python/IRAttributes.cpp

Lines changed: 97 additions & 181 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#include "IRModule.h"
1414

1515
#include "PybindUtils.h"
16-
#include <pybind11/numpy.h>
1716

1817
#include "llvm/ADT/ScopeExit.h"
1918
#include "llvm/Support/raw_ostream.h"
@@ -758,10 +757,103 @@ class PyDenseElementsAttribute
758757
throw py::error_already_set();
759758
}
760759
auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
760+
SmallVector<int64_t> shape;
761+
if (explicitShape) {
762+
shape.append(explicitShape->begin(), explicitShape->end());
763+
} else {
764+
shape.append(view.shape, view.shape + view.ndim);
765+
}
761766

767+
MlirAttribute encodingAttr = mlirAttributeGetNull();
762768
MlirContext context = contextWrapper->get();
763-
MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType,
764-
explicitShape, context);
769+
770+
// Detect format codes that are suitable for bulk loading. This includes
771+
// all byte aligned integer and floating point types up to 8 bytes.
772+
// Notably, this excludes, bool (which needs to be bit-packed) and
773+
// other exotics which do not have a direct representation in the buffer
774+
// protocol (i.e. complex, etc).
775+
std::optional<MlirType> bulkLoadElementType;
776+
if (explicitType) {
777+
bulkLoadElementType = *explicitType;
778+
} else {
779+
std::string_view format(view.format);
780+
if (format == "f") {
781+
// f32
782+
assert(view.itemsize == 4 && "mismatched array itemsize");
783+
bulkLoadElementType = mlirF32TypeGet(context);
784+
} else if (format == "d") {
785+
// f64
786+
assert(view.itemsize == 8 && "mismatched array itemsize");
787+
bulkLoadElementType = mlirF64TypeGet(context);
788+
} else if (format == "e") {
789+
// f16
790+
assert(view.itemsize == 2 && "mismatched array itemsize");
791+
bulkLoadElementType = mlirF16TypeGet(context);
792+
} else if (isSignedIntegerFormat(format)) {
793+
if (view.itemsize == 4) {
794+
// i32
795+
bulkLoadElementType = signless
796+
? mlirIntegerTypeGet(context, 32)
797+
: mlirIntegerTypeSignedGet(context, 32);
798+
} else if (view.itemsize == 8) {
799+
// i64
800+
bulkLoadElementType = signless
801+
? mlirIntegerTypeGet(context, 64)
802+
: mlirIntegerTypeSignedGet(context, 64);
803+
} else if (view.itemsize == 1) {
804+
// i8
805+
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
806+
: mlirIntegerTypeSignedGet(context, 8);
807+
} else if (view.itemsize == 2) {
808+
// i16
809+
bulkLoadElementType = signless
810+
? mlirIntegerTypeGet(context, 16)
811+
: mlirIntegerTypeSignedGet(context, 16);
812+
}
813+
} else if (isUnsignedIntegerFormat(format)) {
814+
if (view.itemsize == 4) {
815+
// unsigned i32
816+
bulkLoadElementType = signless
817+
? mlirIntegerTypeGet(context, 32)
818+
: mlirIntegerTypeUnsignedGet(context, 32);
819+
} else if (view.itemsize == 8) {
820+
// unsigned i64
821+
bulkLoadElementType = signless
822+
? mlirIntegerTypeGet(context, 64)
823+
: mlirIntegerTypeUnsignedGet(context, 64);
824+
} else if (view.itemsize == 1) {
825+
// i8
826+
bulkLoadElementType = signless
827+
? mlirIntegerTypeGet(context, 8)
828+
: mlirIntegerTypeUnsignedGet(context, 8);
829+
} else if (view.itemsize == 2) {
830+
// i16
831+
bulkLoadElementType = signless
832+
? mlirIntegerTypeGet(context, 16)
833+
: mlirIntegerTypeUnsignedGet(context, 16);
834+
}
835+
}
836+
if (!bulkLoadElementType) {
837+
throw std::invalid_argument(
838+
std::string("unimplemented array format conversion from format: ") +
839+
std::string(format));
840+
}
841+
}
842+
843+
MlirType shapedType;
844+
if (mlirTypeIsAShaped(*bulkLoadElementType)) {
845+
if (explicitShape) {
846+
throw std::invalid_argument("Shape can only be specified explicitly "
847+
"when the type is not a shaped type.");
848+
}
849+
shapedType = *bulkLoadElementType;
850+
} else {
851+
shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
852+
*bulkLoadElementType, encodingAttr);
853+
}
854+
size_t rawBufferSize = view.len;
855+
MlirAttribute attr =
856+
mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf);
765857
if (mlirAttributeIsNull(attr)) {
766858
throw std::invalid_argument(
767859
"DenseElementsAttr could not be constructed from the given buffer. "
@@ -871,13 +963,6 @@ class PyDenseElementsAttribute
871963
// unsigned i16
872964
return bufferInfo<uint16_t>(shapedType);
873965
}
874-
} else if (mlirTypeIsAInteger(elementType) &&
875-
mlirIntegerTypeGetWidth(elementType) == 1) {
876-
// i1 / bool
877-
// We can not send the buffer directly back to Python, because the i1
878-
// values are bitpacked within MLIR. We call numpy's unpackbits function
879-
// to convert the bytes.
880-
return getBooleanBufferFromBitpackedAttribute();
881966
}
882967

883968
// TODO: Currently crashes the program.
@@ -931,183 +1016,14 @@ class PyDenseElementsAttribute
9311016
code == 'q';
9321017
}
9331018

934-
static MlirType
935-
getShapedType(std::optional<MlirType> bulkLoadElementType,
936-
std::optional<std::vector<int64_t>> explicitShape,
937-
Py_buffer &view) {
938-
SmallVector<int64_t> shape;
939-
if (explicitShape) {
940-
shape.append(explicitShape->begin(), explicitShape->end());
941-
} else {
942-
shape.append(view.shape, view.shape + view.ndim);
943-
}
944-
945-
if (mlirTypeIsAShaped(*bulkLoadElementType)) {
946-
if (explicitShape) {
947-
throw std::invalid_argument("Shape can only be specified explicitly "
948-
"when the type is not a shaped type.");
949-
}
950-
return *bulkLoadElementType;
951-
} else {
952-
MlirAttribute encodingAttr = mlirAttributeGetNull();
953-
return mlirRankedTensorTypeGet(shape.size(), shape.data(),
954-
*bulkLoadElementType, encodingAttr);
955-
}
956-
}
957-
958-
static MlirAttribute getAttributeFromBuffer(
959-
Py_buffer &view, bool signless, std::optional<PyType> explicitType,
960-
std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) {
961-
// Detect format codes that are suitable for bulk loading. This includes
962-
// all byte aligned integer and floating point types up to 8 bytes.
963-
// Notably, this excludes exotics types which do not have a direct
964-
// representation in the buffer protocol (i.e. complex, etc).
965-
std::optional<MlirType> bulkLoadElementType;
966-
if (explicitType) {
967-
bulkLoadElementType = *explicitType;
968-
} else {
969-
std::string_view format(view.format);
970-
if (format == "f") {
971-
// f32
972-
assert(view.itemsize == 4 && "mismatched array itemsize");
973-
bulkLoadElementType = mlirF32TypeGet(context);
974-
} else if (format == "d") {
975-
// f64
976-
assert(view.itemsize == 8 && "mismatched array itemsize");
977-
bulkLoadElementType = mlirF64TypeGet(context);
978-
} else if (format == "e") {
979-
// f16
980-
assert(view.itemsize == 2 && "mismatched array itemsize");
981-
bulkLoadElementType = mlirF16TypeGet(context);
982-
} else if (format == "?") {
983-
// i1
984-
// The i1 type needs to be bit-packed, so we will handle it seperately
985-
return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
986-
context);
987-
} else if (isSignedIntegerFormat(format)) {
988-
if (view.itemsize == 4) {
989-
// i32
990-
bulkLoadElementType = signless
991-
? mlirIntegerTypeGet(context, 32)
992-
: mlirIntegerTypeSignedGet(context, 32);
993-
} else if (view.itemsize == 8) {
994-
// i64
995-
bulkLoadElementType = signless
996-
? mlirIntegerTypeGet(context, 64)
997-
: mlirIntegerTypeSignedGet(context, 64);
998-
} else if (view.itemsize == 1) {
999-
// i8
1000-
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
1001-
: mlirIntegerTypeSignedGet(context, 8);
1002-
} else if (view.itemsize == 2) {
1003-
// i16
1004-
bulkLoadElementType = signless
1005-
? mlirIntegerTypeGet(context, 16)
1006-
: mlirIntegerTypeSignedGet(context, 16);
1007-
}
1008-
} else if (isUnsignedIntegerFormat(format)) {
1009-
if (view.itemsize == 4) {
1010-
// unsigned i32
1011-
bulkLoadElementType = signless
1012-
? mlirIntegerTypeGet(context, 32)
1013-
: mlirIntegerTypeUnsignedGet(context, 32);
1014-
} else if (view.itemsize == 8) {
1015-
// unsigned i64
1016-
bulkLoadElementType = signless
1017-
? mlirIntegerTypeGet(context, 64)
1018-
: mlirIntegerTypeUnsignedGet(context, 64);
1019-
} else if (view.itemsize == 1) {
1020-
// i8
1021-
bulkLoadElementType = signless
1022-
? mlirIntegerTypeGet(context, 8)
1023-
: mlirIntegerTypeUnsignedGet(context, 8);
1024-
} else if (view.itemsize == 2) {
1025-
// i16
1026-
bulkLoadElementType = signless
1027-
? mlirIntegerTypeGet(context, 16)
1028-
: mlirIntegerTypeUnsignedGet(context, 16);
1029-
}
1030-
}
1031-
if (!bulkLoadElementType) {
1032-
throw std::invalid_argument(
1033-
std::string("unimplemented array format conversion from format: ") +
1034-
std::string(format));
1035-
}
1036-
}
1037-
1038-
MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
1039-
return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
1040-
}
1041-
1042-
// There is a complication for boolean numpy arrays, as numpy represents them
1043-
// as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 booleans
1044-
// per byte.
1045-
static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
1046-
Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
1047-
MlirContext &context) {
1048-
if (llvm::endianness::native != llvm::endianness::little) {
1049-
// Given we have no good way of testing the behavior on big-endian systems
1050-
// we will throw
1051-
throw py::type_error("Constructing a bit-packed MLIR attribute is "
1052-
"unsupported on big-endian systems");
1053-
}
1054-
1055-
py::array_t<uint8_t> unpackedArray(view.len,
1056-
static_cast<uint8_t *>(view.buf));
1057-
1058-
py::module numpy = py::module::import("numpy");
1059-
py::object packbits_func = numpy.attr("packbits");
1060-
py::object packed_booleans =
1061-
packbits_func(unpackedArray, "bitorder"_a = "little");
1062-
py::buffer_info pythonBuffer = packed_booleans.cast<py::buffer>().request();
1063-
1064-
MlirType bitpackedType =
1065-
getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
1066-
return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
1067-
pythonBuffer.ptr);
1068-
}
1069-
1070-
// This does the opposite transformation of
1071-
// `getBitpackedAttributeFromBooleanBuffer`
1072-
py::buffer_info getBooleanBufferFromBitpackedAttribute() {
1073-
if (llvm::endianness::native != llvm::endianness::little) {
1074-
// Given we have no good way of testing the behavior on big-endian systems
1075-
// we will throw
1076-
throw py::type_error("Constructing a numpy array from a MLIR attribute "
1077-
"is unsupported on big-endian systems");
1078-
}
1079-
1080-
int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
1081-
int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8);
1082-
uint8_t *bitpackedData = static_cast<uint8_t *>(
1083-
const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
1084-
py::array_t<uint8_t> packedArray(numBitpackedBytes, bitpackedData);
1085-
1086-
py::module numpy = py::module::import("numpy");
1087-
py::object unpackbits_func = numpy.attr("unpackbits");
1088-
py::object unpacked_booleans =
1089-
unpackbits_func(packedArray, "bitorder"_a = "little");
1090-
py::buffer_info pythonBuffer =
1091-
unpacked_booleans.cast<py::buffer>().request();
1092-
1093-
MlirType shapedType = mlirAttributeGetType(*this);
1094-
return bufferInfo<bool>(shapedType, (bool *)pythonBuffer.ptr, "?");
1095-
}
1096-
10971019
template <typename Type>
10981020
py::buffer_info bufferInfo(MlirType shapedType,
10991021
const char *explicitFormat = nullptr) {
1022+
intptr_t rank = mlirShapedTypeGetRank(shapedType);
11001023
// Prepare the data for the buffer_info.
1101-
// Buffer is configured for read-only access inside the `bufferInfo` call.
1024+
// Buffer is configured for read-only access below.
11021025
Type *data = static_cast<Type *>(
11031026
const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
1104-
return bufferInfo<Type>(shapedType, data, explicitFormat);
1105-
}
1106-
1107-
template <typename Type>
1108-
py::buffer_info bufferInfo(MlirType shapedType, Type *data,
1109-
const char *explicitFormat = nullptr) {
1110-
intptr_t rank = mlirShapedTypeGetRank(shapedType);
11111027
// Prepare the shape for the buffer_info.
11121028
SmallVector<intptr_t, 4> shape;
11131029
for (intptr_t i = 0; i < rank; ++i)

mlir/test/python/ir/array_attributes.py

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -326,78 +326,6 @@ def testGetDenseElementsF64():
326326
print(np.array(attr))
327327

328328

329-
### 1 bit/boolean integer arrays
330-
# CHECK-LABEL: TEST: testGetDenseElementsI1Signless
331-
@run
332-
def testGetDenseElementsI1Signless():
333-
with Context():
334-
array = np.array([True], dtype=np.bool_)
335-
attr = DenseElementsAttr.get(array)
336-
# CHECK: dense<true> : tensor<1xi1>
337-
print(attr)
338-
# CHECK{LITERAL}: [ True]
339-
print(np.array(attr))
340-
341-
array = np.array([[True, False, True], [True, True, False]], dtype=np.bool_)
342-
attr = DenseElementsAttr.get(array)
343-
# CHECK{LITERAL}: dense<[[true, false, true], [true, true, false]]> : tensor<2x3xi1>
344-
print(attr)
345-
# CHECK{LITERAL}: [[ True False True]
346-
# CHECK{LITERAL}: [ True True False]]
347-
print(np.array(attr))
348-
349-
array = np.array(
350-
[[True, True, False, False], [True, False, True, False]], dtype=np.bool_
351-
)
352-
attr = DenseElementsAttr.get(array)
353-
# CHECK{LITERAL}: dense<[[true, true, false, false], [true, false, true, false]]> : tensor<2x4xi1>
354-
print(attr)
355-
# CHECK{LITERAL}: [[ True True False False]
356-
# CHECK{LITERAL}: [ True False True False]]
357-
print(np.array(attr))
358-
359-
array = np.array(
360-
[
361-
[True, True, False, False],
362-
[True, False, True, False],
363-
[False, False, False, False],
364-
[True, True, True, True],
365-
[True, False, False, True],
366-
],
367-
dtype=np.bool_,
368-
)
369-
attr = DenseElementsAttr.get(array)
370-
# CHECK{LITERAL}: dense<[[true, true, false, false], [true, false, true, false], [false, false, false, false], [true, true, true, true], [true, false, false, true]]> : tensor<5x4xi1>
371-
print(attr)
372-
# CHECK{LITERAL}: [[ True True False False]
373-
# CHECK{LITERAL}: [ True False True False]
374-
# CHECK{LITERAL}: [False False False False]
375-
# CHECK{LITERAL}: [ True True True True]
376-
# CHECK{LITERAL}: [ True False False True]]
377-
print(np.array(attr))
378-
379-
array = np.array(
380-
[
381-
[True, True, False, False, True, True, False, False, False],
382-
[False, False, False, True, False, True, True, False, True],
383-
],
384-
dtype=np.bool_,
385-
)
386-
attr = DenseElementsAttr.get(array)
387-
# CHECK{LITERAL}: dense<[[true, true, false, false, true, true, false, false, false], [false, false, false, true, false, true, true, false, true]]> : tensor<2x9xi1>
388-
print(attr)
389-
# CHECK{LITERAL}: [[ True True False False True True False False False]
390-
# CHECK{LITERAL}: [False False False True False True True False True]]
391-
print(np.array(attr))
392-
393-
array = np.array([], dtype=np.bool_)
394-
attr = DenseElementsAttr.get(array)
395-
# CHECK: dense<> : tensor<0xi1>
396-
print(attr)
397-
# CHECK{LITERAL}: []
398-
print(np.array(attr))
399-
400-
401329
### 16 bit integer arrays
402330
# CHECK-LABEL: TEST: testGetDenseElementsI16Signless
403331
@run

0 commit comments

Comments
 (0)