Skip to content

Fix unsupported transpose ops for scalable vectors in LowerVectorTransfer #86163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

cfRod
Copy link
Contributor

@cfRod cfRod commented Mar 21, 2024

Addresses comment in #85632 (comment)
Co-authored by @banach-space

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be
notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write
permissions for the repository. In which case you can instead tag reviewers by
name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review
by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate
is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Mar 21, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Crefeda Rodrigues (cfRod)

Changes

Addresses comment in #85632 (comment)


Full diff: https://github.com/llvm/llvm-project/pull/86163.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp (+15-4)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir (+17)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 4a5e8fcfb6edaf..570a5222862b72 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -41,8 +41,12 @@ static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
   SmallVector<int64_t> newShape(addedRank, 1);
   newShape.append(originalVecType.getShape().begin(),
                   originalVecType.getShape().end());
-  VectorType newVecType =
-      VectorType::get(newShape, originalVecType.getElementType());
+
+  SmallVector<bool> newScalableDims(addedRank, false);
+  newScalableDims.append(originalVecType.getScalableDims().begin(),
+                         originalVecType.getScalableDims().end());
+  VectorType newVecType = VectorType::get(
+      newShape, originalVecType.getElementType(), newScalableDims);
   return builder.create<vector::BroadcastOp>(loc, newVecType, vec);
 }
 
@@ -201,12 +205,19 @@ struct TransferWritePermutationLowering
     // Generate new transfer_write operation.
     Value newVec = rewriter.create<vector::TransposeOp>(
         op.getLoc(), op.getVector(), indices);
+
+    auto vectorType = cast<VectorType>(newVec.getType());
+
+    if (vectorType.isScalable() && !*vectorType.getScalableDims().end()) {
+      rewriter.eraseOp(newVec.getDefiningOp());
+      return failure();
+    }
+
     auto newMap = AffineMap::getMinorIdentityMap(
         map.getNumDims(), map.getNumResults(), rewriter.getContext());
     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
         op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
         op.getMask(), newInBoundsAttr);
-
     return success();
   }
 };
@@ -269,7 +280,7 @@ struct TransferWriteNonPermutationLowering
                                     missingInnerDim.size());
     // Mask: add unit dims at the end of the shape.
     Value newMask;
-    if (op.getMask())
+    if (op.getMask() && !op.getVectorType().isScalable())
       newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(),
                                missingInnerDim.size());
     exprs.append(map.getResults().begin(), map.getResults().end());
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 13e07f59a72a77..83a7f21daf683f 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -41,6 +41,23 @@ func.func @permutation_with_mask_scalable(%2: memref<?x?xf32>, %dim_1: index, %d
   return %1 : vector<8x[4]x2xf32>
 }
 
+// CHECK-LABEL:   func.func @permutation_with_mask_transfer_write_scalable(
+// CHECK-SAME:                                                             %[[ARG_0:.*]]: vector<4x[8]xi16>,
+// CHECK-SAME:                                                             %[[ARG_1:.*]]: memref<1x4x?x1x1x1x1xi16>,
+// CHECK-SAME:                                                             %[[MASK:.*]]: vector<4x[8]xi1>) {
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : vector<4x[8]xi16> to vector<1x1x1x1x4x[8]xi16>
+// CHECK:           vector.transfer_write %[[BCAST]], %[[ARG_1]]{{\[}}%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true, true, true], permutation_map = #map} : vector<1x1x1x1x4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>
+// CHECK:           return
+// CHECK:         }
+  func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %mask:  vector<4x[8]xi1>){
+     %c0 = arith.constant 0 : index
+      vector.transfer_write %arg0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %mask {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2)>
+} : vector<4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>
+
+    return
+  }
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op

@cfRod
Copy link
Contributor Author

cfRod commented Mar 21, 2024

@cfRod cfRod changed the title Fix unsupported transpose ops scalable Fix unsupported transpose ops for scalable vectors in LowerVectorTransfer Mar 21, 2024

auto vectorType = cast<VectorType>(newVec.getType());

if (vectorType.isScalable() && !*vectorType.getScalableDims().end()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe *vectorType.getScalableDims().end() is could be an out-of-bounds access. The .end() method returns an iterator passed the end of the array, I think what you probably want is:

Suggested change
if (vectorType.isScalable() && !*vectorType.getScalableDims().end()) {
if (vectorType.isScalable() && !vectorType.getScalableDims().back()) {

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, for this check there's isLegalVectorType() here: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp#L421-L430

Maybe this could be moved to some general utils?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, for this check there's isLegalVectorType()

Re-using that hook would make sense to me, but we can't call it isLegalVectorType (and, in general, we need to be careful when labeling vectors as "illegal"):

  • vectors like vector<[4]x4xi32> are perfectly legal at the "abstract" Vector dialect level,
  • only once we start lowering to LLVM (and/or SVE/SME), these types needs to be eliminated and it makes sense to consider them as "illegal".

For scalable vectors that have only 1 scalable dim, this code is correct . @cfRod, I suggest adding a comment that we are assuming that at most 1 dim is scalable. @MacDue , I can rename and move isLegalVectorType in a separate PR. WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine with me 👍 (I agree with the naming, the isLegal was only within the context of the ArmSME pass).

Comment on lines +212 to +213
rewriter.eraseOp(newVec.getDefiningOp());
return failure();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this for? It returns failure() so this is a case where the pattern should not apply (and any changes rolled-back, I think), but it's erasing an operation.

Copy link
Member

@MacDue MacDue Mar 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I've misunderstood this change. What's happening here? Did the test case previously lower with TransferWritePermutationLowering, and now TransferWriteNonPermutationLowering now manages to lower the same thing?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and any changes rolled-back, I think), but it's erasing an operation.

It's erasing the Op to roll-back the changes :) And making sure that the new Op is not added to the list of Ops to be processed by the pattern rewriter driver.

I think I've misunderstood this change. What's happening here? Did the test case previously lower with TransferWritePermutationLowering, and now TransferWriteNonPermutationLowering now manages to lower the same thing?

The test case used to trigger TransferWritePermutationLowering, but that's now disabled and looks like some other pattern is triggered. @cfRod , do you know which one?

Copy link
Contributor Author

@cfRod cfRod Mar 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously it was: TransferWriteNonPermutationLowering called first and so the two vector.broadcast are added and then a transpose for the mask and then TransferWritePermutationLowering is called to add the transpose for the input

* Pattern (anonymous namespace)::TransferWriteNonPermutationLowering : 'vector.transfer_write -> ()' {
Trying to match "(anonymous namespace)::TransferWriteNonPermutationLowering"
  ** Insert  : 'vector.broadcast'(0xac685f872c10)
  ** Insert  : 'vector.broadcast'(0xac685f879e40)
  ** Insert  : 'vector.transpose'(0xac685f879ed0)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::vector::detail::TransferWriteOpGenericAdaptorBase::Properties)
  ** Insert  : 'vector.transfer_write'(0xac685f7e1d00)
  ** Replace : 'vector.transfer_write'(0xac685f7b1eb0)
  ** Erase   : 'vector.transfer_write'(0xac685f7b1eb0)
"(anonymous namespace)::TransferWriteNonPermutationLowering" result 1
} -> success : pattern applied successfully
// *** IR Dump After Pattern Application ***
mlir-asm-printer: Verifying operation: func.func
func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %arg2: vector<4x[8]xi1>) {
%c0 = arith.constant 0 : index
%0 = vector.broadcast %arg0 : vector<4x[8]xi16> to vector<1x1x1x1x4x[8]xi16>
%1 = vector.broadcast %arg2 : vector<4x[8]xi1> to vector<1x1x1x1x4x[8]xi1>
%2 = vector.transpose %1, [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi1> to vector<4x[8]x1x1x1x1xi1>
vector.transfer_write %0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %2 {in_bounds = [true, true, true, true, true, true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6, d1, d2)>} : vector<1x1x1x1x4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>
return
}


} -> success : pattern matched
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'vector.transfer_write'(0xac685f7e1d00) {
"vector.transfer_write"(%1, %arg1, %0, %0, %0, %0, %0, %0, %0, %3) <{in_bounds = [true, true, true, true, true, true], operandSegmentSizes = array<i32: 1, 1, 7, 1>, permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6, d1, d2)>}> : (vector<1x1x1x1x4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>, index, index, index, index, index, index, index, vector<4x[8]x1x1x1x1xi1>) -> ()

ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::VectorType>::Impl<Empty>)

* Pattern (anonymous namespace)::TransferWritePermutationLowering : 'vector.transfer_write -> ()' {
Trying to match "(anonymous namespace)::TransferWritePermutationLowering"
  ** Insert  : 'vector.transpose'(0xac685f87b6b0)
  ** Insert  : 'vector.transfer_write'(0xac685f7b1eb0)
  ** Replace : 'vector.transfer_write'(0xac685f7e1d00)
  ** Erase   : 'vector.transfer_write'(0xac685f7e1d00)
"(anonymous namespace)::TransferWritePermutationLowering" result 1
} -> success : pattern applied successfully
// *** IR Dump After Pattern Application ***
mlir-asm-printer: Verifying operation: func.func
func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %arg2: vector<4x[8]xi1>) {
%c0 = arith.constant 0 : index
%0 = vector.broadcast %arg0 : vector<4x[8]xi16> to vector<1x1x1x1x4x[8]xi16>
%1 = vector.broadcast %arg2 : vector<4x[8]xi1> to vector<1x1x1x1x4x[8]xi1>
%2 = vector.transpose %1, [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi1> to vector<4x[8]x1x1x1x1xi1>
%3 = vector.transpose %0, [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi16> to vector<4x[8]x1x1x1x1xi16>
vector.transfer_write %3, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %2 {in_bounds = [true, true, true, true, true, true]} : vector<4x[8]x1x1x1x1xi16>, memref<1x4x?x1x1x1x1xi16>
return
}

After this patch, the second transpose is "erased"

* Pattern (anonymous namespace)::TransferWritePermutationLowering : 'vector.transfer_write -> ()' {
Trying to match "(anonymous namespace)::TransferWritePermutationLowering"
    ** Insert  : 'vector.transpose'(0xaea769721de0)
    ** Erase   : 'vector.transpose'(0xaea769721de0)
"(anonymous namespace)::TransferWritePermutationLowering" result 0
  } -> failure : pattern failed to match

@@ -269,7 +280,7 @@ struct TransferWriteNonPermutationLowering
missingInnerDim.size());
// Mask: add unit dims at the end of the shape.
Value newMask;
if (op.getMask())
if (op.getMask() && !op.getVectorType().isScalable())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can the mask be omitted for scalable vectors? I see the input test case is also masked, but the new vector.transfer_write is unmasked.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can erase the unsupported transpose. Similar to how we do it TransferWritePermutationLowering?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In TransferWritePermutationLowering the rewrite fails (i.e. it does nothing), erasing the transpose is just doing a cleanup. Here the rewrite succeeds but the mask is omitted, which changes the semantics of the operation, which could lead to incorrect results.

…for scalable vectors

Signed-off-by: Crefeda Rodrigues <[email protected]>
@cfRod cfRod force-pushed the fix_unsupported_transpose_ops_scalable branch from 0c6c1d1 to 1f9f3cc Compare March 22, 2024 14:06
erthalion added a commit to erthalion/llvm-project that referenced this pull request Mar 29, 2024
The commit 2da4960 enabled `noundef` attributes propagation. It looks
like ret void is considered to be `noundef` thus `ValidUB.hasAttributes`
now returns true for this type of instructions and everything proceed
further to work with operands. The issue is that such instruction
doesn't have operands, which means when accessing `RI->getOperand(0)`
inliner pass crashes with an assert:

    llvm/include/llvm/IR/Instructions.h:3420: llvm::Value* llvm::ReturnInst::getOperand(unsigned int) const:
    Assertion `i_nocapture < OperandTraits<ReturnInst>::operands(this) && "getOperand() out of range!"' failed.

Fix that by verifying if the ReturnInst in fact has some operands to
process.

Fixes llvm#86163
@cfRod cfRod closed this May 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants