-
Notifications
You must be signed in to change notification settings - Fork 14.3k
Fix unsupported transpose ops for scalable vectors in LowerVectorTransfer #86163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write If you have received no comments on your PR for a week, you can request a review If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Crefeda Rodrigues (cfRod) ChangesAddresses comment in #85632 (comment) Full diff: https://github.com/llvm/llvm-project/pull/86163.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 4a5e8fcfb6edaf..570a5222862b72 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -41,8 +41,12 @@ static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
SmallVector<int64_t> newShape(addedRank, 1);
newShape.append(originalVecType.getShape().begin(),
originalVecType.getShape().end());
- VectorType newVecType =
- VectorType::get(newShape, originalVecType.getElementType());
+
+ SmallVector<bool> newScalableDims(addedRank, false);
+ newScalableDims.append(originalVecType.getScalableDims().begin(),
+ originalVecType.getScalableDims().end());
+ VectorType newVecType = VectorType::get(
+ newShape, originalVecType.getElementType(), newScalableDims);
return builder.create<vector::BroadcastOp>(loc, newVecType, vec);
}
@@ -201,12 +205,19 @@ struct TransferWritePermutationLowering
// Generate new transfer_write operation.
Value newVec = rewriter.create<vector::TransposeOp>(
op.getLoc(), op.getVector(), indices);
+
+ auto vectorType = cast<VectorType>(newVec.getType());
+
+ if (vectorType.isScalable() && !*vectorType.getScalableDims().end()) {
+ rewriter.eraseOp(newVec.getDefiningOp());
+ return failure();
+ }
+
auto newMap = AffineMap::getMinorIdentityMap(
map.getNumDims(), map.getNumResults(), rewriter.getContext());
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
op.getMask(), newInBoundsAttr);
-
return success();
}
};
@@ -269,7 +280,7 @@ struct TransferWriteNonPermutationLowering
missingInnerDim.size());
// Mask: add unit dims at the end of the shape.
Value newMask;
- if (op.getMask())
+ if (op.getMask() && !op.getVectorType().isScalable())
newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(),
missingInnerDim.size());
exprs.append(map.getResults().begin(), map.getResults().end());
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 13e07f59a72a77..83a7f21daf683f 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -41,6 +41,23 @@ func.func @permutation_with_mask_scalable(%2: memref<?x?xf32>, %dim_1: index, %d
return %1 : vector<8x[4]x2xf32>
}
+// CHECK-LABEL: func.func @permutation_with_mask_transfer_write_scalable(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>,
+// CHECK-SAME: %[[ARG_1:.*]]: memref<1x4x?x1x1x1x1xi16>,
+// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : vector<4x[8]xi16> to vector<1x1x1x1x4x[8]xi16>
+// CHECK: vector.transfer_write %[[BCAST]], %[[ARG_1]]{{\[}}%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true, true, true], permutation_map = #map} : vector<1x1x1x1x4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>
+// CHECK: return
+// CHECK: }
+ func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %mask: vector<4x[8]xi1>){
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %arg0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %mask {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2)>
+} : vector<4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>
+
+ return
+ }
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
|
|
||
auto vectorType = cast<VectorType>(newVec.getType()); | ||
|
||
if (vectorType.isScalable() && !*vectorType.getScalableDims().end()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe *vectorType.getScalableDims().end()
is could be an out-of-bounds access. The .end()
method returns an iterator passed the end of the array, I think what you probably want is:
if (vectorType.isScalable() && !*vectorType.getScalableDims().end()) { | |
if (vectorType.isScalable() && !vectorType.getScalableDims().back()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, for this check there's isLegalVectorType()
here: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp#L421-L430
Maybe this could be moved to some general utils?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, for this check there's isLegalVectorType()
Re-using that hook would make sense to me, but we can't call it isLegalVectorType
(and, in general, we need to be careful when labeling vectors as "illegal"):
- vectors like
vector<[4]x4xi32>
are perfectly legal at the "abstract" Vector dialect level, - only once we start lowering to LLVM (and/or SVE/SME), these types needs to be eliminated and it makes sense to consider them as "illegal".
For scalable vectors that have only 1 scalable dim, this code is correct . @cfRod, I suggest adding a comment that we are assuming that at most 1 dim is scalable. @MacDue , I can rename and move isLegalVectorType
in a separate PR. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fine with me 👍 (I agree with the naming, the isLegal
was only within the context of the ArmSME pass).
rewriter.eraseOp(newVec.getDefiningOp()); | ||
return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's this for? It returns failure()
so this is a case where the pattern should not apply (and any changes rolled-back, I think), but it's erasing an operation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I've misunderstood this change. What's happening here? Did the test case previously lower with TransferWritePermutationLowering
, and now TransferWriteNonPermutationLowering
now manages to lower the same thing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and any changes rolled-back, I think), but it's erasing an operation.
It's erasing the Op to roll-back the changes :) And making sure that the new Op is not added to the list of Ops to be processed by the pattern rewriter driver.
I think I've misunderstood this change. What's happening here? Did the test case previously lower with TransferWritePermutationLowering, and now TransferWriteNonPermutationLowering now manages to lower the same thing?
The test case used to trigger TransferWritePermutationLowering
, but that's now disabled and looks like some other pattern is triggered. @cfRod , do you know which one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously it was: TransferWriteNonPermutationLowering called first and so the two vector.broadcast are added and then a transpose for the mask and then TransferWritePermutationLowering is called to add the transpose for the input
* Pattern (anonymous namespace)::TransferWriteNonPermutationLowering : 'vector.transfer_write -> ()' {
Trying to match "(anonymous namespace)::TransferWriteNonPermutationLowering"
** Insert : 'vector.broadcast'(0xac685f872c10)
** Insert : 'vector.broadcast'(0xac685f879e40)
** Insert : 'vector.transpose'(0xac685f879ed0)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::vector::detail::TransferWriteOpGenericAdaptorBase::Properties)
** Insert : 'vector.transfer_write'(0xac685f7e1d00)
** Replace : 'vector.transfer_write'(0xac685f7b1eb0)
** Erase : 'vector.transfer_write'(0xac685f7b1eb0)
"(anonymous namespace)::TransferWriteNonPermutationLowering" result 1
} -> success : pattern applied successfully
// *** IR Dump After Pattern Application ***
mlir-asm-printer: Verifying operation: func.func
func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %arg2: vector<4x[8]xi1>) {
%c0 = arith.constant 0 : index
%0 = vector.broadcast %arg0 : vector<4x[8]xi16> to vector<1x1x1x1x4x[8]xi16>
%1 = vector.broadcast %arg2 : vector<4x[8]xi1> to vector<1x1x1x1x4x[8]xi1>
%2 = vector.transpose %1, [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi1> to vector<4x[8]x1x1x1x1xi1>
vector.transfer_write %0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %2 {in_bounds = [true, true, true, true, true, true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6, d1, d2)>} : vector<1x1x1x1x4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>
return
}
} -> success : pattern matched
//===-------------------------------------------===//
//===-------------------------------------------===//
Processing operation : 'vector.transfer_write'(0xac685f7e1d00) {
"vector.transfer_write"(%1, %arg1, %0, %0, %0, %0, %0, %0, %0, %3) <{in_bounds = [true, true, true, true, true, true], operandSegmentSizes = array<i32: 1, 1, 7, 1>, permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6, d1, d2)>}> : (vector<1x1x1x1x4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>, index, index, index, index, index, index, index, vector<4x[8]x1x1x1x1xi1>) -> ()
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::VectorType>::Impl<Empty>)
* Pattern (anonymous namespace)::TransferWritePermutationLowering : 'vector.transfer_write -> ()' {
Trying to match "(anonymous namespace)::TransferWritePermutationLowering"
** Insert : 'vector.transpose'(0xac685f87b6b0)
** Insert : 'vector.transfer_write'(0xac685f7b1eb0)
** Replace : 'vector.transfer_write'(0xac685f7e1d00)
** Erase : 'vector.transfer_write'(0xac685f7e1d00)
"(anonymous namespace)::TransferWritePermutationLowering" result 1
} -> success : pattern applied successfully
// *** IR Dump After Pattern Application ***
mlir-asm-printer: Verifying operation: func.func
func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %arg2: vector<4x[8]xi1>) {
%c0 = arith.constant 0 : index
%0 = vector.broadcast %arg0 : vector<4x[8]xi16> to vector<1x1x1x1x4x[8]xi16>
%1 = vector.broadcast %arg2 : vector<4x[8]xi1> to vector<1x1x1x1x4x[8]xi1>
%2 = vector.transpose %1, [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi1> to vector<4x[8]x1x1x1x1xi1>
%3 = vector.transpose %0, [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi16> to vector<4x[8]x1x1x1x1xi16>
vector.transfer_write %3, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %2 {in_bounds = [true, true, true, true, true, true]} : vector<4x[8]x1x1x1x1xi16>, memref<1x4x?x1x1x1x1xi16>
return
}
After this patch, the second transpose is "erased"
* Pattern (anonymous namespace)::TransferWritePermutationLowering : 'vector.transfer_write -> ()' {
Trying to match "(anonymous namespace)::TransferWritePermutationLowering"
** Insert : 'vector.transpose'(0xaea769721de0)
** Erase : 'vector.transpose'(0xaea769721de0)
"(anonymous namespace)::TransferWritePermutationLowering" result 0
} -> failure : pattern failed to match
@@ -269,7 +280,7 @@ struct TransferWriteNonPermutationLowering | |||
missingInnerDim.size()); | |||
// Mask: add unit dims at the end of the shape. | |||
Value newMask; | |||
if (op.getMask()) | |||
if (op.getMask() && !op.getVectorType().isScalable()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can the mask be omitted for scalable vectors? I see the input test case is also masked, but the new vector.transfer_write
is unmasked.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can erase the unsupported transpose. Similar to how we do it TransferWritePermutationLowering?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In TransferWritePermutationLowering
the rewrite fails (i.e. it does nothing), erasing the transpose is just doing a cleanup. Here the rewrite succeeds but the mask is omitted, which changes the semantics of the operation, which could lead to incorrect results.
…for scalable vectors Signed-off-by: Crefeda Rodrigues <[email protected]>
0c6c1d1
to
1f9f3cc
Compare
The commit 2da4960 enabled `noundef` attributes propagation. It looks like ret void is considered to be `noundef` thus `ValidUB.hasAttributes` now returns true for this type of instructions and everything proceed further to work with operands. The issue is that such instruction doesn't have operands, which means when accessing `RI->getOperand(0)` inliner pass crashes with an assert: llvm/include/llvm/IR/Instructions.h:3420: llvm::Value* llvm::ReturnInst::getOperand(unsigned int) const: Assertion `i_nocapture < OperandTraits<ReturnInst>::operands(this) && "getOperand() out of range!"' failed. Fix that by verifying if the ReturnInst in fact has some operands to process. Fixes llvm#86163
Addresses comment in #85632 (comment)
Co-authored by @banach-space