@@ -45,6 +45,18 @@ namespace {
45
45
// / Attribute name used for labeling transfer ops during progressive lowering.
46
46
static const char kPassLabel [] = " __vector_to_scf_lowering__" ;
47
47
48
+ // / Return true if this transfer op operates on a source tensor.
49
+ static bool isTensorOp (VectorTransferOpInterface xferOp) {
50
+ if (isa<RankedTensorType>(xferOp.getShapedType ())) {
51
+ if (isa<vector::TransferWriteOp>(xferOp)) {
52
+ // TransferWriteOps on tensors have a result.
53
+ assert (xferOp->getNumResults () > 0 );
54
+ }
55
+ return true ;
56
+ }
57
+ return false ;
58
+ }
59
+
48
60
// / Patterns that inherit from this struct have access to
49
61
// / VectorTransferToSCFOptions.
50
62
template <typename OpTy>
@@ -53,6 +65,15 @@ struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
53
65
VectorTransferToSCFOptions opt)
54
66
: OpRewritePattern<OpTy>(context), options(opt) {}
55
67
68
+ LogicalResult checkLowerTensors (VectorTransferOpInterface xferOp,
69
+ PatternRewriter &rewriter) const {
70
+ if (isTensorOp (xferOp) && !options.lowerTensors ) {
71
+ return rewriter.notifyMatchFailure (
72
+ xferOp, " lowering tensor transfers is disabled" );
73
+ }
74
+ return success ();
75
+ }
76
+
56
77
VectorTransferToSCFOptions options;
57
78
};
58
79
@@ -258,19 +279,6 @@ static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
258
279
newXferOp->setAttr (kPassLabel , b.getUnitAttr ());
259
280
}
260
281
261
- // / Return true if this transfer op operates on a source tensor.
262
- template <typename OpTy>
263
- static bool isTensorOp (OpTy xferOp) {
264
- if (isa<RankedTensorType>(xferOp.getShapedType ())) {
265
- if (xferOp.getOperationName () == TransferWriteOp::getOperationName ()) {
266
- // TransferWriteOps on tensors have a result.
267
- assert (xferOp->getNumResults () > 0 );
268
- }
269
- return true ;
270
- }
271
- return false ;
272
- }
273
-
274
282
namespace lowering_n_d {
275
283
276
284
// / Helper data structure for data and mask buffers.
@@ -1058,10 +1066,8 @@ struct ScalableTransposeTransferWriteConversion
1058
1066
1059
1067
LogicalResult matchAndRewrite (TransferWriteOp writeOp,
1060
1068
PatternRewriter &rewriter) const override {
1061
- if (isTensorOp (writeOp) && !options.lowerTensors ) {
1062
- return rewriter.notifyMatchFailure (
1063
- writeOp, " lowering tensor transfers is disabled" );
1064
- }
1069
+ if (failed (checkLowerTensors (writeOp, rewriter)))
1070
+ return failure ();
1065
1071
1066
1072
VectorType vectorType = writeOp.getVectorType ();
1067
1073
@@ -1286,9 +1292,8 @@ struct UnrollTransferReadConversion
1286
1292
if (xferOp.getVectorType ().getRank () <= options.targetRank )
1287
1293
return rewriter.notifyMatchFailure (
1288
1294
xferOp, " vector rank is less or equal to target rank" );
1289
- if (isTensorOp (xferOp) && !options.lowerTensors )
1290
- return rewriter.notifyMatchFailure (
1291
- xferOp, " transfers operating on tensors are excluded" );
1295
+ if (failed (checkLowerTensors (xferOp, rewriter)))
1296
+ return failure ();
1292
1297
// Transfer ops that modify the element type are not supported atm.
1293
1298
if (xferOp.getVectorType ().getElementType () !=
1294
1299
xferOp.getShapedType ().getElementType ())
@@ -1424,7 +1429,7 @@ struct UnrollTransferWriteConversion
1424
1429
if (inputVectorTy.getRank () <= options.targetRank )
1425
1430
return failure ();
1426
1431
1427
- if (isTensorOp ( xferOp) && !options. lowerTensors )
1432
+ if (failed ( checkLowerTensors ( xferOp, rewriter)) )
1428
1433
return failure ();
1429
1434
// Transfer ops that modify the element type are not supported atm.
1430
1435
if (inputVectorTy.getElementType () !=
0 commit comments