@@ -505,25 +505,61 @@ static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
505
505
return rewriter.create <memref::CollapseShapeOp>(loc, input, reassociation);
506
506
}
507
507
508
- // / Checks that the indices corresponding to dimensions starting at
509
- // / `firstDimToCollapse` are constant 0, and writes to `outIndices`
510
- // / the truncated indices where `firstDimToCollapse` is now the innermost dim.
511
- // / TODO: Extract the logic that writes to outIndices so that this method
512
- // / simply checks one pre-condition.
513
- static LogicalResult
514
- checkAndCollapseInnerZeroIndices (ValueRange indices, int64_t firstDimToCollapse,
515
- SmallVector<Value> &outIndices) {
516
- int64_t rank = indices.size ();
517
- if (firstDimToCollapse >= rank)
518
- return failure ();
519
- for (int64_t i = firstDimToCollapse; i < rank; ++i) {
520
- std::optional<int64_t > cst = getConstantIntValue (indices[i]);
521
- if (!cst || cst.value () != 0 )
522
- return failure ();
508
+ // / Returns the new indices that collapses the inner dimensions starting from
509
+ // / the `firstDimToCollapse` dimension.
510
+ static SmallVector<Value> getCollapsedIndices (RewriterBase &rewriter,
511
+ Location loc,
512
+ ArrayRef<int64_t > shape,
513
+ ValueRange indices,
514
+ int64_t firstDimToCollapse) {
515
+ assert (firstDimToCollapse < static_cast <int64_t >(indices.size ()));
516
+
517
+ // If all the collapsed indices are zero then no extra logic is needed.
518
+ // Otherwise, a new offset/index has to be computed.
519
+ SmallVector<Value> indicesAfterCollapsing (
520
+ indices.begin (), indices.begin () + firstDimToCollapse);
521
+ SmallVector<Value> indicesToCollapse (indices.begin () + firstDimToCollapse,
522
+ indices.end ());
523
+ if (llvm::all_of (indicesToCollapse, isZeroIndex)) {
524
+ indicesAfterCollapsing.push_back (indicesToCollapse[0 ]);
525
+ return indicesAfterCollapsing;
526
+ }
527
+
528
+ // Compute the remaining trailing index/offset required for reading from
529
+ // the collapsed memref:
530
+ //
531
+ // offset = 0
532
+ // for (i = firstDimToCollapse; i < outputRank; ++i)
533
+ // offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
534
+ //
535
+ // For this example:
536
+ // %2 = vector.transfer_read/write %arg4[%c0, %arg0, %c0] (...) :
537
+ // memref<1x43x2xi32>, vector<1x2xi32>
538
+ // which would be collapsed to:
539
+ // %1 = vector.transfer_read/write %collapse_shape[%c0, %offset] (...) :
540
+ // memref<1x86xi32>, vector<2xi32>
541
+ // one would get the following offset:
542
+ // %offset = %arg0 * 43
543
+ OpFoldResult collapsedOffset =
544
+ rewriter.create <arith::ConstantIndexOp>(loc, 0 ).getResult ();
545
+
546
+ auto collapsedStrides = computeSuffixProduct (
547
+ ArrayRef<int64_t >(shape.begin () + firstDimToCollapse, shape.end ()));
548
+
549
+ // Compute the collapsed offset.
550
+ auto &&[collapsedExpr, collapsedVals] =
551
+ computeLinearIndex (collapsedOffset, collapsedStrides, indicesToCollapse);
552
+ collapsedOffset = affine::makeComposedFoldedAffineApply (
553
+ rewriter, loc, collapsedExpr, collapsedVals);
554
+
555
+ if (collapsedOffset.is <Value>()) {
556
+ indicesAfterCollapsing.push_back (collapsedOffset.get <Value>());
557
+ } else {
558
+ indicesAfterCollapsing.push_back (rewriter.create <arith::ConstantIndexOp>(
559
+ loc, *getConstantIntValue (collapsedOffset)));
523
560
}
524
- outIndices = indices;
525
- outIndices.resize (firstDimToCollapse + 1 );
526
- return success ();
561
+
562
+ return indicesAfterCollapsing;
527
563
}
528
564
529
565
namespace {
@@ -594,54 +630,9 @@ class FlattenContiguousRowMajorTransferReadPattern
594
630
AffineMap::get (collapsedRank, 0 , dimExprs, rewriter.getContext ());
595
631
596
632
// 2.2 New indices
597
- // If all the collapsed indices are zero then no extra logic is needed.
598
- // Otherwise, a new offset/index has to be computed.
599
- SmallVector<Value> collapsedIndices;
600
- if (failed (checkAndCollapseInnerZeroIndices (transferReadOp.getIndices (),
601
- firstDimToCollapse,
602
- collapsedIndices))) {
603
- // Copy all the leading indices.
604
- SmallVector<Value> indices = transferReadOp.getIndices ();
605
- collapsedIndices.append (indices.begin (),
606
- indices.begin () + firstDimToCollapse);
607
-
608
- // Compute the remaining trailing index/offset required for reading from
609
- // the collapsed memref:
610
- //
611
- // offset = 0
612
- // for (i = firstDimToCollapse; i < outputRank; ++i)
613
- // offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
614
- //
615
- // For this example:
616
- // %2 = vector.transfer_read %arg4[%c0, %arg0, %c0] (...) :
617
- // memref<1x43x2xi32>, vector<1x2xi32>
618
- // which would be collapsed to:
619
- // %1 = vector.transfer_read %collapse_shape[%c0, %offset] (...) :
620
- // memref<1x86xi32>, vector<2xi32>
621
- // one would get the following offset:
622
- // %offset = %arg0 * 43
623
- OpFoldResult collapsedOffset =
624
- rewriter.create <arith::ConstantIndexOp>(loc, 0 ).getResult ();
625
-
626
- auto sourceShape = sourceType.getShape ();
627
- auto collapsedStrides = computeSuffixProduct (ArrayRef<int64_t >(
628
- sourceShape.begin () + firstDimToCollapse, sourceShape.end ()));
629
-
630
- // Compute the collapsed offset.
631
- ArrayRef<Value> indicesToCollapse (indices.begin () + firstDimToCollapse,
632
- indices.end ());
633
- auto &&[collapsedExpr, collapsedVals] = computeLinearIndex (
634
- collapsedOffset, collapsedStrides, indicesToCollapse);
635
- collapsedOffset = affine::makeComposedFoldedAffineApply (
636
- rewriter, loc, collapsedExpr, collapsedVals);
637
-
638
- if (collapsedOffset.is <Value>()) {
639
- collapsedIndices.push_back (collapsedOffset.get <Value>());
640
- } else {
641
- collapsedIndices.push_back (rewriter.create <arith::ConstantIndexOp>(
642
- loc, *getConstantIntValue (collapsedOffset)));
643
- }
644
- }
633
+ SmallVector<Value> collapsedIndices =
634
+ getCollapsedIndices (rewriter, loc, sourceType.getShape (),
635
+ transferReadOp.getIndices (), firstDimToCollapse);
645
636
646
637
// 3. Create new vector.transfer_read that reads from the collapsed memref
647
638
VectorType flatVectorType = VectorType::get ({vectorType.getNumElements ()},
@@ -697,31 +688,31 @@ class FlattenContiguousRowMajorTransferWritePattern
697
688
return failure ();
698
689
if (!vector::isContiguousSlice (sourceType, vectorType))
699
690
return failure ();
700
- int64_t firstContiguousInnerDim =
701
- sourceType.getRank () - vectorType.getRank ();
691
+ int64_t firstDimToCollapse = sourceType.getRank () - vectorType.getRank ();
702
692
// TODO: generalize this pattern, relax the requirements here.
703
693
if (transferWriteOp.hasOutOfBoundsDim ())
704
694
return failure ();
705
695
if (!transferWriteOp.getPermutationMap ().isMinorIdentity ())
706
696
return failure ();
707
697
if (transferWriteOp.getMask ())
708
698
return failure ();
709
- SmallVector<Value> collapsedIndices;
710
- if (failed (checkAndCollapseInnerZeroIndices (transferWriteOp.getIndices (),
711
- firstContiguousInnerDim,
712
- collapsedIndices)))
713
- return failure ();
699
+
700
+ SmallVector<Value> collapsedIndices =
701
+ getCollapsedIndices (rewriter, loc, sourceType.getShape (),
702
+ transferWriteOp.getIndices (), firstDimToCollapse);
714
703
715
704
Value collapsedSource =
716
- collapseInnerDims (rewriter, loc, source, firstContiguousInnerDim );
705
+ collapseInnerDims (rewriter, loc, source, firstDimToCollapse );
717
706
MemRefType collapsedSourceType =
718
707
cast<MemRefType>(collapsedSource.getType ());
719
708
int64_t collapsedRank = collapsedSourceType.getRank ();
720
- assert (collapsedRank == firstContiguousInnerDim + 1 );
709
+ assert (collapsedRank == firstDimToCollapse + 1 );
710
+
721
711
SmallVector<AffineExpr, 1 > dimExprs{
722
- getAffineDimExpr (firstContiguousInnerDim , rewriter.getContext ())};
712
+ getAffineDimExpr (firstDimToCollapse , rewriter.getContext ())};
723
713
auto collapsedMap =
724
714
AffineMap::get (collapsedRank, 0 , dimExprs, rewriter.getContext ());
715
+
725
716
VectorType flatVectorType = VectorType::get ({vectorType.getNumElements ()},
726
717
vectorType.getElementType ());
727
718
Value flatVector =
0 commit comments