|
14 | 14 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
15 | 15 | #include "mlir/Dialect/StandardOps/IR/Ops.h"
|
16 | 16 | #include "mlir/Dialect/Vector/VectorOps.h"
|
| 17 | +#include "mlir/IR/AffineMap.h" |
17 | 18 | #include "mlir/IR/Attributes.h"
|
18 | 19 | #include "mlir/IR/Builders.h"
|
19 | 20 | #include "mlir/IR/MLIRContext.h"
|
@@ -894,6 +895,129 @@ class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
|
894 | 895 | }
|
895 | 896 | };
|
896 | 897 |
|
| 898 | +template <typename ConcreteOp> |
| 899 | +void replaceTransferOp(ConversionPatternRewriter &rewriter, |
| 900 | + LLVMTypeConverter &typeConverter, Location loc, |
| 901 | + Operation *op, ArrayRef<Value> operands, Value dataPtr, |
| 902 | + Value mask); |
| 903 | + |
| 904 | +template <> |
| 905 | +void replaceTransferOp<TransferReadOp>(ConversionPatternRewriter &rewriter, |
| 906 | + LLVMTypeConverter &typeConverter, |
| 907 | + Location loc, Operation *op, |
| 908 | + ArrayRef<Value> operands, Value dataPtr, |
| 909 | + Value mask) { |
| 910 | + auto xferOp = cast<TransferReadOp>(op); |
| 911 | + auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; |
| 912 | + VectorType fillType = xferOp.getVectorType(); |
| 913 | + Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding()); |
| 914 | + fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill); |
| 915 | + |
| 916 | + auto vecTy = toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>(); |
| 917 | + rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( |
| 918 | + op, vecTy, dataPtr, mask, ValueRange{fill}, |
| 919 | + rewriter.getI32IntegerAttr(1)); |
| 920 | +} |
| 921 | + |
| 922 | +template <> |
| 923 | +void replaceTransferOp<TransferWriteOp>(ConversionPatternRewriter &rewriter, |
| 924 | + LLVMTypeConverter &typeConverter, |
| 925 | + Location loc, Operation *op, |
| 926 | + ArrayRef<Value> operands, Value dataPtr, |
| 927 | + Value mask) { |
| 928 | + auto adaptor = TransferWriteOpOperandAdaptor(operands); |
| 929 | + rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( |
| 930 | + op, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(1)); |
| 931 | +} |
| 932 | + |
| 933 | +static TransferReadOpOperandAdaptor |
| 934 | +getTransferOpAdapter(TransferReadOp xferOp, ArrayRef<Value> operands) { |
| 935 | + return TransferReadOpOperandAdaptor(operands); |
| 936 | +} |
| 937 | + |
| 938 | +static TransferWriteOpOperandAdaptor |
| 939 | +getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef<Value> operands) { |
| 940 | + return TransferWriteOpOperandAdaptor(operands); |
| 941 | +} |
| 942 | + |
| 943 | +/// Conversion pattern that converts a 1-D vector transfer read/write op in a |
| 944 | +/// sequence of: |
| 945 | +/// 1. Bitcast to vector form. |
| 946 | +/// 2. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. |
| 947 | +/// 3. Create a mask where offsetVector is compared against memref upper bound. |
| 948 | +/// 4. Rewrite op as a masked read or write. |
| 949 | +template <typename ConcreteOp> |
| 950 | +class VectorTransferConversion : public ConvertToLLVMPattern { |
| 951 | +public: |
| 952 | + explicit VectorTransferConversion(MLIRContext *context, |
| 953 | + LLVMTypeConverter &typeConv) |
| 954 | + : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, |
| 955 | + typeConv) {} |
| 956 | + |
| 957 | + LogicalResult |
| 958 | + matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
| 959 | + ConversionPatternRewriter &rewriter) const override { |
| 960 | + auto xferOp = cast<ConcreteOp>(op); |
| 961 | + auto adaptor = getTransferOpAdapter(xferOp, operands); |
| 962 | + if (xferOp.getMemRefType().getRank() != 1) |
| 963 | + return failure(); |
| 964 | + if (!xferOp.permutation_map().isIdentity()) |
| 965 | + return failure(); |
| 966 | + |
| 967 | + auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; |
| 968 | + |
| 969 | + Location loc = op->getLoc(); |
| 970 | + Type i64Type = rewriter.getIntegerType(64); |
| 971 | + MemRefType memRefType = xferOp.getMemRefType(); |
| 972 | + |
| 973 | + // 1. Get the source/dst address as an LLVM vector pointer. |
| 974 | + // TODO: support alignment when possible. |
| 975 | + Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), |
| 976 | + adaptor.indices(), rewriter, getModule()); |
| 977 | + auto vecTy = |
| 978 | + toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>(); |
| 979 | + auto vectorDataPtr = |
| 980 | + rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr); |
| 981 | + |
| 982 | + // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. |
| 983 | + unsigned vecWidth = vecTy.getVectorNumElements(); |
| 984 | + VectorType vectorCmpType = VectorType::get(vecWidth, i64Type); |
| 985 | + SmallVector<int64_t, 8> indices; |
| 986 | + indices.reserve(vecWidth); |
| 987 | + for (unsigned i = 0; i < vecWidth; ++i) |
| 988 | + indices.push_back(i); |
| 989 | + Value linearIndices = rewriter.create<ConstantOp>( |
| 990 | + loc, vectorCmpType, |
| 991 | + DenseElementsAttr::get(vectorCmpType, ArrayRef<int64_t>(indices))); |
| 992 | + linearIndices = rewriter.create<LLVM::DialectCastOp>( |
| 993 | + loc, toLLVMTy(vectorCmpType), linearIndices); |
| 994 | + |
| 995 | + // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. |
| 996 | + Value offsetIndex = *(xferOp.indices().begin()); |
| 997 | + offsetIndex = rewriter.create<IndexCastOp>( |
| 998 | + loc, vectorCmpType.getElementType(), offsetIndex); |
| 999 | + Value base = rewriter.create<SplatOp>(loc, vectorCmpType, offsetIndex); |
| 1000 | + Value offsetVector = rewriter.create<AddIOp>(loc, base, linearIndices); |
| 1001 | + |
| 1002 | + // 4. Let dim the memref dimension, compute the vector comparison mask: |
| 1003 | + // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] |
| 1004 | + Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), 0); |
| 1005 | + dim = |
| 1006 | + rewriter.create<IndexCastOp>(loc, vectorCmpType.getElementType(), dim); |
| 1007 | + dim = rewriter.create<SplatOp>(loc, vectorCmpType, dim); |
| 1008 | + Value mask = |
| 1009 | + rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, offsetVector, dim); |
| 1010 | + mask = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(mask.getType()), |
| 1011 | + mask); |
| 1012 | + |
| 1013 | + // 5. Rewrite as a masked read / write. |
| 1014 | + replaceTransferOp<ConcreteOp>(rewriter, typeConverter, loc, op, operands, |
| 1015 | + vectorDataPtr, mask); |
| 1016 | + |
| 1017 | + return success(); |
| 1018 | + } |
| 1019 | +}; |
| 1020 | + |
897 | 1021 | class VectorPrintOpConversion : public ConvertToLLVMPattern {
|
898 | 1022 | public:
|
899 | 1023 | explicit VectorPrintOpConversion(MLIRContext *context,
|
@@ -1079,16 +1203,25 @@ class VectorStridedSliceOpConversion : public OpRewritePattern<StridedSliceOp> {
|
1079 | 1203 | void mlir::populateVectorToLLVMConversionPatterns(
|
1080 | 1204 | LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
1081 | 1205 | MLIRContext *ctx = converter.getDialect()->getContext();
|
| 1206 | + // clang-format off |
1082 | 1207 | patterns.insert<VectorFMAOpNDRewritePattern,
|
1083 | 1208 | VectorInsertStridedSliceOpDifferentRankRewritePattern,
|
1084 | 1209 | VectorInsertStridedSliceOpSameRankRewritePattern,
|
1085 | 1210 | VectorStridedSliceOpConversion>(ctx);
|
1086 |
| - patterns.insert<VectorBroadcastOpConversion, VectorReductionOpConversion, |
1087 |
| - VectorShuffleOpConversion, VectorExtractElementOpConversion, |
1088 |
| - VectorExtractOpConversion, VectorFMAOp1DConversion, |
1089 |
| - VectorInsertElementOpConversion, VectorInsertOpConversion, |
1090 |
| - VectorTypeCastOpConversion, VectorPrintOpConversion>( |
1091 |
| - ctx, converter); |
| 1211 | + patterns |
| 1212 | + .insert<VectorBroadcastOpConversion, |
| 1213 | + VectorReductionOpConversion, |
| 1214 | + VectorShuffleOpConversion, |
| 1215 | + VectorExtractElementOpConversion, |
| 1216 | + VectorExtractOpConversion, |
| 1217 | + VectorFMAOp1DConversion, |
| 1218 | + VectorInsertElementOpConversion, |
| 1219 | + VectorInsertOpConversion, |
| 1220 | + VectorPrintOpConversion, |
| 1221 | + VectorTransferConversion<TransferReadOp>, |
| 1222 | + VectorTransferConversion<TransferWriteOp>, |
| 1223 | + VectorTypeCastOpConversion>(ctx, converter); |
| 1224 | + // clang-format on |
1092 | 1225 | }
|
1093 | 1226 |
|
1094 | 1227 | void mlir::populateVectorToLLVMMatrixConversionPatterns(
|
|
0 commit comments