Skip to content

[mlir][vector] Rewrite vector transfer write with unit dims for scalable vectors #85270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

cfRod
Copy link
Contributor

@cfRod cfRod commented Mar 14, 2024

This PR fixes the issue of lowering vector transfer writes on scalable vectors with unit dims to vector broadcast ops and vector transpose ops - where the scalable dims are dropped.

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be
notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write
permissions for the repository. In which case you can instead tag reviewers by
name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review
by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate
is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Mar 14, 2024

@llvm/pr-subscribers-mlir-vector

Author: Crefeda Rodrigues (cfRod)

Changes

This PR fixes the issue of lowering vector transfer writes on scalable vectors with unit dims to vector broadcast ops and vector transpose ops - where the scalable dims are dropped.


Full diff: https://github.com/llvm/llvm-project/pull/85270.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp (+67)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir (+16)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 4a5e8fcfb6edaf..cef8a497a80996 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -226,6 +226,38 @@ struct TransferWritePermutationLowering
 ///     {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>} :
 ///     vector<1x8x16xf32>
 /// ```
+/// Returns the number of dims that aren't unit dims.
+static int getReducedRank(ArrayRef<int64_t> shape) {
+  return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
+}
+
+static int getFirstNonUnitDim(MemRefType oldType) {
+  int idx = 0;
+  for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
+    if (dimSize == 1) {
+      continue;
+    } else {
+      idx = dimIdx;
+      break;
+    }
+  }
+  return idx;
+}
+
+static int getLasttNonUnitDim(MemRefType oldType) {
+  int idx = 0;
+  for (auto [dimIdx, dimSize] :
+       llvm::enumerate(llvm::reverse(oldType.getShape()))) {
+    if (dimSize == 1) {
+      continue;
+    } else {
+      idx = oldType.getRank() - (dimIdx)-1;
+      break;
+    }
+  }
+  return idx;
+}
+
 struct TransferWriteNonPermutationLowering
     : public OpRewritePattern<vector::TransferWriteOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -264,6 +296,41 @@ struct TransferWriteNonPermutationLowering
       missingInnerDim.push_back(i);
       exprs.push_back(rewriter.getAffineDimExpr(i));
     }
+
+    // Fix for lowering transfer write when we have Scalable vectors and unit
+    // dims
+    auto sourceVectorType = op.getVectorType();
+    auto memRefType = dyn_cast<MemRefType>(op.getShapedType());
+
+    if (sourceVectorType.isScalable() && !memRefType.hasStaticShape()) {
+      int reducedRank = getReducedRank(memRefType.getShape());
+
+      auto loc = op.getLoc();
+      SmallVector<Value> indices(
+          reducedRank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
+
+      // Check if the result shapes has unit dim before and after the scalable
+      // and non-scalable dim
+      int firstIdx = getFirstNonUnitDim(memRefType);
+      int lastIdx = getLasttNonUnitDim(memRefType);
+
+      SmallVector<ReassociationIndices> reassociation;
+      ReassociationIndices collapsedFirstIndices;
+      for (int64_t i = 0; i < firstIdx + 1; ++i)
+        collapsedFirstIndices.push_back(i);
+      reassociation.push_back(ReassociationIndices{collapsedFirstIndices});
+      ReassociationIndices collapsedIndices;
+      for (int64_t i = lastIdx; i < memRefType.getRank(); ++i)
+        collapsedIndices.push_back(i);
+      reassociation.push_back(collapsedIndices);
+      // Create mem collapse op
+      auto newOp = rewriter.create<memref::CollapseShapeOp>(loc, op.getSource(),
+                                                            reassociation);
+      rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(op, op.getVector(),
+                                                           newOp, indices);
+      return success();
+    }
+
     // Vector: add unit dims at the beginning of the shape.
     Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(),
                                     missingInnerDim.size());
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 13e07f59a72a77..a654274f0a73e9 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -41,6 +41,22 @@ func.func @permutation_with_mask_scalable(%2: memref<?x?xf32>, %dim_1: index, %d
   return %1 : vector<8x[4]x2xf32>
 }
 
+// CHECK-LABEL:     func.func @permutation_with_masked_transfer_write_scalable(
+// CHECK-SAME:        %[[VAL_0:.*]]: vector<4x[8]xi16>,
+// CHECK-SAME:        %[[VAL_1:.*]]: memref<1x4x?x1x1x1x1xi16>,
+// CHECK-SAME:        %[[VAL_2:.*]]: vector<4x[8]xi1>) {
+// CHECK:             %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:             %[[VAL_4:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0, 1], [2, 3, 4, 5, 6]] : memref<1x4x?x1x1x1x1xi16> into memref<4x?xi16>
+// CHECK:             vector.transfer_write %[[VAL_0]], %[[VAL_4]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, false]} : vector<4x[8]xi16>, memref<4x?xi16>
+// CHECK:             return
+// CHECK:           }
+  func.func @permutation_with_masked_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %mask:  vector<4x[8]xi1>){
+     %c0 = arith.constant 0 : index
+      vector.transfer_write %arg0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %mask {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2)>
+} : vector<4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>
+    return
+  }
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op

@llvmbot
Copy link
Member

llvmbot commented Mar 14, 2024

@llvm/pr-subscribers-mlir

Author: Crefeda Rodrigues (cfRod)

Changes

This PR fixes the issue of lowering vector transfer writes on scalable vectors with unit dims to vector broadcast ops and vector transpose ops - where the scalable dims are dropped.


Full diff: https://github.com/llvm/llvm-project/pull/85270.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp (+67)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir (+16)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 4a5e8fcfb6edaf..cef8a497a80996 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -226,6 +226,38 @@ struct TransferWritePermutationLowering
 ///     {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>} :
 ///     vector<1x8x16xf32>
 /// ```
+/// Returns the number of dims that aren't unit dims.
+static int getReducedRank(ArrayRef<int64_t> shape) {
+  return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
+}
+
+static int getFirstNonUnitDim(MemRefType oldType) {
+  int idx = 0;
+  for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
+    if (dimSize == 1) {
+      continue;
+    } else {
+      idx = dimIdx;
+      break;
+    }
+  }
+  return idx;
+}
+
+static int getLasttNonUnitDim(MemRefType oldType) {
+  int idx = 0;
+  for (auto [dimIdx, dimSize] :
+       llvm::enumerate(llvm::reverse(oldType.getShape()))) {
+    if (dimSize == 1) {
+      continue;
+    } else {
+      idx = oldType.getRank() - (dimIdx)-1;
+      break;
+    }
+  }
+  return idx;
+}
+
 struct TransferWriteNonPermutationLowering
     : public OpRewritePattern<vector::TransferWriteOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -264,6 +296,41 @@ struct TransferWriteNonPermutationLowering
       missingInnerDim.push_back(i);
       exprs.push_back(rewriter.getAffineDimExpr(i));
     }
+
+    // Fix for lowering transfer write when we have Scalable vectors and unit
+    // dims
+    auto sourceVectorType = op.getVectorType();
+    auto memRefType = dyn_cast<MemRefType>(op.getShapedType());
+
+    if (sourceVectorType.isScalable() && !memRefType.hasStaticShape()) {
+      int reducedRank = getReducedRank(memRefType.getShape());
+
+      auto loc = op.getLoc();
+      SmallVector<Value> indices(
+          reducedRank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
+
+      // Check if the result shapes has unit dim before and after the scalable
+      // and non-scalable dim
+      int firstIdx = getFirstNonUnitDim(memRefType);
+      int lastIdx = getLasttNonUnitDim(memRefType);
+
+      SmallVector<ReassociationIndices> reassociation;
+      ReassociationIndices collapsedFirstIndices;
+      for (int64_t i = 0; i < firstIdx + 1; ++i)
+        collapsedFirstIndices.push_back(i);
+      reassociation.push_back(ReassociationIndices{collapsedFirstIndices});
+      ReassociationIndices collapsedIndices;
+      for (int64_t i = lastIdx; i < memRefType.getRank(); ++i)
+        collapsedIndices.push_back(i);
+      reassociation.push_back(collapsedIndices);
+      // Create mem collapse op
+      auto newOp = rewriter.create<memref::CollapseShapeOp>(loc, op.getSource(),
+                                                            reassociation);
+      rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(op, op.getVector(),
+                                                           newOp, indices);
+      return success();
+    }
+
     // Vector: add unit dims at the beginning of the shape.
     Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(),
                                     missingInnerDim.size());
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 13e07f59a72a77..a654274f0a73e9 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -41,6 +41,22 @@ func.func @permutation_with_mask_scalable(%2: memref<?x?xf32>, %dim_1: index, %d
   return %1 : vector<8x[4]x2xf32>
 }
 
+// CHECK-LABEL:     func.func @permutation_with_masked_transfer_write_scalable(
+// CHECK-SAME:        %[[VAL_0:.*]]: vector<4x[8]xi16>,
+// CHECK-SAME:        %[[VAL_1:.*]]: memref<1x4x?x1x1x1x1xi16>,
+// CHECK-SAME:        %[[VAL_2:.*]]: vector<4x[8]xi1>) {
+// CHECK:             %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:             %[[VAL_4:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0, 1], [2, 3, 4, 5, 6]] : memref<1x4x?x1x1x1x1xi16> into memref<4x?xi16>
+// CHECK:             vector.transfer_write %[[VAL_0]], %[[VAL_4]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, false]} : vector<4x[8]xi16>, memref<4x?xi16>
+// CHECK:             return
+// CHECK:           }
+  func.func @permutation_with_masked_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %mask:  vector<4x[8]xi1>){
+     %c0 = arith.constant 0 : index
+      vector.transfer_write %arg0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %mask {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2)>
+} : vector<4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>
+    return
+  }
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op

@cfRod
Copy link
Contributor Author

cfRod commented Mar 14, 2024

@c-rhodes
Copy link
Collaborator

thanks for the patch! We already have a pattern TransferWriteDropUnitDimsPattern for dropping unit dims on transfer_write ops in mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp.

Could this not be extended to support this case instead of introducing more logic for handling unit dims?

// CHECK-SAME: %[[VAL_2:.*]]: vector<4x[8]xi1>) {
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_4:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0, 1], [2, 3, 4, 5, 6]] : memref<1x4x?x1x1x1x1xi16> into memref<4x?xi16>
// CHECK: vector.transfer_write %[[VAL_0]], %[[VAL_4]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, false]} : vector<4x[8]xi16>, memref<4x?xi16>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it no longer in-bounds?

@cfRod
Copy link
Contributor Author

cfRod commented Mar 18, 2024

Closing this in favour of #85632

@cfRod cfRod closed this Mar 18, 2024
banach-space pushed a commit that referenced this pull request Mar 22, 2024
…wering (#85632)

Updates `extendVectorRank` so that scalability in patterns
that use it (in particular, `TransferWriteNonPermutationLowering`),
is correctly propagated.


Closed related previous PR
#85270

---------

Signed-off-by: Crefeda Rodrigues <[email protected]>
Co-authored-by: Benjamin Maxwell <[email protected]>
chencha3 pushed a commit to chencha3/llvm-project that referenced this pull request Mar 23, 2024
…wering (llvm#85632)

Updates `extendVectorRank` so that scalability in patterns
that use it (in particular, `TransferWriteNonPermutationLowering`),
is correctly propagated.


Closed related previous PR
llvm#85270

---------

Signed-off-by: Crefeda Rodrigues <[email protected]>
Co-authored-by: Benjamin Maxwell <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants