24
24
#include " mlir/Dialect/Vector/IR/VectorOps.h"
25
25
#include " mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
26
26
#include " mlir/Dialect/Vector/Transforms/VectorTransforms.h"
27
+ #include " mlir/Dialect/Vector/Utils/VectorUtils.h"
27
28
#include " mlir/IR/Builders.h"
28
29
#include " mlir/IR/ImplicitLocOpBuilder.h"
29
30
#include " mlir/Pass/Pass.h"
@@ -987,6 +988,185 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
987
988
}
988
989
};
989
990
991
+ // / Retrieves the dimensions sizes of a mask. Currently supports CreateMaskOp
992
+ // / and ConstantMaskOp.
993
+ template <typename VscaleConstantBuilder>
994
+ static FailureOr<SmallVector<OpFoldResult>>
995
+ getMaskDimSizes (Value mask, VscaleConstantBuilder &createVscaleMultiple) {
996
+ if (!mask)
997
+ return SmallVector<OpFoldResult>{};
998
+ if (auto createMaskOp = mask.getDefiningOp <vector::CreateMaskOp>()) {
999
+ return llvm::map_to_vector (createMaskOp.getOperands (), [](Value dimSize) {
1000
+ return OpFoldResult (dimSize);
1001
+ });
1002
+ }
1003
+ if (auto constantMask = mask.getDefiningOp <vector::ConstantMaskOp>()) {
1004
+ int dimIdx = 0 ;
1005
+ VectorType maskType = constantMask.getVectorType ();
1006
+ auto indexType = IndexType::get (mask.getContext ());
1007
+ return llvm::map_to_vector (
1008
+ constantMask.getMaskDimSizes (), [&](int64_t dimSize) {
1009
+ // A scalable dim in a constant_mask means vscale x dimSize.
1010
+ if (maskType.getScalableDims ()[dimIdx++])
1011
+ return OpFoldResult (createVscaleMultiple (dimSize));
1012
+ return OpFoldResult (IntegerAttr::get (indexType, dimSize));
1013
+ });
1014
+ }
1015
+ return failure ();
1016
+ }
1017
+
1018
+ // / Scalable vector lowering of transfer_write(transpose). This lowering only
1019
+ // / supports rank 2 (scalable) vectors, but can be used in in conjunction with
1020
+ // / `UnrollTransferWriteConversion` to support n-D cases. The unroll conversion
1021
+ // / unrolls until the first scalable dimension.
1022
+ // /
1023
+ // / Example:
1024
+ // /
1025
+ // / BEFORE:
1026
+ // / ```mlir
1027
+ // / %transpose = vector.transpose %vec, [1, 0]
1028
+ // / : vector<4x[4]xf32> to vector<[4]x4xf32>
1029
+ // / vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]}
1030
+ // / : vector<[4]x4xf32>, memref<?x?xf32>
1031
+ // / ```
1032
+ // /
1033
+ // / AFTER:
1034
+ // / ```mlir
1035
+ // / %c1 = arith.constant 1 : index
1036
+ // / %c4 = arith.constant 4 : index
1037
+ // / %c0 = arith.constant 0 : index
1038
+ // / %0 = vector.extract %arg0[0] : vector<[4]xf32> from vector<4x[4]xf32>
1039
+ // / %1 = vector.extract %arg0[1] : vector<[4]xf32> from vector<4x[4]xf32>
1040
+ // / %2 = vector.extract %arg0[2] : vector<[4]xf32> from vector<4x[4]xf32>
1041
+ // / %3 = vector.extract %arg0[3] : vector<[4]xf32> from vector<4x[4]xf32>
1042
+ // / %vscale = vector.vscale
1043
+ // / %c4_vscale = arith.muli %vscale, %c4 : index
1044
+ // / scf.for %idx = %c0 to %c4_vscale step %c1 {
1045
+ // / %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
1046
+ // / %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
1047
+ // / %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
1048
+ // / %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
1049
+ // / %slice_i = affine.apply #map(%idx)[%i]
1050
+ // / %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
1051
+ // / vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
1052
+ // / : vector<4xf32>, memref<?x?xf32>
1053
+ // / }
1054
+ // / ```
1055
+ struct ScalableTransposeTransferWriteConversion
1056
+ : VectorToSCFPattern<vector::TransferWriteOp> {
1057
+ using VectorToSCFPattern::VectorToSCFPattern;
1058
+
1059
+ LogicalResult matchAndRewrite (TransferWriteOp writeOp,
1060
+ PatternRewriter &rewriter) const override {
1061
+ if (isTensorOp (writeOp) && !options.lowerTensors ) {
1062
+ return rewriter.notifyMatchFailure (
1063
+ writeOp, " lowering tensor transfers is disabled" );
1064
+ }
1065
+
1066
+ auto vector = writeOp.getVector ();
1067
+ auto vectorType = vector.getType ();
1068
+ auto scalableFlags = vectorType.getScalableDims ();
1069
+ if (scalableFlags != ArrayRef<bool >{true , false }) {
1070
+ return rewriter.notifyMatchFailure (
1071
+ writeOp, " expected vector of form vector<[*]x*xty>" );
1072
+ }
1073
+
1074
+ auto permutationMap = writeOp.getPermutationMap ();
1075
+ if (!permutationMap.isIdentity ()) {
1076
+ return rewriter.notifyMatchFailure (
1077
+ writeOp, " non-identity permutations are unsupported (lower first)" );
1078
+ }
1079
+
1080
+ if (!writeOp.isDimInBounds (0 )) {
1081
+ return rewriter.notifyMatchFailure (
1082
+ writeOp, " out-of-bounds dims are unsupported (use masking)" );
1083
+ }
1084
+
1085
+ auto transposeOp = vector.getDefiningOp <vector::TransposeOp>();
1086
+ if (!transposeOp ||
1087
+ transposeOp.getPermutation () != ArrayRef<int64_t >{1 , 0 }) {
1088
+ return rewriter.notifyMatchFailure (writeOp, " source not transpose" );
1089
+ }
1090
+
1091
+ auto loc = writeOp.getLoc ();
1092
+ auto createVscaleMultiple =
1093
+ vector::makeVscaleConstantBuilder (rewriter, loc);
1094
+
1095
+ auto maskDims = getMaskDimSizes (writeOp.getMask (), createVscaleMultiple);
1096
+ if (failed (maskDims)) {
1097
+ return rewriter.notifyMatchFailure (writeOp,
1098
+ " failed to resolve mask dims" );
1099
+ }
1100
+
1101
+ int64_t fixedDimSize = vectorType.getDimSize (1 );
1102
+ auto fixedDimOffsets = llvm::seq (fixedDimSize);
1103
+
1104
+ // Extract all slices from the source of the transpose.
1105
+ auto transposeSource = transposeOp.getVector ();
1106
+ SmallVector<Value> transposeSourceSlices =
1107
+ llvm::map_to_vector (fixedDimOffsets, [&](int64_t idx) -> Value {
1108
+ return rewriter.create <vector::ExtractOp>(loc, transposeSource, idx);
1109
+ });
1110
+
1111
+ // Loop bounds and step.
1112
+ auto lb = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
1113
+ auto ub =
1114
+ maskDims->empty ()
1115
+ ? Value (createVscaleMultiple (vectorType.getDimSize (0 )))
1116
+ : vector::getAsValues (rewriter, loc, maskDims->front ()).front ();
1117
+ auto step = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
1118
+
1119
+ // Generate a new mask for the slice.
1120
+ VectorType sliceType = VectorType::Builder (vectorType).dropDim (0 );
1121
+ Value sliceMask = nullptr ;
1122
+ if (!maskDims->empty ()) {
1123
+ sliceMask = rewriter.create <vector::CreateMaskOp>(
1124
+ loc, sliceType.clone (rewriter.getI1Type ()),
1125
+ ArrayRef<OpFoldResult>(*maskDims).drop_front ());
1126
+ }
1127
+
1128
+ ValueRange initLoopArgs =
1129
+ isTensorOp (writeOp) ? writeOp.getSource () : ValueRange{};
1130
+ auto result = rewriter.create <scf::ForOp>(
1131
+ loc, lb, ub, step, initLoopArgs,
1132
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange loopIterArgs) {
1133
+ // Indices for the new transfer op.
1134
+ SmallVector<Value, 8 > xferIndices;
1135
+ getXferIndices (b, writeOp, iv, xferIndices);
1136
+
1137
+ // Extract a transposed slice from the source vector.
1138
+ SmallVector<Value> transposeElements =
1139
+ llvm::map_to_vector (fixedDimOffsets, [&](int64_t idx) -> Value {
1140
+ return b.create <vector::ExtractOp>(
1141
+ loc, transposeSourceSlices[idx], iv);
1142
+ });
1143
+ auto sliceVec = b.create <vector::FromElementsOp>(loc, sliceType,
1144
+ transposeElements);
1145
+
1146
+ // Create the transfer_write for the slice.
1147
+ Value dest =
1148
+ loopIterArgs.empty () ? writeOp.getSource () : loopIterArgs.front ();
1149
+ auto newWriteOp = b.create <vector::TransferWriteOp>(
1150
+ loc, sliceVec, dest, xferIndices,
1151
+ ArrayRef<bool >(writeOp.getInBoundsValues ()).drop_front ());
1152
+ if (sliceMask)
1153
+ newWriteOp.getMaskMutable ().assign (sliceMask);
1154
+
1155
+ // Yield from the loop.
1156
+ b.create <scf::YieldOp>(loc, loopIterArgs.empty ()
1157
+ ? ValueRange{}
1158
+ : newWriteOp.getResult ());
1159
+ });
1160
+
1161
+ if (isTensorOp (writeOp))
1162
+ rewriter.replaceOp (writeOp, result);
1163
+ else
1164
+ rewriter.eraseOp (writeOp);
1165
+
1166
+ return success ();
1167
+ }
1168
+ };
1169
+
990
1170
} // namespace lowering_n_d
991
1171
992
1172
namespace lowering_n_d_unrolled {
@@ -1503,7 +1683,10 @@ void mlir::populateVectorToSCFConversionPatterns(
1503
1683
lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1504
1684
patterns.getContext (), options);
1505
1685
}
1506
-
1686
+ if (options.lowerScalable ) {
1687
+ patterns.add <lowering_n_d::ScalableTransposeTransferWriteConversion>(
1688
+ patterns.getContext (), options);
1689
+ }
1507
1690
if (options.targetRank == 1 ) {
1508
1691
patterns.add <lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1509
1692
lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
@@ -1522,13 +1705,15 @@ struct ConvertVectorToSCFPass
1522
1705
this ->fullUnroll = options.unroll ;
1523
1706
this ->targetRank = options.targetRank ;
1524
1707
this ->lowerTensors = options.lowerTensors ;
1708
+ this ->lowerScalable = options.lowerScalable ;
1525
1709
}
1526
1710
1527
1711
void runOnOperation () override {
1528
1712
VectorTransferToSCFOptions options;
1529
1713
options.unroll = fullUnroll;
1530
1714
options.targetRank = targetRank;
1531
1715
options.lowerTensors = lowerTensors;
1716
+ options.lowerScalable = lowerScalable;
1532
1717
1533
1718
// Lower permutation maps first.
1534
1719
RewritePatternSet lowerTransferPatterns (&getContext ());
0 commit comments