24
24
#include " mlir/Dialect/Vector/IR/VectorOps.h"
25
25
#include " mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
26
26
#include " mlir/Dialect/Vector/Transforms/VectorTransforms.h"
27
+ #include " mlir/Dialect/Vector/Utils/VectorUtils.h"
27
28
#include " mlir/IR/Builders.h"
28
29
#include " mlir/IR/ImplicitLocOpBuilder.h"
29
30
#include " mlir/Pass/Pass.h"
@@ -44,6 +45,18 @@ namespace {
44
45
// / Attribute name used for labeling transfer ops during progressive lowering.
45
46
static const char kPassLabel [] = " __vector_to_scf_lowering__" ;
46
47
48
+ // / Return true if this transfer op operates on a source tensor.
49
+ static bool isTensorOp (VectorTransferOpInterface xferOp) {
50
+ if (isa<RankedTensorType>(xferOp.getShapedType ())) {
51
+ if (isa<vector::TransferWriteOp>(xferOp)) {
52
+ // TransferWriteOps on tensors have a result.
53
+ assert (xferOp->getNumResults () > 0 );
54
+ }
55
+ return true ;
56
+ }
57
+ return false ;
58
+ }
59
+
47
60
// / Patterns that inherit from this struct have access to
48
61
// / VectorTransferToSCFOptions.
49
62
template <typename OpTy>
@@ -52,6 +65,15 @@ struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
52
65
VectorTransferToSCFOptions opt)
53
66
: OpRewritePattern<OpTy>(context), options(opt) {}
54
67
68
+ LogicalResult checkLowerTensors (VectorTransferOpInterface xferOp,
69
+ PatternRewriter &rewriter) const {
70
+ if (isTensorOp (xferOp) && !options.lowerTensors ) {
71
+ return rewriter.notifyMatchFailure (
72
+ xferOp, " lowering tensor transfers is disabled" );
73
+ }
74
+ return success ();
75
+ }
76
+
55
77
VectorTransferToSCFOptions options;
56
78
};
57
79
@@ -257,19 +279,6 @@ static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
257
279
newXferOp->setAttr (kPassLabel , b.getUnitAttr ());
258
280
}
259
281
260
- // / Return true if this transfer op operates on a source tensor.
261
- template <typename OpTy>
262
- static bool isTensorOp (OpTy xferOp) {
263
- if (isa<RankedTensorType>(xferOp.getShapedType ())) {
264
- if (xferOp.getOperationName () == TransferWriteOp::getOperationName ()) {
265
- // TransferWriteOps on tensors have a result.
266
- assert (xferOp->getNumResults () > 0 );
267
- }
268
- return true ;
269
- }
270
- return false ;
271
- }
272
-
273
282
namespace lowering_n_d {
274
283
275
284
// / Helper data structure for data and mask buffers.
@@ -987,6 +996,189 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
987
996
}
988
997
};
989
998
999
+ // / Retrieves the dimensions sizes of a mask. Currently supports CreateMaskOp
1000
+ // / and ConstantMaskOp.
1001
+ template <typename VscaleConstantBuilder>
1002
+ static FailureOr<SmallVector<OpFoldResult>>
1003
+ getMaskDimSizes (Value mask, VscaleConstantBuilder &createVscaleMultiple) {
1004
+ if (!mask)
1005
+ return SmallVector<OpFoldResult>{};
1006
+ if (auto createMaskOp = mask.getDefiningOp <vector::CreateMaskOp>()) {
1007
+ return llvm::map_to_vector (createMaskOp.getOperands (), [](Value dimSize) {
1008
+ return OpFoldResult (dimSize);
1009
+ });
1010
+ }
1011
+ if (auto constantMask = mask.getDefiningOp <vector::ConstantMaskOp>()) {
1012
+ int dimIdx = 0 ;
1013
+ VectorType maskType = constantMask.getVectorType ();
1014
+ auto indexType = IndexType::get (mask.getContext ());
1015
+ return llvm::map_to_vector (
1016
+ constantMask.getMaskDimSizes (), [&](int64_t dimSize) {
1017
+ // A scalable dim in a constant_mask means vscale x dimSize.
1018
+ if (maskType.getScalableDims ()[dimIdx++])
1019
+ return OpFoldResult (createVscaleMultiple (dimSize));
1020
+ return OpFoldResult (IntegerAttr::get (indexType, dimSize));
1021
+ });
1022
+ }
1023
+ return failure ();
1024
+ }
1025
+
1026
+ // / Scalable vector lowering of transfer_write(transpose). This lowering only
1027
+ // / supports rank 2 (scalable) vectors, but can be used in conjunction with
1028
+ // / `UnrollTransferWriteConversion` to support n-D cases. The unroll conversion
1029
+ // / unrolls until the first scalable dimension.
1030
+ // /
1031
+ // / Example:
1032
+ // /
1033
+ // / BEFORE:
1034
+ // / ```mlir
1035
+ // / %transpose = vector.transpose %vec, [1, 0]
1036
+ // / : vector<4x[4]xf32> to vector<[4]x4xf32>
1037
+ // / vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]}
1038
+ // / : vector<[4]x4xf32>, memref<?x?xf32>
1039
+ // / ```
1040
+ // /
1041
+ // / AFTER:
1042
+ // / ```mlir
1043
+ // / %c1 = arith.constant 1 : index
1044
+ // / %c4 = arith.constant 4 : index
1045
+ // / %c0 = arith.constant 0 : index
1046
+ // / %0 = vector.extract %arg0[0] : vector<[4]xf32> from vector<4x[4]xf32>
1047
+ // / %1 = vector.extract %arg0[1] : vector<[4]xf32> from vector<4x[4]xf32>
1048
+ // / %2 = vector.extract %arg0[2] : vector<[4]xf32> from vector<4x[4]xf32>
1049
+ // / %3 = vector.extract %arg0[3] : vector<[4]xf32> from vector<4x[4]xf32>
1050
+ // / %vscale = vector.vscale
1051
+ // / %c4_vscale = arith.muli %vscale, %c4 : index
1052
+ // / scf.for %idx = %c0 to %c4_vscale step %c1 {
1053
+ // / %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
1054
+ // / %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
1055
+ // / %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
1056
+ // / %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
1057
+ // / %slice_i = affine.apply #map(%idx)[%i]
1058
+ // / %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
1059
+ // / vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
1060
+ // / : vector<4xf32>, memref<?x?xf32>
1061
+ // / }
1062
+ // / ```
1063
+ struct ScalableTransposeTransferWriteConversion
1064
+ : VectorToSCFPattern<vector::TransferWriteOp> {
1065
+ using VectorToSCFPattern::VectorToSCFPattern;
1066
+
1067
+ LogicalResult matchAndRewrite (TransferWriteOp writeOp,
1068
+ PatternRewriter &rewriter) const override {
1069
+ if (failed (checkLowerTensors (writeOp, rewriter)))
1070
+ return failure ();
1071
+
1072
+ VectorType vectorType = writeOp.getVectorType ();
1073
+
1074
+ // Note: By comparing the scalable dims to an ArrayRef of length two this
1075
+ // implicitly checks the rank (is also two).
1076
+ ArrayRef<bool > scalableFlags = vectorType.getScalableDims ();
1077
+ if (scalableFlags != ArrayRef<bool >{true , false }) {
1078
+ return rewriter.notifyMatchFailure (
1079
+ writeOp, " expected vector of the form vector<[N]xMxty>" );
1080
+ }
1081
+
1082
+ auto permutationMap = writeOp.getPermutationMap ();
1083
+ if (!permutationMap.isIdentity ()) {
1084
+ return rewriter.notifyMatchFailure (
1085
+ writeOp, " non-identity permutations are unsupported (lower first)" );
1086
+ }
1087
+
1088
+ // Note: This pattern is only lowering the leading dimension (to a loop),
1089
+ // so we only check if the leading dimension is in bounds. The in-bounds
1090
+ // attribute for the trailing dimension will be propagated.
1091
+ if (!writeOp.isDimInBounds (0 )) {
1092
+ return rewriter.notifyMatchFailure (
1093
+ writeOp, " out-of-bounds dims are unsupported (use masking)" );
1094
+ }
1095
+
1096
+ Value vector = writeOp.getVector ();
1097
+ auto transposeOp = vector.getDefiningOp <vector::TransposeOp>();
1098
+ if (!transposeOp ||
1099
+ transposeOp.getPermutation () != ArrayRef<int64_t >{1 , 0 }) {
1100
+ return rewriter.notifyMatchFailure (writeOp, " source not transpose" );
1101
+ }
1102
+
1103
+ auto loc = writeOp.getLoc ();
1104
+ auto createVscaleMultiple =
1105
+ vector::makeVscaleConstantBuilder (rewriter, loc);
1106
+
1107
+ auto maskDims = getMaskDimSizes (writeOp.getMask (), createVscaleMultiple);
1108
+ if (failed (maskDims)) {
1109
+ return rewriter.notifyMatchFailure (writeOp,
1110
+ " failed to resolve mask dims" );
1111
+ }
1112
+
1113
+ int64_t fixedDimSize = vectorType.getDimSize (1 );
1114
+ auto fixedDimOffsets = llvm::seq (fixedDimSize);
1115
+
1116
+ // Extract all slices from the source of the transpose.
1117
+ auto transposeSource = transposeOp.getVector ();
1118
+ SmallVector<Value> transposeSourceSlices =
1119
+ llvm::map_to_vector (fixedDimOffsets, [&](int64_t idx) -> Value {
1120
+ return rewriter.create <vector::ExtractOp>(loc, transposeSource, idx);
1121
+ });
1122
+
1123
+ // Loop bounds and step.
1124
+ auto lb = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
1125
+ auto ub =
1126
+ maskDims->empty ()
1127
+ ? Value (createVscaleMultiple (vectorType.getDimSize (0 )))
1128
+ : vector::getAsValues (rewriter, loc, maskDims->front ()).front ();
1129
+ auto step = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
1130
+
1131
+ // Generate a new mask for the slice.
1132
+ VectorType sliceType = VectorType::Builder (vectorType).dropDim (0 );
1133
+ Value sliceMask = nullptr ;
1134
+ if (!maskDims->empty ()) {
1135
+ sliceMask = rewriter.create <vector::CreateMaskOp>(
1136
+ loc, sliceType.clone (rewriter.getI1Type ()),
1137
+ ArrayRef<OpFoldResult>(*maskDims).drop_front ());
1138
+ }
1139
+
1140
+ Value initDest = isTensorOp (writeOp) ? writeOp.getSource () : Value{};
1141
+ ValueRange initLoopArgs = initDest ? initDest : ValueRange{};
1142
+ auto result = rewriter.create <scf::ForOp>(
1143
+ loc, lb, ub, step, initLoopArgs,
1144
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange loopIterArgs) {
1145
+ // Indices for the new transfer op.
1146
+ SmallVector<Value, 8 > xferIndices;
1147
+ getXferIndices (b, writeOp, iv, xferIndices);
1148
+
1149
+ // Extract a transposed slice from the source vector.
1150
+ SmallVector<Value> transposeElements =
1151
+ llvm::map_to_vector (fixedDimOffsets, [&](int64_t idx) -> Value {
1152
+ return b.create <vector::ExtractOp>(
1153
+ loc, transposeSourceSlices[idx], iv);
1154
+ });
1155
+ auto sliceVec = b.create <vector::FromElementsOp>(loc, sliceType,
1156
+ transposeElements);
1157
+
1158
+ // Create the transfer_write for the slice.
1159
+ Value dest =
1160
+ loopIterArgs.empty () ? writeOp.getSource () : loopIterArgs.front ();
1161
+ auto newWriteOp = b.create <vector::TransferWriteOp>(
1162
+ loc, sliceVec, dest, xferIndices,
1163
+ ArrayRef<bool >(writeOp.getInBoundsValues ()).drop_front ());
1164
+ if (sliceMask)
1165
+ newWriteOp.getMaskMutable ().assign (sliceMask);
1166
+
1167
+ // Yield from the loop.
1168
+ b.create <scf::YieldOp>(loc, loopIterArgs.empty ()
1169
+ ? ValueRange{}
1170
+ : newWriteOp.getResult ());
1171
+ });
1172
+
1173
+ if (isTensorOp (writeOp))
1174
+ rewriter.replaceOp (writeOp, result);
1175
+ else
1176
+ rewriter.eraseOp (writeOp);
1177
+
1178
+ return success ();
1179
+ }
1180
+ };
1181
+
990
1182
} // namespace lowering_n_d
991
1183
992
1184
namespace lowering_n_d_unrolled {
@@ -1100,9 +1292,8 @@ struct UnrollTransferReadConversion
1100
1292
if (xferOp.getVectorType ().getRank () <= options.targetRank )
1101
1293
return rewriter.notifyMatchFailure (
1102
1294
xferOp, " vector rank is less or equal to target rank" );
1103
- if (isTensorOp (xferOp) && !options.lowerTensors )
1104
- return rewriter.notifyMatchFailure (
1105
- xferOp, " transfers operating on tensors are excluded" );
1295
+ if (failed (checkLowerTensors (xferOp, rewriter)))
1296
+ return failure ();
1106
1297
// Transfer ops that modify the element type are not supported atm.
1107
1298
if (xferOp.getVectorType ().getElementType () !=
1108
1299
xferOp.getShapedType ().getElementType ())
@@ -1238,7 +1429,7 @@ struct UnrollTransferWriteConversion
1238
1429
if (inputVectorTy.getRank () <= options.targetRank )
1239
1430
return failure ();
1240
1431
1241
- if (isTensorOp ( xferOp) && !options. lowerTensors )
1432
+ if (failed ( checkLowerTensors ( xferOp, rewriter)) )
1242
1433
return failure ();
1243
1434
// Transfer ops that modify the element type are not supported atm.
1244
1435
if (inputVectorTy.getElementType () !=
@@ -1503,7 +1694,10 @@ void mlir::populateVectorToSCFConversionPatterns(
1503
1694
lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1504
1695
patterns.getContext (), options);
1505
1696
}
1506
-
1697
+ if (options.lowerScalable ) {
1698
+ patterns.add <lowering_n_d::ScalableTransposeTransferWriteConversion>(
1699
+ patterns.getContext (), options);
1700
+ }
1507
1701
if (options.targetRank == 1 ) {
1508
1702
patterns.add <lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1509
1703
lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
@@ -1522,13 +1716,15 @@ struct ConvertVectorToSCFPass
1522
1716
this ->fullUnroll = options.unroll ;
1523
1717
this ->targetRank = options.targetRank ;
1524
1718
this ->lowerTensors = options.lowerTensors ;
1719
+ this ->lowerScalable = options.lowerScalable ;
1525
1720
}
1526
1721
1527
1722
void runOnOperation () override {
1528
1723
VectorTransferToSCFOptions options;
1529
1724
options.unroll = fullUnroll;
1530
1725
options.targetRank = targetRank;
1531
1726
options.lowerTensors = lowerTensors;
1727
+ options.lowerScalable = lowerScalable;
1532
1728
1533
1729
// Lower permutation maps first.
1534
1730
RewritePatternSet lowerTransferPatterns (&getContext ());
0 commit comments