Skip to content

Commit c2b9529

Browse files
authored
[mlir][vector] Fix n-d transfer write distribution (llvm#83215)
Currently n-d transfer write distribution can be inconsistent with distribution of reductions if a value has multiple users, one of which is a transfer_write with a non-standard distribution map, and the other of which is a vector.reduction. We may want to consider removing the distribution map functionality in the future for this reason.
1 parent 87c0260 commit c2b9529

File tree

3 files changed

+65
-10
lines changed

3 files changed

+65
-10
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -443,15 +443,24 @@ static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
443443
/// d1) and return vector<16x2x64>
444444
static VectorType getDistributedType(VectorType originalType, AffineMap map,
445445
int64_t warpSize) {
446-
if (map.getNumResults() != 1)
447-
return VectorType();
448446
SmallVector<int64_t> targetShape(originalType.getShape().begin(),
449447
originalType.getShape().end());
450448
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
451449
unsigned position = map.getDimPosition(i);
452-
if (targetShape[position] % warpSize != 0)
453-
return VectorType();
450+
if (targetShape[position] % warpSize != 0) {
451+
if (warpSize % targetShape[position] != 0) {
452+
return VectorType();
453+
}
454+
warpSize /= targetShape[position];
455+
targetShape[position] = 1;
456+
continue;
457+
}
454458
targetShape[position] = targetShape[position] / warpSize;
459+
warpSize = 1;
460+
break;
461+
}
462+
if (warpSize != 1) {
463+
return VectorType();
455464
}
456465
VectorType targetType =
457466
VectorType::get(targetShape, originalType.getElementType());
@@ -526,7 +535,30 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
526535
// 4. Reindex the write using the distribution map.
527536
auto newWarpOp =
528537
newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
538+
539+
// Delinearize the lane id based on the way threads are divided across the
540+
// vector. To get the number of threads per vector dimension, divide the
541+
// sequential size by the distributed size along each dim.
529542
rewriter.setInsertionPoint(newWriteOp);
543+
SmallVector<OpFoldResult> delinearizedIdSizes;
544+
for (auto [seqSize, distSize] :
545+
llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
546+
assert(seqSize % distSize == 0 && "Invalid distributed vector shape");
547+
delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize));
548+
}
549+
SmallVector<Value> delinearized;
550+
if (map.getNumResults() > 1) {
551+
delinearized = rewriter
552+
.create<mlir::affine::AffineDelinearizeIndexOp>(
553+
newWarpOp.getLoc(), newWarpOp.getLaneid(),
554+
delinearizedIdSizes)
555+
.getResults();
556+
} else {
557+
// If there is only one map result, we can elide the delinearization
558+
// op and use the lane id directly.
559+
delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
560+
}
561+
530562
AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
531563
Location loc = newWriteOp.getLoc();
532564
SmallVector<Value> indices(newWriteOp.getIndices().begin(),
@@ -539,11 +571,11 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
539571
continue;
540572
unsigned indexPos = indexExpr.getPosition();
541573
unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
574+
Value laneId = delinearized[vectorPos];
542575
auto scale =
543576
rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
544577
indices[indexPos] = affine::makeComposedAffineApply(
545-
rewriter, loc, d0 + scale * d1,
546-
{indices[indexPos], newWarpOp.getLaneid()});
578+
rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
547579
}
548580
newWriteOp.getIndicesMutable().assign(indices);
549581

mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1559,3 +1559,28 @@ func.func @warp_propagate_multi_dim_create_mask(%laneid: index, %m0: index, %m1:
15591559
// CHECK-PROP: %[[DISTM0:.+]] = affine.apply #[[$SUBM0]]()[%[[M0]], %[[LANEID]]]
15601560
// CHECK-PROP: %[[DISTM1:.+]] = affine.apply #[[$SUBM1]]()[%[[M1]], %[[LANEID]]]
15611561
// CHECK-PROP: vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[M2]] : vector<1x2x4xi1>
1562+
1563+
// -----
1564+
1565+
func.func @warp_propagate_nd_write(%laneid: index, %dest: memref<4x1024xf32>) {
1566+
%c0 = arith.constant 0 : index
1567+
vector.warp_execute_on_lane_0(%laneid)[32] -> () {
1568+
%0 = "some_def"() : () -> (vector<4x1024xf32>)
1569+
vector.transfer_write %0, %dest[%c0, %c0] : vector<4x1024xf32>, memref<4x1024xf32>
1570+
vector.yield
1571+
}
1572+
return
1573+
}
1574+
1575+
// CHECK-DIST-AND-PROP: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 128)>
1576+
1577+
// CHECK-DIST-AND-PROP-LABEL: func.func @warp_propagate_nd_write(
1578+
// CHECK-DIST-AND-PROP: %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1x128xf32>) {
1579+
// CHECK-DIST-AND-PROP: %[[V0:.*]] = "some_def"
1580+
// CHECK-DIST-AND-PROP: vector.yield %[[V0]]
1581+
// CHECK-DIST-AND-PROP-SAME: vector<4x1024xf32>
1582+
// CHECK-DIST-AND-PROP: }
1583+
1584+
// CHECK-DIST-AND-PROP: %[[IDS:.+]]:2 = affine.delinearize_index %{{.*}} into (%c4, %c8) : index, index
1585+
// CHECK-DIST-AND-PROP: %[[INNER_ID:.+]] = affine.apply #map()[%[[IDS]]#1]
1586+
// CHECK-DIST-AND-PROP: vector.transfer_write %[[W]], %{{.*}}[%[[IDS]]#0, %[[INNER_ID]]] {{.*}} : vector<1x128xf32>

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -630,15 +630,13 @@ struct TestVectorDistribution
630630
});
631631
MLIRContext *ctx = &getContext();
632632
auto distributionFn = [](Value val) {
633-
// Create a map (d0, d1) -> (d1) to distribute along the inner
634-
// dimension. Once we support n-d distribution we can add more
635-
// complex cases.
633+
// Create an identity dim map of the same rank as the vector.
636634
VectorType vecType = dyn_cast<VectorType>(val.getType());
637635
int64_t vecRank = vecType ? vecType.getRank() : 0;
638636
OpBuilder builder(val.getContext());
639637
if (vecRank == 0)
640638
return AffineMap::get(val.getContext());
641-
return AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
639+
return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
642640
};
643641
auto shuffleFn = [](Location loc, OpBuilder &builder, Value val,
644642
Value srcIdx, int64_t warpSz) {

0 commit comments

Comments
 (0)