|
37 | 37 | #include "mlir/IR/Types.h"
|
38 | 38 | #include "mlir/Interfaces/VectorInterfaces.h"
|
39 | 39 |
|
| 40 | +#include "llvm/ADT/STLExtras.h" |
40 | 41 | #include "llvm/Support/CommandLine.h"
|
41 | 42 | #include "llvm/Support/Debug.h"
|
42 | 43 | #include "llvm/Support/raw_ostream.h"
|
@@ -2729,6 +2730,116 @@ struct TransferWriteInsertPattern
|
2729 | 2730 | }
|
2730 | 2731 | };
|
2731 | 2732 |
|
| 2733 | +/// Progressive lowering of transfer_read. This pattern supports lowering of |
| 2734 | +/// `vector.transfer_read` to a combination of `vector.load` and |
| 2735 | +/// `vector.broadcast` if all of the following hold: |
| 2736 | +/// - The op reads from a memref with the default layout. |
| 2737 | +/// - Masking is not required. |
| 2738 | +/// - If the memref's element type is a vector type then it coincides with the |
| 2739 | +/// result type. |
| 2740 | +/// - The permutation map doesn't perform permutation (broadcasting is allowed). |
| 2741 | +struct TransferReadToVectorLoadLowering |
| 2742 | + : public OpRewritePattern<vector::TransferReadOp> { |
| 2743 | + TransferReadToVectorLoadLowering(MLIRContext *context) |
| 2744 | + : OpRewritePattern<vector::TransferReadOp>(context) {} |
| 2745 | + LogicalResult matchAndRewrite(vector::TransferReadOp read, |
| 2746 | + PatternRewriter &rewriter) const override { |
| 2747 | + SmallVector<unsigned, 4> broadcastedDims; |
| 2748 | + // TODO: Support permutations. |
| 2749 | + if (!read.permutation_map().isMinorIdentityWithBroadcasting( |
| 2750 | + &broadcastedDims)) |
| 2751 | + return failure(); |
| 2752 | + auto memRefType = read.getShapedType().dyn_cast<MemRefType>(); |
| 2753 | + if (!memRefType) |
| 2754 | + return failure(); |
| 2755 | + |
| 2756 | + // If there is broadcasting involved then we first load the unbroadcasted |
| 2757 | + // vector, and then broadcast it with `vector.broadcast`. |
| 2758 | + ArrayRef<int64_t> vectorShape = read.getVectorType().getShape(); |
| 2759 | + SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(), |
| 2760 | + vectorShape.end()); |
| 2761 | + for (unsigned i : broadcastedDims) |
| 2762 | + unbroadcastedVectorShape[i] = 1; |
| 2763 | + VectorType unbroadcastedVectorType = VectorType::get( |
| 2764 | + unbroadcastedVectorShape, read.getVectorType().getElementType()); |
| 2765 | + |
| 2766 | + // `vector.load` supports vector types as memref's elements only when the |
| 2767 | + // resulting vector type is the same as the element type. |
| 2768 | + if (memRefType.getElementType().isa<VectorType>() && |
| 2769 | + memRefType.getElementType() != unbroadcastedVectorType) |
| 2770 | + return failure(); |
| 2771 | + // Only the default layout is supported by `vector.load`. |
| 2772 | + // TODO: Support non-default layouts. |
| 2773 | + if (!memRefType.getAffineMaps().empty()) |
| 2774 | + return failure(); |
| 2775 | + // TODO: When masking is required, we can create a MaskedLoadOp |
| 2776 | + if (read.hasMaskedDim()) |
| 2777 | + return failure(); |
| 2778 | + |
| 2779 | + Operation *loadOp; |
| 2780 | + if (!broadcastedDims.empty() && |
| 2781 | + unbroadcastedVectorType.getNumElements() == 1) { |
| 2782 | + // If broadcasting is required and the number of loaded elements is 1 then |
| 2783 | + // we can create `std.load` instead of `vector.load`. |
| 2784 | + loadOp = rewriter.create<mlir::LoadOp>(read.getLoc(), read.source(), |
| 2785 | + read.indices()); |
| 2786 | + } else { |
| 2787 | + // Otherwise create `vector.load`. |
| 2788 | + loadOp = rewriter.create<vector::LoadOp>(read.getLoc(), |
| 2789 | + unbroadcastedVectorType, |
| 2790 | + read.source(), read.indices()); |
| 2791 | + } |
| 2792 | + |
| 2793 | + // Insert a broadcasting op if required. |
| 2794 | + if (!broadcastedDims.empty()) { |
| 2795 | + rewriter.replaceOpWithNewOp<vector::BroadcastOp>( |
| 2796 | + read, read.getVectorType(), loadOp->getResult(0)); |
| 2797 | + } else { |
| 2798 | + rewriter.replaceOp(read, loadOp->getResult(0)); |
| 2799 | + } |
| 2800 | + |
| 2801 | + return success(); |
| 2802 | + } |
| 2803 | +}; |
| 2804 | + |
| 2805 | +/// Progressive lowering of transfer_write. This pattern supports lowering of |
| 2806 | +/// `vector.transfer_write` to `vector.store` if all of the following hold: |
| 2807 | +/// - The op writes to a memref with the default layout. |
| 2808 | +/// - Masking is not required. |
| 2809 | +/// - If the memref's element type is a vector type then it coincides with the |
| 2810 | +/// type of the written value. |
| 2811 | +/// - The permutation map is the minor identity map (neither permutation nor |
| 2812 | +/// broadcasting is allowed). |
| 2813 | +struct TransferWriteToVectorStoreLowering |
| 2814 | + : public OpRewritePattern<vector::TransferWriteOp> { |
| 2815 | + TransferWriteToVectorStoreLowering(MLIRContext *context) |
| 2816 | + : OpRewritePattern<vector::TransferWriteOp>(context) {} |
| 2817 | + LogicalResult matchAndRewrite(vector::TransferWriteOp write, |
| 2818 | + PatternRewriter &rewriter) const override { |
| 2819 | + // TODO: Support non-minor-identity maps |
| 2820 | + if (!write.permutation_map().isMinorIdentity()) |
| 2821 | + return failure(); |
| 2822 | + auto memRefType = write.getShapedType().dyn_cast<MemRefType>(); |
| 2823 | + if (!memRefType) |
| 2824 | + return failure(); |
| 2825 | + // `vector.store` supports vector types as memref's elements only when the |
| 2826 | + // type of the vector value being written is the same as the element type. |
| 2827 | + if (memRefType.getElementType().isa<VectorType>() && |
| 2828 | + memRefType.getElementType() != write.getVectorType()) |
| 2829 | + return failure(); |
| 2830 | + // Only the default layout is supported by `vector.store`. |
| 2831 | + // TODO: Support non-default layouts. |
| 2832 | + if (!memRefType.getAffineMaps().empty()) |
| 2833 | + return failure(); |
| 2834 | + // TODO: When masking is required, we can create a MaskedStoreOp |
| 2835 | + if (write.hasMaskedDim()) |
| 2836 | + return failure(); |
| 2837 | + rewriter.replaceOpWithNewOp<vector::StoreOp>( |
| 2838 | + write, write.vector(), write.source(), write.indices()); |
| 2839 | + return success(); |
| 2840 | + } |
| 2841 | +}; |
| 2842 | + |
2732 | 2843 | // Trims leading one dimensions from `oldType` and returns the result type.
|
2733 | 2844 | // Returns `vector<1xT>` if `oldType` only has one element.
|
2734 | 2845 | static VectorType trimLeadingOneDims(VectorType oldType) {
|
@@ -3201,3 +3312,9 @@ void mlir::vector::populateVectorContractLoweringPatterns(
|
3201 | 3312 | ContractionOpToOuterProductOpLowering>(parameters, context);
|
3202 | 3313 | // clang-format on
|
3203 | 3314 | }
|
| 3315 | + |
| 3316 | +void mlir::vector::populateVectorTransferLoweringPatterns( |
| 3317 | + OwningRewritePatternList &patterns, MLIRContext *context) { |
| 3318 | + patterns.insert<TransferReadToVectorLoadLowering, |
| 3319 | + TransferWriteToVectorStoreLowering>(context); |
| 3320 | +} |
0 commit comments