-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][linalg] Add support for masked vectorization of tensor.insert_slice
(2/N)
#123031
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1abdf4f
814379d
48d1657
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2716,56 +2716,56 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, | |
} | ||
auto vecType = VectorType::get(vecShape, sourceType.getElementType()); | ||
|
||
// 3. Generate TransferReadOp. | ||
SmallVector<Value> readIndices( | ||
vecType.getRank(), | ||
rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0)); | ||
Operation *read = rewriter.create<vector::TransferReadOp>( | ||
sliceOp.getLoc(), vecType, source, readIndices, padValue, | ||
ArrayRef<bool>{readInBounds}); | ||
// 3. Generate TransferReadOp + TransferWriteOp | ||
ReifiedRankedShapedTypeDims reifiedSrcSizes; | ||
Value maskOp; | ||
|
||
// If vector sizes are user provided, make sure to mask xfer_read. | ||
// If vector sizes are user provided, make sure to mask. First, generate the | ||
// mask. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Couldn't user-provided vector sizes lead to an unmasked scenario? We have a method that checks if mask is needed here (can't remember the name right now). Couldn't use it for this case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup, that's |
||
if (!inputVectorSizes.empty()) { | ||
auto *srcDefOp = source.getDefiningOp(); | ||
if (!srcDefOp) { | ||
LDBG("Unable to get the defining Op of " << sliceOp); | ||
return failure(); | ||
} | ||
|
||
ReifiedRankedShapedTypeDims reifiedSrcSizes; | ||
LogicalResult status = | ||
cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes( | ||
rewriter, reifiedSrcSizes); | ||
if (status.failed()) { | ||
LDBG("Unable to reify result shapes of " << sliceOp); | ||
LDBG("Unable to reify result shapes of " << srcDefOp); | ||
return failure(); | ||
} | ||
|
||
// Create the mask | ||
SmallVector<int64_t> readMaskShape( | ||
sliceOp.getSource().getType().getShape()); | ||
auto readMaskType = VectorType::get(inputVectorSizes, rewriter.getI1Type()); | ||
Value maskOp = rewriter.create<vector::CreateMaskOp>( | ||
maskOp = rewriter.create<vector::CreateMaskOp>( | ||
sliceOp.getLoc(), readMaskType, reifiedSrcSizes[0]); | ||
|
||
// Mask the xfer_read Op | ||
read = mlir::vector::maskOperation(rewriter, read, maskOp); | ||
} | ||
|
||
// 4. Generate TransferWriteOp. | ||
if (!inputVectorSizes.empty() && | ||
ShapedType::isDynamicShape(resultType.getShape())) { | ||
LDBG("TODO: Masking of xfer_write when vectorising " << sliceOp); | ||
return failure(); | ||
SmallVector<Value> readIndices( | ||
vecType.getRank(), | ||
rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0)); | ||
Operation *read = rewriter.create<vector::TransferReadOp>( | ||
sliceOp.getLoc(), vecType, source, readIndices, padValue, | ||
ArrayRef<bool>{readInBounds}); | ||
|
||
if (maskOp) { | ||
read = mlir::vector::maskOperation(rewriter, read, maskOp); | ||
} | ||
|
||
auto writeIndices = getValueOrCreateConstantIndexOp( | ||
rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets()); | ||
|
||
// 5. Finalize | ||
Operation *write = rewriter.create<vector::TransferWriteOp>( | ||
sliceOp.getLoc(), read->getResult(0), sliceOp.getDest(), writeIndices, | ||
ArrayRef<bool>{writeInBounds}); | ||
|
||
if (maskOp) { | ||
write = mlir::vector::maskOperation(rewriter, write, maskOp); | ||
} | ||
|
||
// 4. Finalize | ||
newResults.push_back(write->getResult(0)); | ||
|
||
return success(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move the declaration to where it is initialized, i.e., l.2742?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that l.2742 sits within an
if
block and the generated mask is also used outside, e.g. l.2756.This is roughly the structure:
I have ideas how to improve this, but no spare cycles 😢 (there's
createWriteOrMaskedWrite
andcreateReadOrMaskedRead
that we should re-use here, but that won't work as-is).If that's OK, will add this to my TODO list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry that I missed it. I see, thanks!