-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Drop innermost unit dims on transfer_write. #78554
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Drop innermost unit dims on transfer_write. #78554
Conversation
The revision renames DropInnerMostUnitDims to DropInnerMostUnitDimsTransferRead; adds support for vector.transfer_write. It refactors common methods (i.e., getTransferFoldableInnerUnitDims and getMemRefTypeWithDroppingInnerDims) and uses them in both patterns.
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Han-Chung Wang (hanhanW) ChangesThe revision renames DropInnerMostUnitDims to DropInnerMostUnitDimsTransferRead; adds support for vector.transfer_write. It refactors common methods (i.e., getTransferFoldableInnerUnitDims and getMemRefTypeWithDroppingInnerDims) and uses them in both patterns. A vector.transfer_write with inner unit dims will be lowered to vector.shape_cast + memref.subview + vector.transfer_write. E.g., vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0]
{in_bounds = [true, true, true, true]}
: vector<1x16x16x1xf32>, memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>> will be lowered to %subview = memref.subview %arg0
[0, 0, 0, 0] [1, 512, 16, 1] [1, 1, 1, 1]
: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>
to memref<1x512x16xf32, strided<[8192, 16, 1], offset: ?>>
%0 = vector.shape_cast %arg1 : vector<1x16x16x1xf32> to vector<1x16x16xf32>
vector.transfer_write %0, %subview[%c0, %arg2, %c0]
{in_bounds = [true, true, true]}
: vector<1x16x16xf32>, memref<1x512x16xf32, strided<[8192, 16, 1], offset: ?>> Full diff: https://github.com/llvm/llvm-project/pull/78554.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index bd02c07981466d..7c276ca8101221 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1152,8 +1152,71 @@ struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
}
};
-// Drop inner most contiguous unit dimensions from transfer_read operand.
-class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
+/// Returns the number of dims can be folded away from transfer ops. It returns
+/// a failure if strides and offsets can not be resolved.
+static FailureOr<size_t>
+getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
+ SmallVector<int64_t> srcStrides;
+ int64_t srcOffset;
+ if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
+ return failure();
+
+ // According to vector.transfer_read/write semantics, the vector can be a
+ // slice. It pads the indices with `1` starting from beginning. Thus, we have
+ // to offset the check index with `rankDiff` in `srcStrides` and source dim
+ // sizes.
+ size_t result = 0;
+ int rankDiff = srcType.getRank() - vectorType.getRank();
+ for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
+ // Check that the inner dim size is 1 for both memref/tensor type and
+ // vector slice. It can be folded only if they are 1 and the stride is 1.
+ int dim = vectorType.getRank() - i - 1;
+ if (srcStrides[dim + rankDiff] == 1 &&
+ srcType.getDimSize(dim + rankDiff) == 1 &&
+ vectorType.getDimSize(dim) == 1) {
+ result++;
+ } else {
+ break;
+ }
+ }
+ return result;
+}
+
+/// Returns a MemRef type that drops inner `dimsToDrop` dimensions from
+/// `srcType`.
+static MemRefType getMemRefTypeWithDroppingInnerDims(OpBuilder &builder,
+ MemRefType srcType,
+ size_t dimsToDrop) {
+ MemRefType resultMemrefType;
+ MemRefLayoutAttrInterface layout = srcType.getLayout();
+ if (isa<AffineMapAttr>(layout) && layout.isIdentity()) {
+ return MemRefType::get(srcType.getShape().drop_back(dimsToDrop),
+ srcType.getElementType(), nullptr,
+ srcType.getMemorySpace());
+ }
+ MemRefLayoutAttrInterface updatedLayout;
+ if (auto strided = dyn_cast<StridedLayoutAttr>(layout)) {
+ auto strides = llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
+ updatedLayout = StridedLayoutAttr::get(strided.getContext(),
+ strided.getOffset(), strides);
+ } else {
+ AffineMap map = srcType.getLayout().getAffineMap();
+ int numSymbols = map.getNumSymbols();
+ for (size_t i = 0; i < dimsToDrop; ++i) {
+ int dim = srcType.getRank() - i - 1;
+ map = map.replace(builder.getAffineDimExpr(dim),
+ builder.getAffineConstantExpr(0), map.getNumDims() - 1,
+ numSymbols);
+ }
+ }
+ return MemRefType::get(srcType.getShape().drop_back(dimsToDrop),
+ srcType.getElementType(), updatedLayout,
+ srcType.getMemorySpace());
+}
+
+/// Drop inner most contiguous unit dimensions from transfer_read operand.
+class DropInnerMostUnitDimsTransferRead
+ : public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
@@ -1177,29 +1240,12 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
if (targetType.getRank() <= 1)
return failure();
- SmallVector<int64_t> srcStrides;
- int64_t srcOffset;
- if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
- return failure();
-
- // According to vector.transfer_read semantics, the result can be a slice.
- // It pads the indices with `1` starting from beginning. Thus, we have to
- // offset the check index with `rankDiff` in `srcStrides` and source dim
- // sizes.
- size_t dimsToDrop = 0;
- int rankDiff = srcType.getRank() - targetType.getRank();
- for (int64_t i = 0, e = targetType.getRank(); i < e; ++i) {
- // Check that the inner dim size is 1 for both memref/tensor type and
- // vector slice. It can be folded only if they are 1 and the stride is 1.
- int dim = targetType.getRank() - i - 1;
- if (srcStrides[dim + rankDiff] == 1 &&
- srcType.getDimSize(dim + rankDiff) == 1 &&
- targetType.getDimSize(dim) == 1) {
- dimsToDrop++;
- } else {
- break;
- }
- }
+ FailureOr<size_t> maybeDimsToDrop =
+ getTransferFoldableInnerUnitDims(srcType, targetType);
+ if (failed(maybeDimsToDrop))
+ return failure();
+
+ size_t dimsToDrop = maybeDimsToDrop.value();
if (dimsToDrop == 0)
return failure();
@@ -1207,35 +1253,9 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
VectorType::get(targetType.getShape().drop_back(dimsToDrop),
targetType.getElementType());
- MemRefType resultMemrefType;
- MemRefLayoutAttrInterface layout = srcType.getLayout();
- if (isa<AffineMapAttr>(layout) && layout.isIdentity()) {
- resultMemrefType = MemRefType::get(
- srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
- nullptr, srcType.getMemorySpace());
- } else {
- MemRefLayoutAttrInterface updatedLayout;
- if (auto strided = dyn_cast<StridedLayoutAttr>(layout)) {
- auto strides =
- llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
- updatedLayout = StridedLayoutAttr::get(strided.getContext(),
- strided.getOffset(), strides);
- } else {
- AffineMap map = srcType.getLayout().getAffineMap();
- int numSymbols = map.getNumSymbols();
- for (size_t i = 0; i < dimsToDrop; ++i) {
- int dim = srcType.getRank() - i - 1;
- map = map.replace(rewriter.getAffineDimExpr(dim),
- rewriter.getAffineConstantExpr(0),
- map.getNumDims() - 1, numSymbols);
- }
- }
- resultMemrefType = MemRefType::get(
- srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
- updatedLayout, srcType.getMemorySpace());
- }
-
auto loc = readOp.getLoc();
+ MemRefType resultMemrefType =
+ getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
SmallVector<int64_t> offsets(srcType.getRank(), 0);
SmallVector<int64_t> strides(srcType.getRank(), 1);
@@ -1261,6 +1281,73 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
}
};
+/// Drop inner most contiguous unit dimensions from transfer_write operand.
+class DropInnerMostUnitDimsTransferWrite
+ : public OpRewritePattern<vector::TransferWriteOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+ PatternRewriter &rewriter) const override {
+ // TODO: support 0-d corner case.
+ if (writeOp.getTransferRank() == 0)
+ return failure();
+
+ // TODO: support mask.
+ if (writeOp.getMask())
+ return failure();
+
+ auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
+ if (!srcType || !srcType.hasStaticShape())
+ return failure();
+
+ if (!writeOp.getPermutationMap().isMinorIdentity())
+ return failure();
+
+ auto targetType = writeOp.getVectorType();
+ if (targetType.getRank() <= 1)
+ return failure();
+
+ FailureOr<size_t> maybeDimsToDrop =
+ getTransferFoldableInnerUnitDims(srcType, targetType);
+ if (failed(maybeDimsToDrop))
+ return failure();
+
+ size_t dimsToDrop = maybeDimsToDrop.value();
+ if (dimsToDrop == 0)
+ return failure();
+
+ auto resultTargetVecType =
+ VectorType::get(targetType.getShape().drop_back(dimsToDrop),
+ targetType.getElementType());
+
+ auto loc = writeOp.getLoc();
+ MemRefType resultMemrefType =
+ getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
+ SmallVector<int64_t> offsets(srcType.getRank(), 0);
+ SmallVector<int64_t> strides(srcType.getRank(), 1);
+
+ ArrayAttr inBoundsAttr =
+ writeOp.getInBounds()
+ ? rewriter.getArrayAttr(
+ writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
+ : ArrayAttr();
+ Value rankedReducedView = rewriter.create<memref::SubViewOp>(
+ loc, resultMemrefType, writeOp.getSource(), offsets, srcType.getShape(),
+ strides);
+ auto permMap = getTransferMinorIdentityMap(
+ cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
+
+ auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
+ loc, resultTargetVecType, writeOp.getVector());
+ rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+ writeOp, shapeCast, rankedReducedView,
+ writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
+ // TODO: support mask.
+ /*mask=*/Value(), inBoundsAttr);
+ return success();
+ }
+};
+
/// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
/// semantics to a contraction suitable for MMT (matrix matrix multiplication
/// with the RHS transposed) lowering.
@@ -1696,7 +1783,9 @@ void mlir::vector::populateVectorReductionToContractPatterns(
void mlir::vector::
populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<DropInnerMostUnitDims>(patterns.getContext(), benefit);
+ patterns.add<DropInnerMostUnitDimsTransferRead,
+ DropInnerMostUnitDimsTransferWrite>(patterns.getContext(),
+ benefit);
}
void mlir::vector::populateSinkVectorBroadcastPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index 0d2743b9fe2e7f..59116c19b46ec2 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -76,3 +76,24 @@ func.func @contiguous_inner_most_dim_out_of_bounds_2d(%arg0: memref<1x1xf32>) ->
// CHECK-NOT: memref.subview
// CHECK: %[[READ:.+]] = vector.transfer_read %[[SRC]]
// CHECK: return %[[READ]] : vector<4x8xf32>
+
+// -----
+
+func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0]
+ {in_bounds = [true, true, true, true]}
+ : vector<1x16x16x1xf32>, memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>
+ return
+}
+// CHECK: func.func @drop_inner_most_dim_for_transfer_write
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]]
+// CHECK-SAME: [0, 0, 0, 0] [1, 512, 16, 1] [1, 1, 1, 1]
+// CHECK-SAME: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>> to memref<1x512x16xf32, strided<[8192, 16, 1], offset: ?>>
+// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1xf32> to vector<1x16x16xf32>
+// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]
+// CHECK-SAME: [%[[C0]], %[[IDX]], %[[C0]]]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for improving this! Makes sense modulo some minor suggestions.
I appreciate that some of my comments refer to code that you are merely moving from one place to another - feel free to ignore thise. Adding more tests/comments would be nice :)
mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for addressing my comments! I've left a few nits, but feel free to ignore.
The revision renames DropInnerMostUnitDims to DropInnerMostUnitDimsTransferRead; adds support for vector.transfer_write.
It refactors common methods (i.e., getTransferFoldableInnerUnitDims and getMemRefTypeWithDroppingInnerDims) and uses them in both patterns.
A vector.transfer_write with inner unit dims will be lowered to vector.shape_cast + memref.subview + vector.transfer_write. E.g.,
will be lowered to