-
Notifications
You must be signed in to change notification settings - Fork 14.3k
Fix unsupported transpose ops for scalable vectors in LowerVectorTransfer #86163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -205,12 +205,19 @@ struct TransferWritePermutationLowering | |
// Generate new transfer_write operation. | ||
Value newVec = rewriter.create<vector::TransposeOp>( | ||
op.getLoc(), op.getVector(), indices); | ||
|
||
auto vectorType = cast<VectorType>(newVec.getType()); | ||
|
||
if (vectorType.isScalable() && !*vectorType.getScalableDims().end()) { | ||
rewriter.eraseOp(newVec.getDefiningOp()); | ||
return failure(); | ||
Comment on lines
+212
to
+213
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's this for? It returns There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I've misunderstood this change. What's happening here? Did the test case previously lower with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It's erasing the Op to roll-back the changes :) And making sure that the new Op is not added to the list of Ops to be processed by the pattern rewriter driver.
The test case used to trigger There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Previously it was: TransferWriteNonPermutationLowering called first and so the two vector.broadcast are added and then a transpose for the mask and then TransferWritePermutationLowering is called to add the transpose for the input
After this patch, the second transpose is "erased"
|
||
} | ||
|
||
auto newMap = AffineMap::getMinorIdentityMap( | ||
map.getNumDims(), map.getNumResults(), rewriter.getContext()); | ||
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( | ||
op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap), | ||
op.getMask(), newInBoundsAttr); | ||
|
||
return success(); | ||
} | ||
}; | ||
|
@@ -273,7 +280,7 @@ struct TransferWriteNonPermutationLowering | |
missingInnerDim.size()); | ||
// Mask: add unit dims at the end of the shape. | ||
Value newMask; | ||
if (op.getMask()) | ||
if (op.getMask() && !op.getVectorType().isScalable()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why can the mask be omitted for scalable vectors? I see the input test case is also masked, but the new There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can erase the unsupported transpose. Similar to how we do it TransferWritePermutationLowering? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In |
||
newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(), | ||
missingInnerDim.size()); | ||
exprs.append(map.getResults().begin(), map.getResults().end()); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe
*vectorType.getScalableDims().end()
is could be an out-of-bounds access. The.end()
method returns an iterator passed the end of the array, I think what you probably want is:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, for this check there's
isLegalVectorType()
here: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp#L421-L430Maybe this could be moved to some general utils?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re-using that hook would make sense to me, but we can't call it
isLegalVectorType
(and, in general, we need to be careful when labeling vectors as "illegal"):vector<[4]x4xi32>
are perfectly legal at the "abstract" Vector dialect level,For scalable vectors that have only 1 scalable dim, this code is correct . @cfRod, I suggest adding a comment that we are assuming that at most 1 dim is scalable. @MacDue , I can rename and move
isLegalVectorType
in a separate PR. WDYT?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fine with me 👍 (I agree with the naming, the
isLegal
was only within the context of the ArmSME pass).