@@ -511,6 +511,8 @@ static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
511
511
// / Checks that the indices corresponding to dimensions starting at
512
512
// / `firstDimToCollapse` are constant 0, and writes to `outIndices`
513
513
// / the truncated indices where `firstDimToCollapse` is now the innermost dim.
514
+ // / TODO: Extract the logic that writes to outIndices so that this method
515
+ // / simply checks one pre-condition.
514
516
static LogicalResult
515
517
checkAndCollapseInnerZeroIndices (ValueRange indices, int64_t firstDimToCollapse,
516
518
SmallVector<Value> &outIndices) {
@@ -542,45 +544,100 @@ class FlattenContiguousRowMajorTransferReadPattern
542
544
auto loc = transferReadOp.getLoc ();
543
545
Value vector = transferReadOp.getVector ();
544
546
VectorType vectorType = cast<VectorType>(vector.getType ());
545
- Value source = transferReadOp.getSource ();
547
+ auto source = transferReadOp.getSource ();
546
548
MemRefType sourceType = dyn_cast<MemRefType>(source.getType ());
549
+
550
+ // 0. Check pre-conditions
547
551
// Contiguity check is valid on tensors only.
548
552
if (!sourceType)
549
553
return failure ();
554
+ // If this is already 0D/1D, there's nothing to do.
550
555
if (vectorType.getRank () <= 1 )
551
- // Already 0D/1D, nothing to do.
552
556
return failure ();
553
557
if (!vector::isContiguousSlice (sourceType, vectorType))
554
558
return failure ();
555
- int64_t firstContiguousInnerDim =
556
- sourceType.getRank () - vectorType.getRank ();
557
559
// TODO: generalize this pattern, relax the requirements here.
558
560
if (transferReadOp.hasOutOfBoundsDim ())
559
561
return failure ();
560
562
if (!transferReadOp.getPermutationMap ().isMinorIdentity ())
561
563
return failure ();
562
564
if (transferReadOp.getMask ())
563
565
return failure ();
566
+
564
567
SmallVector<Value> collapsedIndices;
565
- if (failed (checkAndCollapseInnerZeroIndices (transferReadOp.getIndices (),
566
- firstContiguousInnerDim,
567
- collapsedIndices)))
568
- return failure ();
568
+ int64_t firstDimToCollapse = sourceType.getRank () - vectorType.getRank ();
569
+
570
+ // 1. Collapse the source memref
569
571
Value collapsedSource =
570
- collapseInnerDims (rewriter, loc, source, firstContiguousInnerDim );
572
+ collapseInnerDims (rewriter, loc, source, firstDimToCollapse );
571
573
MemRefType collapsedSourceType =
572
574
dyn_cast<MemRefType>(collapsedSource.getType ());
573
575
int64_t collapsedRank = collapsedSourceType.getRank ();
574
- assert (collapsedRank == firstContiguousInnerDim + 1 );
576
+ assert (collapsedRank == firstDimToCollapse + 1 );
577
+
578
+ // 2. Generate input args for a new vector.transfer_read that will read
579
+ // from the collapsed memref.
580
+ // 2.1. New dim exprs + affine map
575
581
SmallVector<AffineExpr, 1 > dimExprs{
576
- getAffineDimExpr (firstContiguousInnerDim , rewriter.getContext ())};
582
+ getAffineDimExpr (firstDimToCollapse , rewriter.getContext ())};
577
583
auto collapsedMap =
578
584
AffineMap::get (collapsedRank, 0 , dimExprs, rewriter.getContext ());
585
+
586
+ // 2.2 New indices
587
+ // If all the collapsed indices are zero then no extra logic is needed.
588
+ // Otherwise, a new offset/index has to be computed.
589
+ if (failed (checkAndCollapseInnerZeroIndices (transferReadOp.getIndices (),
590
+ firstDimToCollapse,
591
+ collapsedIndices))) {
592
+ // Copy all the leading indices
593
+ collapsedIndices = transferReadOp.getIndices ();
594
+ collapsedIndices.resize (firstDimToCollapse);
595
+
596
+ // Compute the remaining trailing index/offset required for reading from
597
+ // the collapsed memref:
598
+ //
599
+ // offset = 0
600
+ // for (i = firstDimToCollapse; i < outputRank; ++i)
601
+ // offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
602
+ //
603
+ // For this example:
604
+ // %2 = vector.transfer_read %arg4[%c0, %arg0, %c0] (...) :
605
+ // memref<1x43x2xi32>, vector<1x2xi32>
606
+ // which would be collapsed to:
607
+ // %1 = vector.transfer_read %collapse_shape[%c0, %offset] (...) :
608
+ // memref<1x86xi32>, vector<2xi32>
609
+ // one would get the following offset:
610
+ // %offset = %arg0 * 43
611
+ AffineExpr offsetExpr, idxExpr;
612
+ bindSymbols (rewriter.getContext (), offsetExpr, idxExpr);
613
+
614
+ int64_t outputRank = transferReadOp.getIndices ().size ();
615
+ OpFoldResult offset =
616
+ rewriter.create <arith::ConstantIndexOp>(loc, 0 ).getResult ();
617
+
618
+ for (int64_t i = firstDimToCollapse; i < outputRank; ++i) {
619
+ int64_t dim = dyn_cast<ShapedType>(source.getType ()).getDimSize (i);
620
+ offset = affine::makeComposedFoldedAffineApply (
621
+ rewriter, loc, offsetExpr + dim * idxExpr,
622
+ {offset, transferReadOp.getIndices ()[i]});
623
+ }
624
+ if (offset.is <Value>()) {
625
+ collapsedIndices.push_back (offset.get <Value>());
626
+ } else {
627
+ collapsedIndices.push_back (rewriter.create <arith::ConstantIndexOp>(
628
+ loc, *getConstantIntValue (offset)));
629
+ }
630
+ }
631
+
632
+ // 3. Create new vector.transfer_read that reads from the collapsed memref
579
633
VectorType flatVectorType = VectorType::get ({vectorType.getNumElements ()},
580
634
vectorType.getElementType ());
581
635
vector::TransferReadOp flatRead = rewriter.create <vector::TransferReadOp>(
582
636
loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
583
637
flatRead.setInBoundsAttr (rewriter.getBoolArrayAttr ({true }));
638
+
639
+ // 4. Replace the old transfer_read with the new one reading from the
640
+ // collapsed shape
584
641
rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(
585
642
transferReadOp, cast<VectorType>(vector.getType ()), flatRead);
586
643
return success ();
0 commit comments