Skip to content

Commit 7b50608

Browse files
committed
Add checkLowerTensors helper
1 parent d70d796 commit 7b50608

File tree

1 file changed

+26
-21
lines changed

1 file changed

+26
-21
lines changed

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,18 @@ namespace {
4545
/// Attribute name used for labeling transfer ops during progressive lowering.
4646
static const char kPassLabel[] = "__vector_to_scf_lowering__";
4747

48+
/// Return true if this transfer op operates on a source tensor.
49+
static bool isTensorOp(VectorTransferOpInterface xferOp) {
50+
if (isa<RankedTensorType>(xferOp.getShapedType())) {
51+
if (isa<vector::TransferWriteOp>(xferOp)) {
52+
// TransferWriteOps on tensors have a result.
53+
assert(xferOp->getNumResults() > 0);
54+
}
55+
return true;
56+
}
57+
return false;
58+
}
59+
4860
/// Patterns that inherit from this struct have access to
4961
/// VectorTransferToSCFOptions.
5062
template <typename OpTy>
@@ -53,6 +65,15 @@ struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
5365
VectorTransferToSCFOptions opt)
5466
: OpRewritePattern<OpTy>(context), options(opt) {}
5567

68+
LogicalResult checkLowerTensors(VectorTransferOpInterface xferOp,
69+
PatternRewriter &rewriter) const {
70+
if (isTensorOp(xferOp) && !options.lowerTensors) {
71+
return rewriter.notifyMatchFailure(
72+
xferOp, "lowering tensor transfers is disabled");
73+
}
74+
return success();
75+
}
76+
5677
VectorTransferToSCFOptions options;
5778
};
5879

@@ -258,19 +279,6 @@ static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
258279
newXferOp->setAttr(kPassLabel, b.getUnitAttr());
259280
}
260281

261-
/// Return true if this transfer op operates on a source tensor.
262-
template <typename OpTy>
263-
static bool isTensorOp(OpTy xferOp) {
264-
if (isa<RankedTensorType>(xferOp.getShapedType())) {
265-
if (xferOp.getOperationName() == TransferWriteOp::getOperationName()) {
266-
// TransferWriteOps on tensors have a result.
267-
assert(xferOp->getNumResults() > 0);
268-
}
269-
return true;
270-
}
271-
return false;
272-
}
273-
274282
namespace lowering_n_d {
275283

276284
/// Helper data structure for data and mask buffers.
@@ -1058,10 +1066,8 @@ struct ScalableTransposeTransferWriteConversion
10581066

10591067
LogicalResult matchAndRewrite(TransferWriteOp writeOp,
10601068
PatternRewriter &rewriter) const override {
1061-
if (isTensorOp(writeOp) && !options.lowerTensors) {
1062-
return rewriter.notifyMatchFailure(
1063-
writeOp, "lowering tensor transfers is disabled");
1064-
}
1069+
if (failed(checkLowerTensors(writeOp, rewriter)))
1070+
return failure();
10651071

10661072
VectorType vectorType = writeOp.getVectorType();
10671073

@@ -1286,9 +1292,8 @@ struct UnrollTransferReadConversion
12861292
if (xferOp.getVectorType().getRank() <= options.targetRank)
12871293
return rewriter.notifyMatchFailure(
12881294
xferOp, "vector rank is less or equal to target rank");
1289-
if (isTensorOp(xferOp) && !options.lowerTensors)
1290-
return rewriter.notifyMatchFailure(
1291-
xferOp, "transfers operating on tensors are excluded");
1295+
if (failed(checkLowerTensors(xferOp, rewriter)))
1296+
return failure();
12921297
// Transfer ops that modify the element type are not supported atm.
12931298
if (xferOp.getVectorType().getElementType() !=
12941299
xferOp.getShapedType().getElementType())
@@ -1424,7 +1429,7 @@ struct UnrollTransferWriteConversion
14241429
if (inputVectorTy.getRank() <= options.targetRank)
14251430
return failure();
14261431

1427-
if (isTensorOp(xferOp) && !options.lowerTensors)
1432+
if (failed(checkLowerTensors(xferOp, rewriter)))
14281433
return failure();
14291434
// Transfer ops that modify the element type are not supported atm.
14301435
if (inputVectorTy.getElementType() !=

0 commit comments

Comments
 (0)