|
13 | 13 | #include "IRModule.h"
|
14 | 14 |
|
15 | 15 | #include "PybindUtils.h"
|
16 |
| -#include <pybind11/numpy.h> |
17 | 16 |
|
18 | 17 | #include "llvm/ADT/ScopeExit.h"
|
19 | 18 | #include "llvm/Support/raw_ostream.h"
|
@@ -758,10 +757,103 @@ class PyDenseElementsAttribute
|
758 | 757 | throw py::error_already_set();
|
759 | 758 | }
|
760 | 759 | auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
|
| 760 | + SmallVector<int64_t> shape; |
| 761 | + if (explicitShape) { |
| 762 | + shape.append(explicitShape->begin(), explicitShape->end()); |
| 763 | + } else { |
| 764 | + shape.append(view.shape, view.shape + view.ndim); |
| 765 | + } |
761 | 766 |
|
| 767 | + MlirAttribute encodingAttr = mlirAttributeGetNull(); |
762 | 768 | MlirContext context = contextWrapper->get();
|
763 |
| - MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType, |
764 |
| - explicitShape, context); |
| 769 | + |
| 770 | + // Detect format codes that are suitable for bulk loading. This includes |
| 771 | + // all byte aligned integer and floating point types up to 8 bytes. |
| 772 | + // Notably, this excludes, bool (which needs to be bit-packed) and |
| 773 | + // other exotics which do not have a direct representation in the buffer |
| 774 | + // protocol (i.e. complex, etc). |
| 775 | + std::optional<MlirType> bulkLoadElementType; |
| 776 | + if (explicitType) { |
| 777 | + bulkLoadElementType = *explicitType; |
| 778 | + } else { |
| 779 | + std::string_view format(view.format); |
| 780 | + if (format == "f") { |
| 781 | + // f32 |
| 782 | + assert(view.itemsize == 4 && "mismatched array itemsize"); |
| 783 | + bulkLoadElementType = mlirF32TypeGet(context); |
| 784 | + } else if (format == "d") { |
| 785 | + // f64 |
| 786 | + assert(view.itemsize == 8 && "mismatched array itemsize"); |
| 787 | + bulkLoadElementType = mlirF64TypeGet(context); |
| 788 | + } else if (format == "e") { |
| 789 | + // f16 |
| 790 | + assert(view.itemsize == 2 && "mismatched array itemsize"); |
| 791 | + bulkLoadElementType = mlirF16TypeGet(context); |
| 792 | + } else if (isSignedIntegerFormat(format)) { |
| 793 | + if (view.itemsize == 4) { |
| 794 | + // i32 |
| 795 | + bulkLoadElementType = signless |
| 796 | + ? mlirIntegerTypeGet(context, 32) |
| 797 | + : mlirIntegerTypeSignedGet(context, 32); |
| 798 | + } else if (view.itemsize == 8) { |
| 799 | + // i64 |
| 800 | + bulkLoadElementType = signless |
| 801 | + ? mlirIntegerTypeGet(context, 64) |
| 802 | + : mlirIntegerTypeSignedGet(context, 64); |
| 803 | + } else if (view.itemsize == 1) { |
| 804 | + // i8 |
| 805 | + bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) |
| 806 | + : mlirIntegerTypeSignedGet(context, 8); |
| 807 | + } else if (view.itemsize == 2) { |
| 808 | + // i16 |
| 809 | + bulkLoadElementType = signless |
| 810 | + ? mlirIntegerTypeGet(context, 16) |
| 811 | + : mlirIntegerTypeSignedGet(context, 16); |
| 812 | + } |
| 813 | + } else if (isUnsignedIntegerFormat(format)) { |
| 814 | + if (view.itemsize == 4) { |
| 815 | + // unsigned i32 |
| 816 | + bulkLoadElementType = signless |
| 817 | + ? mlirIntegerTypeGet(context, 32) |
| 818 | + : mlirIntegerTypeUnsignedGet(context, 32); |
| 819 | + } else if (view.itemsize == 8) { |
| 820 | + // unsigned i64 |
| 821 | + bulkLoadElementType = signless |
| 822 | + ? mlirIntegerTypeGet(context, 64) |
| 823 | + : mlirIntegerTypeUnsignedGet(context, 64); |
| 824 | + } else if (view.itemsize == 1) { |
| 825 | + // i8 |
| 826 | + bulkLoadElementType = signless |
| 827 | + ? mlirIntegerTypeGet(context, 8) |
| 828 | + : mlirIntegerTypeUnsignedGet(context, 8); |
| 829 | + } else if (view.itemsize == 2) { |
| 830 | + // i16 |
| 831 | + bulkLoadElementType = signless |
| 832 | + ? mlirIntegerTypeGet(context, 16) |
| 833 | + : mlirIntegerTypeUnsignedGet(context, 16); |
| 834 | + } |
| 835 | + } |
| 836 | + if (!bulkLoadElementType) { |
| 837 | + throw std::invalid_argument( |
| 838 | + std::string("unimplemented array format conversion from format: ") + |
| 839 | + std::string(format)); |
| 840 | + } |
| 841 | + } |
| 842 | + |
| 843 | + MlirType shapedType; |
| 844 | + if (mlirTypeIsAShaped(*bulkLoadElementType)) { |
| 845 | + if (explicitShape) { |
| 846 | + throw std::invalid_argument("Shape can only be specified explicitly " |
| 847 | + "when the type is not a shaped type."); |
| 848 | + } |
| 849 | + shapedType = *bulkLoadElementType; |
| 850 | + } else { |
| 851 | + shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(), |
| 852 | + *bulkLoadElementType, encodingAttr); |
| 853 | + } |
| 854 | + size_t rawBufferSize = view.len; |
| 855 | + MlirAttribute attr = |
| 856 | + mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf); |
765 | 857 | if (mlirAttributeIsNull(attr)) {
|
766 | 858 | throw std::invalid_argument(
|
767 | 859 | "DenseElementsAttr could not be constructed from the given buffer. "
|
@@ -871,13 +963,6 @@ class PyDenseElementsAttribute
|
871 | 963 | // unsigned i16
|
872 | 964 | return bufferInfo<uint16_t>(shapedType);
|
873 | 965 | }
|
874 |
| - } else if (mlirTypeIsAInteger(elementType) && |
875 |
| - mlirIntegerTypeGetWidth(elementType) == 1) { |
876 |
| - // i1 / bool |
877 |
| - // We can not send the buffer directly back to Python, because the i1 |
878 |
| - // values are bitpacked within MLIR. We call numpy's unpackbits function |
879 |
| - // to convert the bytes. |
880 |
| - return getBooleanBufferFromBitpackedAttribute(); |
881 | 966 | }
|
882 | 967 |
|
883 | 968 | // TODO: Currently crashes the program.
|
@@ -931,183 +1016,14 @@ class PyDenseElementsAttribute
|
931 | 1016 | code == 'q';
|
932 | 1017 | }
|
933 | 1018 |
|
934 |
| - static MlirType |
935 |
| - getShapedType(std::optional<MlirType> bulkLoadElementType, |
936 |
| - std::optional<std::vector<int64_t>> explicitShape, |
937 |
| - Py_buffer &view) { |
938 |
| - SmallVector<int64_t> shape; |
939 |
| - if (explicitShape) { |
940 |
| - shape.append(explicitShape->begin(), explicitShape->end()); |
941 |
| - } else { |
942 |
| - shape.append(view.shape, view.shape + view.ndim); |
943 |
| - } |
944 |
| - |
945 |
| - if (mlirTypeIsAShaped(*bulkLoadElementType)) { |
946 |
| - if (explicitShape) { |
947 |
| - throw std::invalid_argument("Shape can only be specified explicitly " |
948 |
| - "when the type is not a shaped type."); |
949 |
| - } |
950 |
| - return *bulkLoadElementType; |
951 |
| - } else { |
952 |
| - MlirAttribute encodingAttr = mlirAttributeGetNull(); |
953 |
| - return mlirRankedTensorTypeGet(shape.size(), shape.data(), |
954 |
| - *bulkLoadElementType, encodingAttr); |
955 |
| - } |
956 |
| - } |
957 |
| - |
958 |
| - static MlirAttribute getAttributeFromBuffer( |
959 |
| - Py_buffer &view, bool signless, std::optional<PyType> explicitType, |
960 |
| - std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) { |
961 |
| - // Detect format codes that are suitable for bulk loading. This includes |
962 |
| - // all byte aligned integer and floating point types up to 8 bytes. |
963 |
| - // Notably, this excludes exotics types which do not have a direct |
964 |
| - // representation in the buffer protocol (i.e. complex, etc). |
965 |
| - std::optional<MlirType> bulkLoadElementType; |
966 |
| - if (explicitType) { |
967 |
| - bulkLoadElementType = *explicitType; |
968 |
| - } else { |
969 |
| - std::string_view format(view.format); |
970 |
| - if (format == "f") { |
971 |
| - // f32 |
972 |
| - assert(view.itemsize == 4 && "mismatched array itemsize"); |
973 |
| - bulkLoadElementType = mlirF32TypeGet(context); |
974 |
| - } else if (format == "d") { |
975 |
| - // f64 |
976 |
| - assert(view.itemsize == 8 && "mismatched array itemsize"); |
977 |
| - bulkLoadElementType = mlirF64TypeGet(context); |
978 |
| - } else if (format == "e") { |
979 |
| - // f16 |
980 |
| - assert(view.itemsize == 2 && "mismatched array itemsize"); |
981 |
| - bulkLoadElementType = mlirF16TypeGet(context); |
982 |
| - } else if (format == "?") { |
983 |
| - // i1 |
984 |
| - // The i1 type needs to be bit-packed, so we will handle it seperately |
985 |
| - return getBitpackedAttributeFromBooleanBuffer(view, explicitShape, |
986 |
| - context); |
987 |
| - } else if (isSignedIntegerFormat(format)) { |
988 |
| - if (view.itemsize == 4) { |
989 |
| - // i32 |
990 |
| - bulkLoadElementType = signless |
991 |
| - ? mlirIntegerTypeGet(context, 32) |
992 |
| - : mlirIntegerTypeSignedGet(context, 32); |
993 |
| - } else if (view.itemsize == 8) { |
994 |
| - // i64 |
995 |
| - bulkLoadElementType = signless |
996 |
| - ? mlirIntegerTypeGet(context, 64) |
997 |
| - : mlirIntegerTypeSignedGet(context, 64); |
998 |
| - } else if (view.itemsize == 1) { |
999 |
| - // i8 |
1000 |
| - bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) |
1001 |
| - : mlirIntegerTypeSignedGet(context, 8); |
1002 |
| - } else if (view.itemsize == 2) { |
1003 |
| - // i16 |
1004 |
| - bulkLoadElementType = signless |
1005 |
| - ? mlirIntegerTypeGet(context, 16) |
1006 |
| - : mlirIntegerTypeSignedGet(context, 16); |
1007 |
| - } |
1008 |
| - } else if (isUnsignedIntegerFormat(format)) { |
1009 |
| - if (view.itemsize == 4) { |
1010 |
| - // unsigned i32 |
1011 |
| - bulkLoadElementType = signless |
1012 |
| - ? mlirIntegerTypeGet(context, 32) |
1013 |
| - : mlirIntegerTypeUnsignedGet(context, 32); |
1014 |
| - } else if (view.itemsize == 8) { |
1015 |
| - // unsigned i64 |
1016 |
| - bulkLoadElementType = signless |
1017 |
| - ? mlirIntegerTypeGet(context, 64) |
1018 |
| - : mlirIntegerTypeUnsignedGet(context, 64); |
1019 |
| - } else if (view.itemsize == 1) { |
1020 |
| - // i8 |
1021 |
| - bulkLoadElementType = signless |
1022 |
| - ? mlirIntegerTypeGet(context, 8) |
1023 |
| - : mlirIntegerTypeUnsignedGet(context, 8); |
1024 |
| - } else if (view.itemsize == 2) { |
1025 |
| - // i16 |
1026 |
| - bulkLoadElementType = signless |
1027 |
| - ? mlirIntegerTypeGet(context, 16) |
1028 |
| - : mlirIntegerTypeUnsignedGet(context, 16); |
1029 |
| - } |
1030 |
| - } |
1031 |
| - if (!bulkLoadElementType) { |
1032 |
| - throw std::invalid_argument( |
1033 |
| - std::string("unimplemented array format conversion from format: ") + |
1034 |
| - std::string(format)); |
1035 |
| - } |
1036 |
| - } |
1037 |
| - |
1038 |
| - MlirType type = getShapedType(bulkLoadElementType, explicitShape, view); |
1039 |
| - return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf); |
1040 |
| - } |
1041 |
| - |
1042 |
| - // There is a complication for boolean numpy arrays, as numpy represents them |
1043 |
| - // as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 booleans |
1044 |
| - // per byte. |
1045 |
| - static MlirAttribute getBitpackedAttributeFromBooleanBuffer( |
1046 |
| - Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape, |
1047 |
| - MlirContext &context) { |
1048 |
| - if (llvm::endianness::native != llvm::endianness::little) { |
1049 |
| - // Given we have no good way of testing the behavior on big-endian systems |
1050 |
| - // we will throw |
1051 |
| - throw py::type_error("Constructing a bit-packed MLIR attribute is " |
1052 |
| - "unsupported on big-endian systems"); |
1053 |
| - } |
1054 |
| - |
1055 |
| - py::array_t<uint8_t> unpackedArray(view.len, |
1056 |
| - static_cast<uint8_t *>(view.buf)); |
1057 |
| - |
1058 |
| - py::module numpy = py::module::import("numpy"); |
1059 |
| - py::object packbits_func = numpy.attr("packbits"); |
1060 |
| - py::object packed_booleans = |
1061 |
| - packbits_func(unpackedArray, "bitorder"_a = "little"); |
1062 |
| - py::buffer_info pythonBuffer = packed_booleans.cast<py::buffer>().request(); |
1063 |
| - |
1064 |
| - MlirType bitpackedType = |
1065 |
| - getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view); |
1066 |
| - return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size, |
1067 |
| - pythonBuffer.ptr); |
1068 |
| - } |
1069 |
| - |
1070 |
| - // This does the opposite transformation of |
1071 |
| - // `getBitpackedAttributeFromBooleanBuffer` |
1072 |
| - py::buffer_info getBooleanBufferFromBitpackedAttribute() { |
1073 |
| - if (llvm::endianness::native != llvm::endianness::little) { |
1074 |
| - // Given we have no good way of testing the behavior on big-endian systems |
1075 |
| - // we will throw |
1076 |
| - throw py::type_error("Constructing a numpy array from a MLIR attribute " |
1077 |
| - "is unsupported on big-endian systems"); |
1078 |
| - } |
1079 |
| - |
1080 |
| - int64_t numBooleans = mlirElementsAttrGetNumElements(*this); |
1081 |
| - int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8); |
1082 |
| - uint8_t *bitpackedData = static_cast<uint8_t *>( |
1083 |
| - const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); |
1084 |
| - py::array_t<uint8_t> packedArray(numBitpackedBytes, bitpackedData); |
1085 |
| - |
1086 |
| - py::module numpy = py::module::import("numpy"); |
1087 |
| - py::object unpackbits_func = numpy.attr("unpackbits"); |
1088 |
| - py::object unpacked_booleans = |
1089 |
| - unpackbits_func(packedArray, "bitorder"_a = "little"); |
1090 |
| - py::buffer_info pythonBuffer = |
1091 |
| - unpacked_booleans.cast<py::buffer>().request(); |
1092 |
| - |
1093 |
| - MlirType shapedType = mlirAttributeGetType(*this); |
1094 |
| - return bufferInfo<bool>(shapedType, (bool *)pythonBuffer.ptr, "?"); |
1095 |
| - } |
1096 |
| - |
1097 | 1019 | template <typename Type>
|
1098 | 1020 | py::buffer_info bufferInfo(MlirType shapedType,
|
1099 | 1021 | const char *explicitFormat = nullptr) {
|
| 1022 | + intptr_t rank = mlirShapedTypeGetRank(shapedType); |
1100 | 1023 | // Prepare the data for the buffer_info.
|
1101 |
| - // Buffer is configured for read-only access inside the `bufferInfo` call. |
| 1024 | + // Buffer is configured for read-only access below. |
1102 | 1025 | Type *data = static_cast<Type *>(
|
1103 | 1026 | const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
|
1104 |
| - return bufferInfo<Type>(shapedType, data, explicitFormat); |
1105 |
| - } |
1106 |
| - |
1107 |
| - template <typename Type> |
1108 |
| - py::buffer_info bufferInfo(MlirType shapedType, Type *data, |
1109 |
| - const char *explicitFormat = nullptr) { |
1110 |
| - intptr_t rank = mlirShapedTypeGetRank(shapedType); |
1111 | 1027 | // Prepare the shape for the buffer_info.
|
1112 | 1028 | SmallVector<intptr_t, 4> shape;
|
1113 | 1029 | for (intptr_t i = 0; i < rank; ++i)
|
|
0 commit comments