-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Rewrite vector transfer write with unit dims for scalable vectors #85270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Rewrite vector transfer write with unit dims for scalable vectors #85270
Conversation
Signed-off-by: Crefeda Rodrigues <[email protected]>
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write If you have received no comments on your PR for a week, you can request a review If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir-vector Author: Crefeda Rodrigues (cfRod) ChangesThis PR fixes the issue of lowering vector transfer writes on scalable vectors with unit dims to vector broadcast ops and vector transpose ops - where the scalable dims are dropped. Full diff: https://github.com/llvm/llvm-project/pull/85270.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 4a5e8fcfb6edaf..cef8a497a80996 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -226,6 +226,38 @@ struct TransferWritePermutationLowering
/// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>} :
/// vector<1x8x16xf32>
/// ```
+/// Returns the number of dims that aren't unit dims.
+static int getReducedRank(ArrayRef<int64_t> shape) {
+ return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
+}
+
+static int getFirstNonUnitDim(MemRefType oldType) {
+ int idx = 0;
+ for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
+ if (dimSize == 1) {
+ continue;
+ } else {
+ idx = dimIdx;
+ break;
+ }
+ }
+ return idx;
+}
+
+static int getLasttNonUnitDim(MemRefType oldType) {
+ int idx = 0;
+ for (auto [dimIdx, dimSize] :
+ llvm::enumerate(llvm::reverse(oldType.getShape()))) {
+ if (dimSize == 1) {
+ continue;
+ } else {
+ idx = oldType.getRank() - (dimIdx)-1;
+ break;
+ }
+ }
+ return idx;
+}
+
struct TransferWriteNonPermutationLowering
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern::OpRewritePattern;
@@ -264,6 +296,41 @@ struct TransferWriteNonPermutationLowering
missingInnerDim.push_back(i);
exprs.push_back(rewriter.getAffineDimExpr(i));
}
+
+ // Fix for lowering transfer write when we have Scalable vectors and unit
+ // dims
+ auto sourceVectorType = op.getVectorType();
+ auto memRefType = dyn_cast<MemRefType>(op.getShapedType());
+
+ if (sourceVectorType.isScalable() && !memRefType.hasStaticShape()) {
+ int reducedRank = getReducedRank(memRefType.getShape());
+
+ auto loc = op.getLoc();
+ SmallVector<Value> indices(
+ reducedRank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
+
+ // Check if the result shapes has unit dim before and after the scalable
+ // and non-scalable dim
+ int firstIdx = getFirstNonUnitDim(memRefType);
+ int lastIdx = getLasttNonUnitDim(memRefType);
+
+ SmallVector<ReassociationIndices> reassociation;
+ ReassociationIndices collapsedFirstIndices;
+ for (int64_t i = 0; i < firstIdx + 1; ++i)
+ collapsedFirstIndices.push_back(i);
+ reassociation.push_back(ReassociationIndices{collapsedFirstIndices});
+ ReassociationIndices collapsedIndices;
+ for (int64_t i = lastIdx; i < memRefType.getRank(); ++i)
+ collapsedIndices.push_back(i);
+ reassociation.push_back(collapsedIndices);
+ // Create mem collapse op
+ auto newOp = rewriter.create<memref::CollapseShapeOp>(loc, op.getSource(),
+ reassociation);
+ rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(op, op.getVector(),
+ newOp, indices);
+ return success();
+ }
+
// Vector: add unit dims at the beginning of the shape.
Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(),
missingInnerDim.size());
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 13e07f59a72a77..a654274f0a73e9 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -41,6 +41,22 @@ func.func @permutation_with_mask_scalable(%2: memref<?x?xf32>, %dim_1: index, %d
return %1 : vector<8x[4]x2xf32>
}
+// CHECK-LABEL: func.func @permutation_with_masked_transfer_write_scalable(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<4x[8]xi16>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<1x4x?x1x1x1x1xi16>,
+// CHECK-SAME: %[[VAL_2:.*]]: vector<4x[8]xi1>) {
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0, 1], [2, 3, 4, 5, 6]] : memref<1x4x?x1x1x1x1xi16> into memref<4x?xi16>
+// CHECK: vector.transfer_write %[[VAL_0]], %[[VAL_4]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, false]} : vector<4x[8]xi16>, memref<4x?xi16>
+// CHECK: return
+// CHECK: }
+ func.func @permutation_with_masked_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %mask: vector<4x[8]xi1>){
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %arg0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %mask {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2)>
+} : vector<4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>
+ return
+ }
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
|
@llvm/pr-subscribers-mlir Author: Crefeda Rodrigues (cfRod) ChangesThis PR fixes the issue of lowering vector transfer writes on scalable vectors with unit dims to vector broadcast ops and vector transpose ops - where the scalable dims are dropped. Full diff: https://github.com/llvm/llvm-project/pull/85270.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 4a5e8fcfb6edaf..cef8a497a80996 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -226,6 +226,38 @@ struct TransferWritePermutationLowering
/// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>} :
/// vector<1x8x16xf32>
/// ```
+/// Returns the number of dims that aren't unit dims.
+static int getReducedRank(ArrayRef<int64_t> shape) {
+ return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
+}
+
+static int getFirstNonUnitDim(MemRefType oldType) {
+ int idx = 0;
+ for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
+ if (dimSize == 1) {
+ continue;
+ } else {
+ idx = dimIdx;
+ break;
+ }
+ }
+ return idx;
+}
+
+static int getLasttNonUnitDim(MemRefType oldType) {
+ int idx = 0;
+ for (auto [dimIdx, dimSize] :
+ llvm::enumerate(llvm::reverse(oldType.getShape()))) {
+ if (dimSize == 1) {
+ continue;
+ } else {
+ idx = oldType.getRank() - (dimIdx)-1;
+ break;
+ }
+ }
+ return idx;
+}
+
struct TransferWriteNonPermutationLowering
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern::OpRewritePattern;
@@ -264,6 +296,41 @@ struct TransferWriteNonPermutationLowering
missingInnerDim.push_back(i);
exprs.push_back(rewriter.getAffineDimExpr(i));
}
+
+ // Fix for lowering transfer write when we have Scalable vectors and unit
+ // dims
+ auto sourceVectorType = op.getVectorType();
+ auto memRefType = dyn_cast<MemRefType>(op.getShapedType());
+
+ if (sourceVectorType.isScalable() && !memRefType.hasStaticShape()) {
+ int reducedRank = getReducedRank(memRefType.getShape());
+
+ auto loc = op.getLoc();
+ SmallVector<Value> indices(
+ reducedRank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
+
+ // Check if the result shapes has unit dim before and after the scalable
+ // and non-scalable dim
+ int firstIdx = getFirstNonUnitDim(memRefType);
+ int lastIdx = getLasttNonUnitDim(memRefType);
+
+ SmallVector<ReassociationIndices> reassociation;
+ ReassociationIndices collapsedFirstIndices;
+ for (int64_t i = 0; i < firstIdx + 1; ++i)
+ collapsedFirstIndices.push_back(i);
+ reassociation.push_back(ReassociationIndices{collapsedFirstIndices});
+ ReassociationIndices collapsedIndices;
+ for (int64_t i = lastIdx; i < memRefType.getRank(); ++i)
+ collapsedIndices.push_back(i);
+ reassociation.push_back(collapsedIndices);
+ // Create mem collapse op
+ auto newOp = rewriter.create<memref::CollapseShapeOp>(loc, op.getSource(),
+ reassociation);
+ rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(op, op.getVector(),
+ newOp, indices);
+ return success();
+ }
+
// Vector: add unit dims at the beginning of the shape.
Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(),
missingInnerDim.size());
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 13e07f59a72a77..a654274f0a73e9 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -41,6 +41,22 @@ func.func @permutation_with_mask_scalable(%2: memref<?x?xf32>, %dim_1: index, %d
return %1 : vector<8x[4]x2xf32>
}
+// CHECK-LABEL: func.func @permutation_with_masked_transfer_write_scalable(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<4x[8]xi16>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<1x4x?x1x1x1x1xi16>,
+// CHECK-SAME: %[[VAL_2:.*]]: vector<4x[8]xi1>) {
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0, 1], [2, 3, 4, 5, 6]] : memref<1x4x?x1x1x1x1xi16> into memref<4x?xi16>
+// CHECK: vector.transfer_write %[[VAL_0]], %[[VAL_4]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, false]} : vector<4x[8]xi16>, memref<4x?xi16>
+// CHECK: return
+// CHECK: }
+ func.func @permutation_with_masked_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %mask: vector<4x[8]xi1>){
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %arg0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %mask {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2)>
+} : vector<4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>
+ return
+ }
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
|
thanks for the patch! We already have a pattern Could this not be extended to support this case instead of introducing more logic for handling unit dims? |
// CHECK-SAME: %[[VAL_2:.*]]: vector<4x[8]xi1>) { | ||
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index | ||
// CHECK: %[[VAL_4:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0, 1], [2, 3, 4, 5, 6]] : memref<1x4x?x1x1x1x1xi16> into memref<4x?xi16> | ||
// CHECK: vector.transfer_write %[[VAL_0]], %[[VAL_4]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, false]} : vector<4x[8]xi16>, memref<4x?xi16> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is it no longer in-bounds?
Closing this in favour of #85632 |
…wering (#85632) Updates `extendVectorRank` so that scalability in patterns that use it (in particular, `TransferWriteNonPermutationLowering`), is correctly propagated. Closed related previous PR #85270 --------- Signed-off-by: Crefeda Rodrigues <[email protected]> Co-authored-by: Benjamin Maxwell <[email protected]>
…wering (llvm#85632) Updates `extendVectorRank` so that scalability in patterns that use it (in particular, `TransferWriteNonPermutationLowering`), is correctly propagated. Closed related previous PR llvm#85270 --------- Signed-off-by: Crefeda Rodrigues <[email protected]> Co-authored-by: Benjamin Maxwell <[email protected]>
This PR fixes the issue of lowering vector transfer writes on scalable vectors with unit dims to vector broadcast ops and vector transpose ops - where the scalable dims are dropped.