Skip to content

Commit 27a713f

Browse files
authored
[mlir][vector] Add scalable lowering for transfer_write(transpose) (#101353)
This specifically handles the case of a transpose from a vector type like `vector<8x[4]xf32>` to `vector<[4]x8xf32>`. Such transposes occur fairly frequently when scalably vectorizing `linalg.generic`s. There is no direct lowering for these (as types like `vector<[4]x8xf32>` cannot be represented in LLVM-IR). However, if the only use of the transpose is a write, then it is possible to lower the `transfer_write(transpose)` as a VLA loop. Example: ```mlir %transpose = vector.transpose %vec, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32> vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x4xf32>, memref<?x?xf32> ``` Becomes: ```mlir %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index %c0 = arith.constant 0 : index %0 = vector.extract %arg0[0] : vector<[4]xf32> from vector<4x[4]xf32> %1 = vector.extract %arg0[1] : vector<[4]xf32> from vector<4x[4]xf32> %2 = vector.extract %arg0[2] : vector<[4]xf32> from vector<4x[4]xf32> %3 = vector.extract %arg0[3] : vector<[4]xf32> from vector<4x[4]xf32> %vscale = vector.vscale %c4_vscale = arith.muli %vscale, %c4 : index scf.for %idx = %c0 to %c4_vscale step %c1 { %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32> %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32> %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32> %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32> %slice_i = affine.apply #map(%idx)[%i] %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32> vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32> } ```
1 parent 23c72e9 commit 27a713f

File tree

5 files changed

+346
-21
lines changed

5 files changed

+346
-21
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1300,7 +1300,9 @@ def ConvertVectorToSCF : Pass<"convert-vector-to-scf"> {
13001300
Option<"targetRank", "target-rank", "unsigned", /*default=*/"1",
13011301
"Target vector rank to which transfer ops should be lowered">,
13021302
Option<"lowerTensors", "lower-tensors", "bool", /*default=*/"false",
1303-
"Lower transfer ops that operate on tensors">
1303+
"Lower transfer ops that operate on tensors">,
1304+
Option<"lowerScalable", "lower-scalable", "bool", /*default=*/"false",
1305+
"Add scalable vector specific lowerings (that introduce loops)">
13041306
];
13051307
}
13061308

mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,14 @@ struct VectorTransferToSCFOptions {
6969
unroll = u;
7070
return *this;
7171
}
72+
/// Enable scalable vector specific lowerings (which introduce loops). These
73+
/// work alongside fullUnroll (which unrolls until the first scalable
74+
/// dimension).
75+
bool lowerScalable = false;
76+
VectorTransferToSCFOptions enableLowerScalable(bool enable = true) {
77+
lowerScalable = enable;
78+
return *this;
79+
}
7280
};
7381

7482
/// Collect a set of patterns to convert from the Vector dialect to SCF + func.

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Lines changed: 214 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2525
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
2626
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
27+
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
2728
#include "mlir/IR/Builders.h"
2829
#include "mlir/IR/ImplicitLocOpBuilder.h"
2930
#include "mlir/Pass/Pass.h"
@@ -44,6 +45,18 @@ namespace {
4445
/// Attribute name used for labeling transfer ops during progressive lowering.
4546
static const char kPassLabel[] = "__vector_to_scf_lowering__";
4647

48+
/// Return true if this transfer op operates on a source tensor.
49+
static bool isTensorOp(VectorTransferOpInterface xferOp) {
50+
if (isa<RankedTensorType>(xferOp.getShapedType())) {
51+
if (isa<vector::TransferWriteOp>(xferOp)) {
52+
// TransferWriteOps on tensors have a result.
53+
assert(xferOp->getNumResults() > 0);
54+
}
55+
return true;
56+
}
57+
return false;
58+
}
59+
4760
/// Patterns that inherit from this struct have access to
4861
/// VectorTransferToSCFOptions.
4962
template <typename OpTy>
@@ -52,6 +65,15 @@ struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
5265
VectorTransferToSCFOptions opt)
5366
: OpRewritePattern<OpTy>(context), options(opt) {}
5467

68+
LogicalResult checkLowerTensors(VectorTransferOpInterface xferOp,
69+
PatternRewriter &rewriter) const {
70+
if (isTensorOp(xferOp) && !options.lowerTensors) {
71+
return rewriter.notifyMatchFailure(
72+
xferOp, "lowering tensor transfers is disabled");
73+
}
74+
return success();
75+
}
76+
5577
VectorTransferToSCFOptions options;
5678
};
5779

@@ -257,19 +279,6 @@ static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
257279
newXferOp->setAttr(kPassLabel, b.getUnitAttr());
258280
}
259281

260-
/// Return true if this transfer op operates on a source tensor.
261-
template <typename OpTy>
262-
static bool isTensorOp(OpTy xferOp) {
263-
if (isa<RankedTensorType>(xferOp.getShapedType())) {
264-
if (xferOp.getOperationName() == TransferWriteOp::getOperationName()) {
265-
// TransferWriteOps on tensors have a result.
266-
assert(xferOp->getNumResults() > 0);
267-
}
268-
return true;
269-
}
270-
return false;
271-
}
272-
273282
namespace lowering_n_d {
274283

275284
/// Helper data structure for data and mask buffers.
@@ -987,6 +996,189 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
987996
}
988997
};
989998

999+
/// Retrieves the dimensions sizes of a mask. Currently supports CreateMaskOp
1000+
/// and ConstantMaskOp.
1001+
template <typename VscaleConstantBuilder>
1002+
static FailureOr<SmallVector<OpFoldResult>>
1003+
getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) {
1004+
if (!mask)
1005+
return SmallVector<OpFoldResult>{};
1006+
if (auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>()) {
1007+
return llvm::map_to_vector(createMaskOp.getOperands(), [](Value dimSize) {
1008+
return OpFoldResult(dimSize);
1009+
});
1010+
}
1011+
if (auto constantMask = mask.getDefiningOp<vector::ConstantMaskOp>()) {
1012+
int dimIdx = 0;
1013+
VectorType maskType = constantMask.getVectorType();
1014+
auto indexType = IndexType::get(mask.getContext());
1015+
return llvm::map_to_vector(
1016+
constantMask.getMaskDimSizes(), [&](int64_t dimSize) {
1017+
// A scalable dim in a constant_mask means vscale x dimSize.
1018+
if (maskType.getScalableDims()[dimIdx++])
1019+
return OpFoldResult(createVscaleMultiple(dimSize));
1020+
return OpFoldResult(IntegerAttr::get(indexType, dimSize));
1021+
});
1022+
}
1023+
return failure();
1024+
}
1025+
1026+
/// Scalable vector lowering of transfer_write(transpose). This lowering only
1027+
/// supports rank 2 (scalable) vectors, but can be used in conjunction with
1028+
/// `UnrollTransferWriteConversion` to support n-D cases. The unroll conversion
1029+
/// unrolls until the first scalable dimension.
1030+
///
1031+
/// Example:
1032+
///
1033+
/// BEFORE:
1034+
/// ```mlir
1035+
/// %transpose = vector.transpose %vec, [1, 0]
1036+
/// : vector<4x[4]xf32> to vector<[4]x4xf32>
1037+
/// vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]}
1038+
/// : vector<[4]x4xf32>, memref<?x?xf32>
1039+
/// ```
1040+
///
1041+
/// AFTER:
1042+
/// ```mlir
1043+
/// %c1 = arith.constant 1 : index
1044+
/// %c4 = arith.constant 4 : index
1045+
/// %c0 = arith.constant 0 : index
1046+
/// %0 = vector.extract %arg0[0] : vector<[4]xf32> from vector<4x[4]xf32>
1047+
/// %1 = vector.extract %arg0[1] : vector<[4]xf32> from vector<4x[4]xf32>
1048+
/// %2 = vector.extract %arg0[2] : vector<[4]xf32> from vector<4x[4]xf32>
1049+
/// %3 = vector.extract %arg0[3] : vector<[4]xf32> from vector<4x[4]xf32>
1050+
/// %vscale = vector.vscale
1051+
/// %c4_vscale = arith.muli %vscale, %c4 : index
1052+
/// scf.for %idx = %c0 to %c4_vscale step %c1 {
1053+
/// %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
1054+
/// %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
1055+
/// %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
1056+
/// %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
1057+
/// %slice_i = affine.apply #map(%idx)[%i]
1058+
/// %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
1059+
/// vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
1060+
/// : vector<4xf32>, memref<?x?xf32>
1061+
/// }
1062+
/// ```
1063+
struct ScalableTransposeTransferWriteConversion
1064+
: VectorToSCFPattern<vector::TransferWriteOp> {
1065+
using VectorToSCFPattern::VectorToSCFPattern;
1066+
1067+
LogicalResult matchAndRewrite(TransferWriteOp writeOp,
1068+
PatternRewriter &rewriter) const override {
1069+
if (failed(checkLowerTensors(writeOp, rewriter)))
1070+
return failure();
1071+
1072+
VectorType vectorType = writeOp.getVectorType();
1073+
1074+
// Note: By comparing the scalable dims to an ArrayRef of length two this
1075+
// implicitly checks the rank (is also two).
1076+
ArrayRef<bool> scalableFlags = vectorType.getScalableDims();
1077+
if (scalableFlags != ArrayRef<bool>{true, false}) {
1078+
return rewriter.notifyMatchFailure(
1079+
writeOp, "expected vector of the form vector<[N]xMxty>");
1080+
}
1081+
1082+
auto permutationMap = writeOp.getPermutationMap();
1083+
if (!permutationMap.isIdentity()) {
1084+
return rewriter.notifyMatchFailure(
1085+
writeOp, "non-identity permutations are unsupported (lower first)");
1086+
}
1087+
1088+
// Note: This pattern is only lowering the leading dimension (to a loop),
1089+
// so we only check if the leading dimension is in bounds. The in-bounds
1090+
// attribute for the trailing dimension will be propagated.
1091+
if (!writeOp.isDimInBounds(0)) {
1092+
return rewriter.notifyMatchFailure(
1093+
writeOp, "out-of-bounds dims are unsupported (use masking)");
1094+
}
1095+
1096+
Value vector = writeOp.getVector();
1097+
auto transposeOp = vector.getDefiningOp<vector::TransposeOp>();
1098+
if (!transposeOp ||
1099+
transposeOp.getPermutation() != ArrayRef<int64_t>{1, 0}) {
1100+
return rewriter.notifyMatchFailure(writeOp, "source not transpose");
1101+
}
1102+
1103+
auto loc = writeOp.getLoc();
1104+
auto createVscaleMultiple =
1105+
vector::makeVscaleConstantBuilder(rewriter, loc);
1106+
1107+
auto maskDims = getMaskDimSizes(writeOp.getMask(), createVscaleMultiple);
1108+
if (failed(maskDims)) {
1109+
return rewriter.notifyMatchFailure(writeOp,
1110+
"failed to resolve mask dims");
1111+
}
1112+
1113+
int64_t fixedDimSize = vectorType.getDimSize(1);
1114+
auto fixedDimOffsets = llvm::seq(fixedDimSize);
1115+
1116+
// Extract all slices from the source of the transpose.
1117+
auto transposeSource = transposeOp.getVector();
1118+
SmallVector<Value> transposeSourceSlices =
1119+
llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value {
1120+
return rewriter.create<vector::ExtractOp>(loc, transposeSource, idx);
1121+
});
1122+
1123+
// Loop bounds and step.
1124+
auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1125+
auto ub =
1126+
maskDims->empty()
1127+
? Value(createVscaleMultiple(vectorType.getDimSize(0)))
1128+
: vector::getAsValues(rewriter, loc, maskDims->front()).front();
1129+
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
1130+
1131+
// Generate a new mask for the slice.
1132+
VectorType sliceType = VectorType::Builder(vectorType).dropDim(0);
1133+
Value sliceMask = nullptr;
1134+
if (!maskDims->empty()) {
1135+
sliceMask = rewriter.create<vector::CreateMaskOp>(
1136+
loc, sliceType.clone(rewriter.getI1Type()),
1137+
ArrayRef<OpFoldResult>(*maskDims).drop_front());
1138+
}
1139+
1140+
Value initDest = isTensorOp(writeOp) ? writeOp.getSource() : Value{};
1141+
ValueRange initLoopArgs = initDest ? initDest : ValueRange{};
1142+
auto result = rewriter.create<scf::ForOp>(
1143+
loc, lb, ub, step, initLoopArgs,
1144+
[&](OpBuilder &b, Location loc, Value iv, ValueRange loopIterArgs) {
1145+
// Indices for the new transfer op.
1146+
SmallVector<Value, 8> xferIndices;
1147+
getXferIndices(b, writeOp, iv, xferIndices);
1148+
1149+
// Extract a transposed slice from the source vector.
1150+
SmallVector<Value> transposeElements =
1151+
llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value {
1152+
return b.create<vector::ExtractOp>(
1153+
loc, transposeSourceSlices[idx], iv);
1154+
});
1155+
auto sliceVec = b.create<vector::FromElementsOp>(loc, sliceType,
1156+
transposeElements);
1157+
1158+
// Create the transfer_write for the slice.
1159+
Value dest =
1160+
loopIterArgs.empty() ? writeOp.getSource() : loopIterArgs.front();
1161+
auto newWriteOp = b.create<vector::TransferWriteOp>(
1162+
loc, sliceVec, dest, xferIndices,
1163+
ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front());
1164+
if (sliceMask)
1165+
newWriteOp.getMaskMutable().assign(sliceMask);
1166+
1167+
// Yield from the loop.
1168+
b.create<scf::YieldOp>(loc, loopIterArgs.empty()
1169+
? ValueRange{}
1170+
: newWriteOp.getResult());
1171+
});
1172+
1173+
if (isTensorOp(writeOp))
1174+
rewriter.replaceOp(writeOp, result);
1175+
else
1176+
rewriter.eraseOp(writeOp);
1177+
1178+
return success();
1179+
}
1180+
};
1181+
9901182
} // namespace lowering_n_d
9911183

9921184
namespace lowering_n_d_unrolled {
@@ -1100,9 +1292,8 @@ struct UnrollTransferReadConversion
11001292
if (xferOp.getVectorType().getRank() <= options.targetRank)
11011293
return rewriter.notifyMatchFailure(
11021294
xferOp, "vector rank is less or equal to target rank");
1103-
if (isTensorOp(xferOp) && !options.lowerTensors)
1104-
return rewriter.notifyMatchFailure(
1105-
xferOp, "transfers operating on tensors are excluded");
1295+
if (failed(checkLowerTensors(xferOp, rewriter)))
1296+
return failure();
11061297
// Transfer ops that modify the element type are not supported atm.
11071298
if (xferOp.getVectorType().getElementType() !=
11081299
xferOp.getShapedType().getElementType())
@@ -1238,7 +1429,7 @@ struct UnrollTransferWriteConversion
12381429
if (inputVectorTy.getRank() <= options.targetRank)
12391430
return failure();
12401431

1241-
if (isTensorOp(xferOp) && !options.lowerTensors)
1432+
if (failed(checkLowerTensors(xferOp, rewriter)))
12421433
return failure();
12431434
// Transfer ops that modify the element type are not supported atm.
12441435
if (inputVectorTy.getElementType() !=
@@ -1503,7 +1694,10 @@ void mlir::populateVectorToSCFConversionPatterns(
15031694
lowering_n_d::TransferOpConversion<TransferWriteOp>>(
15041695
patterns.getContext(), options);
15051696
}
1506-
1697+
if (options.lowerScalable) {
1698+
patterns.add<lowering_n_d::ScalableTransposeTransferWriteConversion>(
1699+
patterns.getContext(), options);
1700+
}
15071701
if (options.targetRank == 1) {
15081702
patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
15091703
lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
@@ -1522,13 +1716,15 @@ struct ConvertVectorToSCFPass
15221716
this->fullUnroll = options.unroll;
15231717
this->targetRank = options.targetRank;
15241718
this->lowerTensors = options.lowerTensors;
1719+
this->lowerScalable = options.lowerScalable;
15251720
}
15261721

15271722
void runOnOperation() override {
15281723
VectorTransferToSCFOptions options;
15291724
options.unroll = fullUnroll;
15301725
options.targetRank = targetRank;
15311726
options.lowerTensors = lowerTensors;
1727+
options.lowerScalable = lowerScalable;
15321728

15331729
// Lower permutation maps first.
15341730
RewritePatternSet lowerTransferPatterns(&getContext());

mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf{lower-tensors=true}))" -split-input-file -allow-unregistered-dialect | FileCheck %s
1+
// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf{lower-tensors=true lower-scalable=true}))" -split-input-file -allow-unregistered-dialect | FileCheck %s
22

33
// CHECK-LABEL: func @transfer_read_2d(
44
// CHECK: %[[ALLOC:.*]] = memref.alloca() : memref<vector<4x9xf32>>
@@ -36,3 +36,16 @@ func.func @transfer_write_2d(%A : tensor<?x?xf32>, %vec : vector<2x3xf32>,
3636
return %t : tensor<?x?xf32>
3737
}
3838

39+
// -----
40+
41+
// CHECK-LABEL: func @scalable_transpose_store
42+
// CHECK-SAME: %[[TENSOR:[a-z0-9]+]]: tensor<?x?xf32>
43+
// CHECK: %[[RESULT:.*]] = scf.for {{.*}} iter_args(%[[ITER_ARG:.*]] = %[[TENSOR]]) -> (tensor<?x?xf32>)
44+
// CHECK: %[[WRITE_SLICE:.*]] = vector.transfer_write %{{.*}} %[[ITER_ARG]]
45+
// CHECK: scf.yield %[[WRITE_SLICE]]
46+
// CHECK: return %[[RESULT]]
47+
func.func @scalable_transpose_store(%vec: vector<4x[4]xf32>, %A: tensor<?x?xf32>, %base1: index, %base2: index) -> tensor<?x?xf32> {
48+
%transpose = vector.transpose %vec, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
49+
%result = vector.transfer_write %transpose, %A[%base1, %base2] {in_bounds = [true, true]} : vector<[4]x4xf32>, tensor<?x?xf32>
50+
return %result : tensor<?x?xf32>
51+
}

0 commit comments

Comments
 (0)