Skip to content

[mlir][vector] Fix rewrite pattern API violation in VectorToSCF #77909

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1060,10 +1060,10 @@ struct UnrollTransferReadConversion
setHasBoundedRewriteRecursion();
}

/// Return the vector into which the newly created TransferReadOp results
/// are inserted.
Value getResultVector(TransferReadOp xferOp,
PatternRewriter &rewriter) const {
/// Get or build the vector into which the newly created TransferReadOp
/// results are inserted.
Value buildResultVector(PatternRewriter &rewriter,
TransferReadOp xferOp) const {
if (auto insertOp = getInsertOp(xferOp))
return insertOp.getDest();
Location loc = xferOp.getLoc();
Expand Down Expand Up @@ -1098,24 +1098,27 @@ struct UnrollTransferReadConversion
LogicalResult matchAndRewrite(TransferReadOp xferOp,
PatternRewriter &rewriter) const override {
if (xferOp.getVectorType().getRank() <= options.targetRank)
return failure();
return rewriter.notifyMatchFailure(
xferOp, "vector rank is less or equal to target rank");
if (isTensorOp(xferOp) && !options.lowerTensors)
return failure();
return rewriter.notifyMatchFailure(
xferOp, "transfers operating on tensors are excluded");
// Transfer ops that modify the element type are not supported atm.
if (xferOp.getVectorType().getElementType() !=
xferOp.getShapedType().getElementType())
return failure();

auto insertOp = getInsertOp(xferOp);
auto vec = getResultVector(xferOp, rewriter);
auto vecType = dyn_cast<VectorType>(vec.getType());
return rewriter.notifyMatchFailure(
xferOp, "not yet supported: element type mismatch");
auto xferVecType = xferOp.getVectorType();

if (xferVecType.getScalableDims()[0]) {
// Cannot unroll a scalable dimension at compile time.
return failure();
return rewriter.notifyMatchFailure(
xferOp, "scalable dimensions cannot be unrolled");
}

auto insertOp = getInsertOp(xferOp);
auto vec = buildResultVector(rewriter, xferOp);
auto vecType = dyn_cast<VectorType>(vec.getType());

VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(0);

int64_t dimSize = xferVecType.getShape()[0];
Expand Down