13
13
#include " mlir/Dialect/Vector/IR/VectorOps.h"
14
14
#include " mlir/Dialect/Vector/Transforms/VectorDistribution.h"
15
15
#include " mlir/IR/AffineExpr.h"
16
+ #include " mlir/Transforms/RegionUtils.h"
16
17
#include " mlir/Transforms/SideEffectUtils.h"
17
18
#include " llvm/ADT/SetVector.h"
18
19
#include < utility>
@@ -421,6 +422,31 @@ static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
421
422
return newWriteOp;
422
423
}
423
424
425
+ // / Return the distributed vector type based on the original type and the
426
+ // / distribution map. The map is expected to have a dimension equal to the
427
+ // / original type rank and should be a projection where the results are the
428
+ // / distributed dimensions. The number of results should be equal to the number
429
+ // / of warp sizes which is currently limited to 1.
430
+ // / Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
431
+ // / and a warp size of 16 would distribute the second dimension (associated to
432
+ // / d1) and return vector<16x2x64>
433
+ static VectorType getDistributedType (VectorType originalType, AffineMap map,
434
+ int64_t warpSize) {
435
+ if (map.getNumResults () != 1 )
436
+ return VectorType ();
437
+ SmallVector<int64_t > targetShape (originalType.getShape ().begin (),
438
+ originalType.getShape ().end ());
439
+ for (unsigned i = 0 , e = map.getNumResults (); i < e; i++) {
440
+ unsigned position = map.getDimPosition (i);
441
+ if (targetShape[position] % warpSize != 0 )
442
+ return VectorType ();
443
+ targetShape[position] = targetShape[position] / warpSize;
444
+ }
445
+ VectorType targetType =
446
+ VectorType::get (targetShape, originalType.getElementType ());
447
+ return targetType;
448
+ }
449
+
424
450
// / Distribute transfer_write ops based on the affine map returned by
425
451
// / `distributionMapFn`.
426
452
// / Example:
@@ -456,29 +482,19 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
456
482
if (writtenVectorType.getRank () == 0 )
457
483
return failure ();
458
484
459
- // 2. Compute the distribution map.
460
- AffineMap map = distributionMapFn (writeOp);
461
- if (map.getNumResults () != 1 )
462
- return writeOp->emitError (" multi-dim distribution not implemented yet" );
463
-
464
- // 3. Compute the targetType using the distribution map.
465
- SmallVector<int64_t > targetShape (writtenVectorType.getShape ().begin (),
466
- writtenVectorType.getShape ().end ());
467
- for (unsigned i = 0 , e = map.getNumResults (); i < e; i++) {
468
- unsigned position = map.getDimPosition (i);
469
- if (targetShape[position] % warpOp.getWarpSize () != 0 )
470
- return failure ();
471
- targetShape[position] = targetShape[position] / warpOp.getWarpSize ();
472
- }
485
+ // 2. Compute the distributed type.
486
+ AffineMap map = distributionMapFn (writeOp.getVector ());
473
487
VectorType targetType =
474
- VectorType::get (targetShape, writtenVectorType.getElementType ());
488
+ getDistributedType (writtenVectorType, map, warpOp.getWarpSize ());
489
+ if (!targetType)
490
+ return failure ();
475
491
476
- // 4 . clone the write into a new WarpExecuteOnLane0Op to separate it from
492
+ // 3 . clone the write into a new WarpExecuteOnLane0Op to separate it from
477
493
// the rest.
478
494
vector::TransferWriteOp newWriteOp =
479
495
cloneWriteOp (rewriter, warpOp, writeOp, targetType);
480
496
481
- // 5 . Reindex the write using the distribution map.
497
+ // 4 . Reindex the write using the distribution map.
482
498
auto newWarpOp =
483
499
newWriteOp.getVector ().getDefiningOp <WarpExecuteOnLane0Op>();
484
500
rewriter.setInsertionPoint (newWriteOp);
@@ -494,7 +510,8 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
494
510
continue ;
495
511
unsigned indexPos = indexExpr.getPosition ();
496
512
unsigned vectorPos = std::get<1 >(it).cast <AffineDimExpr>().getPosition ();
497
- auto scale = rewriter.getAffineConstantExpr (targetShape[vectorPos]);
513
+ auto scale =
514
+ rewriter.getAffineConstantExpr (targetType.getDimSize (vectorPos));
498
515
indices[indexPos] =
499
516
makeComposedAffineApply (rewriter, loc, d0 + scale * d1,
500
517
{indices[indexPos], newWarpOp.getLaneid ()});
@@ -956,6 +973,10 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
956
973
// / }
957
974
// / ```
958
975
struct WarpOpScfForOp : public OpRewritePattern <WarpExecuteOnLane0Op> {
976
+
977
+ WarpOpScfForOp (MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1 )
978
+ : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
979
+ distributionMapFn (std::move(fn)) {}
959
980
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
960
981
LogicalResult matchAndRewrite (WarpExecuteOnLane0Op warpOp,
961
982
PatternRewriter &rewriter) const override {
@@ -966,35 +987,78 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
966
987
auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
967
988
if (!forOp)
968
989
return failure ();
990
+ // Collect Values that come from the warp op but are outside the forOp.
991
+ // Those Value needs to be returned by the original warpOp and passed to the
992
+ // new op.
993
+ llvm::SmallSetVector<Value, 32 > escapingValues;
994
+ SmallVector<Type> inputTypes;
995
+ SmallVector<Type> distTypes;
996
+ mlir::visitUsedValuesDefinedAbove (
997
+ forOp.getBodyRegion (), [&](OpOperand *operand) {
998
+ Operation *parent = operand->get ().getParentRegion ()->getParentOp ();
999
+ if (warpOp->isAncestor (parent)) {
1000
+ if (!escapingValues.insert (operand->get ()))
1001
+ return ;
1002
+ Type distType = operand->get ().getType ();
1003
+ if (auto vecType = distType.cast <VectorType>()) {
1004
+ AffineMap map = distributionMapFn (operand->get ());
1005
+ distType = getDistributedType (vecType, map, warpOp.getWarpSize ());
1006
+ }
1007
+ inputTypes.push_back (operand->get ().getType ());
1008
+ distTypes.push_back (distType);
1009
+ }
1010
+ });
1011
+
1012
+ SmallVector<size_t > newRetIndices;
1013
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
1014
+ rewriter, warpOp, escapingValues.getArrayRef (), distTypes,
1015
+ newRetIndices);
1016
+ yield = cast<vector::YieldOp>(
1017
+ newWarpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
1018
+
969
1019
SmallVector<Value> newOperands;
970
1020
SmallVector<unsigned > resultIdx;
971
1021
// Collect all the outputs coming from the forOp.
972
1022
for (OpOperand &yieldOperand : yield->getOpOperands ()) {
973
1023
if (yieldOperand.get ().getDefiningOp () != forOp.getOperation ())
974
1024
continue ;
975
1025
auto forResult = yieldOperand.get ().cast <OpResult>();
976
- newOperands.push_back (warpOp.getResult (yieldOperand.getOperandNumber ()));
1026
+ newOperands.push_back (
1027
+ newWarpOp.getResult (yieldOperand.getOperandNumber ()));
977
1028
yieldOperand.set (forOp.getIterOperands ()[forResult.getResultNumber ()]);
978
1029
resultIdx.push_back (yieldOperand.getOperandNumber ());
979
1030
}
1031
+
980
1032
OpBuilder::InsertionGuard g (rewriter);
981
- rewriter.setInsertionPointAfter (warpOp);
1033
+ rewriter.setInsertionPointAfter (newWarpOp);
1034
+
982
1035
// Create a new for op outside the region with a WarpExecuteOnLane0Op region
983
1036
// inside.
984
1037
auto newForOp = rewriter.create <scf::ForOp>(
985
1038
forOp.getLoc (), forOp.getLowerBound (), forOp.getUpperBound (),
986
1039
forOp.getStep (), newOperands);
987
1040
rewriter.setInsertionPoint (newForOp.getBody (), newForOp.getBody ()->begin ());
1041
+
1042
+ SmallVector<Value> warpInput (newForOp.getRegionIterArgs ().begin (),
1043
+ newForOp.getRegionIterArgs ().end ());
1044
+ SmallVector<Type> warpInputType (forOp.getResultTypes ().begin (),
1045
+ forOp.getResultTypes ().end ());
1046
+ llvm::SmallDenseMap<Value, int64_t > argIndexMapping;
1047
+ for (auto [i, retIdx] : llvm::enumerate (newRetIndices)) {
1048
+ warpInput.push_back (newWarpOp.getResult (retIdx));
1049
+ argIndexMapping[escapingValues[i]] = warpInputType.size ();
1050
+ warpInputType.push_back (inputTypes[i]);
1051
+ }
988
1052
auto innerWarp = rewriter.create <WarpExecuteOnLane0Op>(
989
- warpOp.getLoc (), newForOp.getResultTypes (), warpOp.getLaneid (),
990
- warpOp.getWarpSize (), newForOp.getRegionIterArgs (),
991
- forOp.getResultTypes ());
1053
+ newWarpOp.getLoc (), newForOp.getResultTypes (), newWarpOp.getLaneid (),
1054
+ newWarpOp.getWarpSize (), warpInput, warpInputType);
992
1055
993
1056
SmallVector<Value> argMapping;
994
1057
argMapping.push_back (newForOp.getInductionVar ());
995
1058
for (Value args : innerWarp.getBody ()->getArguments ()) {
996
1059
argMapping.push_back (args);
997
1060
}
1061
+ argMapping.resize (forOp.getBody ()->getNumArguments ());
998
1062
SmallVector<Value> yieldOperands;
999
1063
for (Value operand : forOp.getBody ()->getTerminator ()->getOperands ())
1000
1064
yieldOperands.push_back (operand);
@@ -1008,12 +1072,23 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
1008
1072
rewriter.eraseOp (forOp);
1009
1073
// Replace the warpOp result coming from the original ForOp.
1010
1074
for (const auto &res : llvm::enumerate (resultIdx)) {
1011
- warpOp .getResult (res.value ())
1075
+ newWarpOp .getResult (res.value ())
1012
1076
.replaceAllUsesWith (newForOp.getResult (res.index ()));
1013
- newForOp->setOperand (res.index () + 3 , warpOp .getResult (res.value ()));
1077
+ newForOp->setOperand (res.index () + 3 , newWarpOp .getResult (res.value ()));
1014
1078
}
1079
+ newForOp.walk ([&](Operation *op) {
1080
+ for (OpOperand &operand : op->getOpOperands ()) {
1081
+ auto it = argIndexMapping.find (operand.get ());
1082
+ if (it == argIndexMapping.end ())
1083
+ continue ;
1084
+ operand.set (innerWarp.getBodyRegion ().getArgument (it->second ));
1085
+ }
1086
+ });
1015
1087
return success ();
1016
1088
}
1089
+
1090
+ private:
1091
+ DistributionMapFn distributionMapFn;
1017
1092
};
1018
1093
1019
1094
// / A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
@@ -1119,11 +1194,14 @@ void mlir::vector::populateDistributeTransferWriteOpPatterns(
1119
1194
}
1120
1195
1121
1196
void mlir::vector::populatePropagateWarpVectorDistributionPatterns (
1122
- RewritePatternSet &patterns, PatternBenefit benefit) {
1197
+ RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1198
+ PatternBenefit benefit) {
1123
1199
patterns.add <WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
1124
1200
WarpOpBroadcast, WarpOpExtract, WarpOpExtractElement,
1125
- WarpOpForwardOperand, WarpOpScfForOp, WarpOpConstant>(
1126
- patterns.getContext (), benefit);
1201
+ WarpOpForwardOperand, WarpOpConstant>(patterns.getContext (),
1202
+ benefit);
1203
+ patterns.add <WarpOpScfForOp>(patterns.getContext (), distributionMapFn,
1204
+ benefit);
1127
1205
}
1128
1206
1129
1207
void mlir::vector::populateDistributeReduction (
0 commit comments