Skip to content

Commit 91f62f0

Browse files
committed
[mlir][vector] Fix distribution of scf.for with value coming from above
When a value used in the forOp is defined outside the region but within the parent warpOp we need to return and distribute the value to pass it to new operations created within the loop. Also simplify the lambda interface. Differential Revision: https://reviews.llvm.org/D137146
1 parent 7fdf356 commit 91f62f0

File tree

4 files changed

+159
-42
lines changed

4 files changed

+159
-42
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ void populateWarpExecuteOnLane0OpToScfForPattern(
4040
const WarpExecuteOnLane0LoweringOptions &options,
4141
PatternBenefit benefit = 1);
4242

43-
using DistributionMapFn = std::function<AffineMap(vector::TransferWriteOp)>;
43+
using DistributionMapFn = std::function<AffineMap(Value)>;
4444

4545
/// Distribute transfer_write ops based on the affine map returned by
4646
/// `distributionMapFn`.
@@ -67,9 +67,12 @@ void populateDistributeTransferWriteOpPatterns(
6767
/// region.
6868
void moveScalarUniformCode(WarpExecuteOnLane0Op op);
6969

70-
/// Collect patterns to propagate warp distribution.
70+
/// Collect patterns to propagate warp distribution. `distributionMapFn` is used
71+
/// to decide how a value should be distributed when this cannot be inferred
72+
/// from its uses.
7173
void populatePropagateWarpVectorDistributionPatterns(
72-
RewritePatternSet &pattern, PatternBenefit benefit = 1);
74+
RewritePatternSet &pattern, const DistributionMapFn &distributionMapFn,
75+
PatternBenefit benefit = 1);
7376

7477
/// Lambda signature to compute a reduction of a distributed value for the given
7578
/// reduction kind and size.

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 106 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1414
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
1515
#include "mlir/IR/AffineExpr.h"
16+
#include "mlir/Transforms/RegionUtils.h"
1617
#include "mlir/Transforms/SideEffectUtils.h"
1718
#include "llvm/ADT/SetVector.h"
1819
#include <utility>
@@ -421,6 +422,31 @@ static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
421422
return newWriteOp;
422423
}
423424

425+
/// Return the distributed vector type based on the original type and the
426+
/// distribution map. The map is expected to have a dimension equal to the
427+
/// original type rank and should be a projection where the results are the
428+
/// distributed dimensions. The number of results should be equal to the number
429+
/// of warp sizes which is currently limited to 1.
430+
/// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
431+
/// and a warp size of 16 would distribute the second dimension (associated to
432+
/// d1) and return vector<16x2x64>
433+
static VectorType getDistributedType(VectorType originalType, AffineMap map,
434+
int64_t warpSize) {
435+
if (map.getNumResults() != 1)
436+
return VectorType();
437+
SmallVector<int64_t> targetShape(originalType.getShape().begin(),
438+
originalType.getShape().end());
439+
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
440+
unsigned position = map.getDimPosition(i);
441+
if (targetShape[position] % warpSize != 0)
442+
return VectorType();
443+
targetShape[position] = targetShape[position] / warpSize;
444+
}
445+
VectorType targetType =
446+
VectorType::get(targetShape, originalType.getElementType());
447+
return targetType;
448+
}
449+
424450
/// Distribute transfer_write ops based on the affine map returned by
425451
/// `distributionMapFn`.
426452
/// Example:
@@ -456,29 +482,19 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
456482
if (writtenVectorType.getRank() == 0)
457483
return failure();
458484

459-
// 2. Compute the distribution map.
460-
AffineMap map = distributionMapFn(writeOp);
461-
if (map.getNumResults() != 1)
462-
return writeOp->emitError("multi-dim distribution not implemented yet");
463-
464-
// 3. Compute the targetType using the distribution map.
465-
SmallVector<int64_t> targetShape(writtenVectorType.getShape().begin(),
466-
writtenVectorType.getShape().end());
467-
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
468-
unsigned position = map.getDimPosition(i);
469-
if (targetShape[position] % warpOp.getWarpSize() != 0)
470-
return failure();
471-
targetShape[position] = targetShape[position] / warpOp.getWarpSize();
472-
}
485+
// 2. Compute the distributed type.
486+
AffineMap map = distributionMapFn(writeOp.getVector());
473487
VectorType targetType =
474-
VectorType::get(targetShape, writtenVectorType.getElementType());
488+
getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
489+
if (!targetType)
490+
return failure();
475491

476-
// 4. clone the write into a new WarpExecuteOnLane0Op to separate it from
492+
// 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
477493
// the rest.
478494
vector::TransferWriteOp newWriteOp =
479495
cloneWriteOp(rewriter, warpOp, writeOp, targetType);
480496

481-
// 5. Reindex the write using the distribution map.
497+
// 4. Reindex the write using the distribution map.
482498
auto newWarpOp =
483499
newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
484500
rewriter.setInsertionPoint(newWriteOp);
@@ -494,7 +510,8 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
494510
continue;
495511
unsigned indexPos = indexExpr.getPosition();
496512
unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
497-
auto scale = rewriter.getAffineConstantExpr(targetShape[vectorPos]);
513+
auto scale =
514+
rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
498515
indices[indexPos] =
499516
makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
500517
{indices[indexPos], newWarpOp.getLaneid()});
@@ -956,6 +973,10 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
956973
/// }
957974
/// ```
958975
struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
976+
977+
WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
978+
: OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
979+
distributionMapFn(std::move(fn)) {}
959980
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
960981
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
961982
PatternRewriter &rewriter) const override {
@@ -966,35 +987,78 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
966987
auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
967988
if (!forOp)
968989
return failure();
990+
// Collect Values that come from the warp op but are outside the forOp.
991+
// Those Value needs to be returned by the original warpOp and passed to the
992+
// new op.
993+
llvm::SmallSetVector<Value, 32> escapingValues;
994+
SmallVector<Type> inputTypes;
995+
SmallVector<Type> distTypes;
996+
mlir::visitUsedValuesDefinedAbove(
997+
forOp.getBodyRegion(), [&](OpOperand *operand) {
998+
Operation *parent = operand->get().getParentRegion()->getParentOp();
999+
if (warpOp->isAncestor(parent)) {
1000+
if (!escapingValues.insert(operand->get()))
1001+
return;
1002+
Type distType = operand->get().getType();
1003+
if (auto vecType = distType.cast<VectorType>()) {
1004+
AffineMap map = distributionMapFn(operand->get());
1005+
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1006+
}
1007+
inputTypes.push_back(operand->get().getType());
1008+
distTypes.push_back(distType);
1009+
}
1010+
});
1011+
1012+
SmallVector<size_t> newRetIndices;
1013+
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1014+
rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
1015+
newRetIndices);
1016+
yield = cast<vector::YieldOp>(
1017+
newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1018+
9691019
SmallVector<Value> newOperands;
9701020
SmallVector<unsigned> resultIdx;
9711021
// Collect all the outputs coming from the forOp.
9721022
for (OpOperand &yieldOperand : yield->getOpOperands()) {
9731023
if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
9741024
continue;
9751025
auto forResult = yieldOperand.get().cast<OpResult>();
976-
newOperands.push_back(warpOp.getResult(yieldOperand.getOperandNumber()));
1026+
newOperands.push_back(
1027+
newWarpOp.getResult(yieldOperand.getOperandNumber()));
9771028
yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]);
9781029
resultIdx.push_back(yieldOperand.getOperandNumber());
9791030
}
1031+
9801032
OpBuilder::InsertionGuard g(rewriter);
981-
rewriter.setInsertionPointAfter(warpOp);
1033+
rewriter.setInsertionPointAfter(newWarpOp);
1034+
9821035
// Create a new for op outside the region with a WarpExecuteOnLane0Op region
9831036
// inside.
9841037
auto newForOp = rewriter.create<scf::ForOp>(
9851038
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
9861039
forOp.getStep(), newOperands);
9871040
rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
1041+
1042+
SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
1043+
newForOp.getRegionIterArgs().end());
1044+
SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
1045+
forOp.getResultTypes().end());
1046+
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1047+
for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
1048+
warpInput.push_back(newWarpOp.getResult(retIdx));
1049+
argIndexMapping[escapingValues[i]] = warpInputType.size();
1050+
warpInputType.push_back(inputTypes[i]);
1051+
}
9881052
auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
989-
warpOp.getLoc(), newForOp.getResultTypes(), warpOp.getLaneid(),
990-
warpOp.getWarpSize(), newForOp.getRegionIterArgs(),
991-
forOp.getResultTypes());
1053+
newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
1054+
newWarpOp.getWarpSize(), warpInput, warpInputType);
9921055

9931056
SmallVector<Value> argMapping;
9941057
argMapping.push_back(newForOp.getInductionVar());
9951058
for (Value args : innerWarp.getBody()->getArguments()) {
9961059
argMapping.push_back(args);
9971060
}
1061+
argMapping.resize(forOp.getBody()->getNumArguments());
9981062
SmallVector<Value> yieldOperands;
9991063
for (Value operand : forOp.getBody()->getTerminator()->getOperands())
10001064
yieldOperands.push_back(operand);
@@ -1008,12 +1072,23 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
10081072
rewriter.eraseOp(forOp);
10091073
// Replace the warpOp result coming from the original ForOp.
10101074
for (const auto &res : llvm::enumerate(resultIdx)) {
1011-
warpOp.getResult(res.value())
1075+
newWarpOp.getResult(res.value())
10121076
.replaceAllUsesWith(newForOp.getResult(res.index()));
1013-
newForOp->setOperand(res.index() + 3, warpOp.getResult(res.value()));
1077+
newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
10141078
}
1079+
newForOp.walk([&](Operation *op) {
1080+
for (OpOperand &operand : op->getOpOperands()) {
1081+
auto it = argIndexMapping.find(operand.get());
1082+
if (it == argIndexMapping.end())
1083+
continue;
1084+
operand.set(innerWarp.getBodyRegion().getArgument(it->second));
1085+
}
1086+
});
10151087
return success();
10161088
}
1089+
1090+
private:
1091+
DistributionMapFn distributionMapFn;
10171092
};
10181093

10191094
/// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
@@ -1119,11 +1194,14 @@ void mlir::vector::populateDistributeTransferWriteOpPatterns(
11191194
}
11201195

11211196
void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
1122-
RewritePatternSet &patterns, PatternBenefit benefit) {
1197+
RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1198+
PatternBenefit benefit) {
11231199
patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
11241200
WarpOpBroadcast, WarpOpExtract, WarpOpExtractElement,
1125-
WarpOpForwardOperand, WarpOpScfForOp, WarpOpConstant>(
1126-
patterns.getContext(), benefit);
1201+
WarpOpForwardOperand, WarpOpConstant>(patterns.getContext(),
1202+
benefit);
1203+
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
1204+
benefit);
11271205
}
11281206

11291207
void mlir::vector::populateDistributeReduction(

mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,40 @@ func.func @warp_scf_for(%arg0: index) {
349349

350350
// -----
351351

352+
// CHECK-PROP-LABEL: func @warp_scf_for_use_from_above(
353+
// CHECK-PROP: %[[INI:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) {
354+
// CHECK-PROP: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
355+
// CHECK-PROP: %[[USE:.*]] = "some_def_above"() : () -> vector<128xf32>
356+
// CHECK-PROP: vector.yield %[[INI1]], %[[USE]] : vector<128xf32>, vector<128xf32>
357+
// CHECK-PROP: }
358+
// CHECK-PROP: %[[F:.*]] = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[FARG:.*]] = %[[INI]]#0) -> (vector<4xf32>) {
359+
// CHECK-PROP: %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] args(%[[FARG]], %[[INI]]#1 : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>) {
360+
// CHECK-PROP: ^bb0(%[[ARG0:.*]]: vector<128xf32>, %[[ARG1:.*]]: vector<128xf32>):
361+
// CHECK-PROP: %[[ACC:.*]] = "some_def"(%[[ARG0]], %[[ARG1]]) : (vector<128xf32>, vector<128xf32>) -> vector<128xf32>
362+
// CHECK-PROP: vector.yield %[[ACC]] : vector<128xf32>
363+
// CHECK-PROP: }
364+
// CHECK-PROP: scf.yield %[[W]] : vector<4xf32>
365+
// CHECK-PROP: }
366+
// CHECK-PROP: "some_use"(%[[F]]) : (vector<4xf32>) -> ()
367+
func.func @warp_scf_for_use_from_above(%arg0: index) {
368+
%c128 = arith.constant 128 : index
369+
%c1 = arith.constant 1 : index
370+
%c0 = arith.constant 0 : index
371+
%0 = vector.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
372+
%ini = "some_def"() : () -> (vector<128xf32>)
373+
%use_from_above = "some_def_above"() : () -> (vector<128xf32>)
374+
%3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini) -> (vector<128xf32>) {
375+
%acc = "some_def"(%arg4, %use_from_above) : (vector<128xf32>, vector<128xf32>) -> (vector<128xf32>)
376+
scf.yield %acc : vector<128xf32>
377+
}
378+
vector.yield %3 : vector<128xf32>
379+
}
380+
"some_use"(%0) : (vector<4xf32>) -> ()
381+
return
382+
}
383+
384+
// -----
385+
352386
// CHECK-PROP-LABEL: func @warp_scf_for_swap(
353387
// CHECK-PROP: %[[INI:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) {
354388
// CHECK-PROP: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -746,24 +746,26 @@ struct TestVectorDistribution
746746
}
747747
});
748748
MLIRContext *ctx = &getContext();
749+
auto distributionFn = [](Value val) {
750+
// Create a map (d0, d1) -> (d1) to distribute along the inner
751+
// dimension. Once we support n-d distribution we can add more
752+
// complex cases.
753+
VectorType vecType = val.getType().dyn_cast<VectorType>();
754+
int64_t vecRank = vecType ? vecType.getRank() : 0;
755+
OpBuilder builder(val.getContext());
756+
if (vecRank == 0)
757+
return AffineMap::get(val.getContext());
758+
return AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
759+
};
749760
if (distributeTransferWriteOps) {
750-
auto distributionFn = [](vector::TransferWriteOp writeOp) {
751-
// Create a map (d0, d1) -> (d1) to distribute along the inner
752-
// dimension. Once we support n-d distribution we can add more
753-
// complex cases.
754-
int64_t vecRank = writeOp.getVectorType().getRank();
755-
OpBuilder builder(writeOp.getContext());
756-
auto map =
757-
AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
758-
return map;
759-
};
760761
RewritePatternSet patterns(ctx);
761762
populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
762763
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
763764
}
764765
if (propagateDistribution) {
765766
RewritePatternSet patterns(ctx);
766-
vector::populatePropagateWarpVectorDistributionPatterns(patterns);
767+
vector::populatePropagateWarpVectorDistributionPatterns(patterns,
768+
distributionFn);
767769
vector::populateDistributeReduction(patterns, warpReduction);
768770
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
769771
}

0 commit comments

Comments
 (0)