@@ -443,15 +443,24 @@ static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
443
443
// / d1) and return vector<16x2x64>
444
444
static VectorType getDistributedType (VectorType originalType, AffineMap map,
445
445
int64_t warpSize) {
446
- if (map.getNumResults () != 1 )
447
- return VectorType ();
448
446
SmallVector<int64_t > targetShape (originalType.getShape ().begin (),
449
447
originalType.getShape ().end ());
450
448
for (unsigned i = 0 , e = map.getNumResults (); i < e; i++) {
451
449
unsigned position = map.getDimPosition (i);
452
- if (targetShape[position] % warpSize != 0 )
453
- return VectorType ();
450
+ if (targetShape[position] % warpSize != 0 ) {
451
+ if (warpSize % targetShape[position] != 0 ) {
452
+ return VectorType ();
453
+ }
454
+ warpSize /= targetShape[position];
455
+ targetShape[position] = 1 ;
456
+ continue ;
457
+ }
454
458
targetShape[position] = targetShape[position] / warpSize;
459
+ warpSize = 1 ;
460
+ break ;
461
+ }
462
+ if (warpSize != 1 ) {
463
+ return VectorType ();
455
464
}
456
465
VectorType targetType =
457
466
VectorType::get (targetShape, originalType.getElementType ());
@@ -526,7 +535,30 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
526
535
// 4. Reindex the write using the distribution map.
527
536
auto newWarpOp =
528
537
newWriteOp.getVector ().getDefiningOp <WarpExecuteOnLane0Op>();
538
+
539
+ // Delinearize the lane id based on the way threads are divided across the
540
+ // vector. To get the number of threads per vector dimension, divide the
541
+ // sequential size by the distributed size along each dim.
529
542
rewriter.setInsertionPoint (newWriteOp);
543
+ SmallVector<OpFoldResult> delinearizedIdSizes;
544
+ for (auto [seqSize, distSize] :
545
+ llvm::zip_equal (writtenVectorType.getShape (), targetType.getShape ())) {
546
+ assert (seqSize % distSize == 0 && " Invalid distributed vector shape" );
547
+ delinearizedIdSizes.push_back (rewriter.getIndexAttr (seqSize / distSize));
548
+ }
549
+ SmallVector<Value> delinearized;
550
+ if (map.getNumResults () > 1 ) {
551
+ delinearized = rewriter
552
+ .create <mlir::affine::AffineDelinearizeIndexOp>(
553
+ newWarpOp.getLoc (), newWarpOp.getLaneid (),
554
+ delinearizedIdSizes)
555
+ .getResults ();
556
+ } else {
557
+ // If there is only one map result, we can elide the delinearization
558
+ // op and use the lane id directly.
559
+ delinearized.append (targetType.getRank (), newWarpOp.getLaneid ());
560
+ }
561
+
530
562
AffineMap indexMap = map.compose (newWriteOp.getPermutationMap ());
531
563
Location loc = newWriteOp.getLoc ();
532
564
SmallVector<Value> indices (newWriteOp.getIndices ().begin (),
@@ -539,11 +571,11 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
539
571
continue ;
540
572
unsigned indexPos = indexExpr.getPosition ();
541
573
unsigned vectorPos = cast<AffineDimExpr>(std::get<1 >(it)).getPosition ();
574
+ Value laneId = delinearized[vectorPos];
542
575
auto scale =
543
576
rewriter.getAffineConstantExpr (targetType.getDimSize (vectorPos));
544
577
indices[indexPos] = affine::makeComposedAffineApply (
545
- rewriter, loc, d0 + scale * d1,
546
- {indices[indexPos], newWarpOp.getLaneid ()});
578
+ rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
547
579
}
548
580
newWriteOp.getIndicesMutable ().assign (indices);
549
581
0 commit comments