15
15
#include " mlir/Dialect/Arith/IR/Arith.h"
16
16
#include " mlir/Dialect/MemRef/IR/MemRef.h"
17
17
#include " mlir/Dialect/Tensor/IR/Tensor.h"
18
+ #include " mlir/Dialect/Utils/IndexingUtils.h"
18
19
#include " mlir/Dialect/Vector/IR/VectorOps.h"
19
20
#include " mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
20
21
#include " mlir/Dialect/Vector/Transforms/VectorTransforms.h"
@@ -577,7 +578,6 @@ class FlattenContiguousRowMajorTransferReadPattern
577
578
if (transferReadOp.getMask ())
578
579
return failure ();
579
580
580
- SmallVector<Value> collapsedIndices;
581
581
int64_t firstDimToCollapse = sourceType.getRank () - vectorType.getRank ();
582
582
583
583
// 1. Collapse the source memref
@@ -599,12 +599,14 @@ class FlattenContiguousRowMajorTransferReadPattern
599
599
// 2.2 New indices
600
600
// If all the collapsed indices are zero then no extra logic is needed.
601
601
// Otherwise, a new offset/index has to be computed.
602
+ SmallVector<Value> collapsedIndices;
602
603
if (failed (checkAndCollapseInnerZeroIndices (transferReadOp.getIndices (),
603
604
firstDimToCollapse,
604
605
collapsedIndices))) {
605
- // Copy all the leading indices
606
- collapsedIndices = transferReadOp.getIndices ();
607
- collapsedIndices.resize (firstDimToCollapse);
606
+ // Copy all the leading indices.
607
+ SmallVector<Value> indices = transferReadOp.getIndices ();
608
+ collapsedIndices.append (indices.begin (),
609
+ indices.begin () + firstDimToCollapse);
608
610
609
611
// Compute the remaining trailing index/offset required for reading from
610
612
// the collapsed memref:
@@ -621,24 +623,26 @@ class FlattenContiguousRowMajorTransferReadPattern
621
623
// memref<1x86xi32>, vector<2xi32>
622
624
// one would get the following offset:
623
625
// %offset = %arg0 * 43
624
- AffineExpr offsetExpr, idxExpr;
625
- bindSymbols (rewriter.getContext (), offsetExpr, idxExpr);
626
-
627
- int64_t outputRank = transferReadOp.getIndices ().size ();
628
- OpFoldResult offset =
626
+ OpFoldResult collapsedOffset =
629
627
rewriter.create <arith::ConstantIndexOp>(loc, 0 ).getResult ();
630
628
631
- for (int64_t i = firstDimToCollapse; i < outputRank; ++i) {
632
- int64_t dim = dyn_cast<ShapedType>(source.getType ()).getDimSize (i);
633
- offset = affine::makeComposedFoldedAffineApply (
634
- rewriter, loc, offsetExpr + dim * idxExpr,
635
- {offset, transferReadOp.getIndices ()[i]});
636
- }
637
- if (offset.is <Value>()) {
638
- collapsedIndices.push_back (offset.get <Value>());
629
+ auto sourceShape = sourceType.getShape ();
630
+ auto collapsedStrides = computeSuffixProduct (ArrayRef<int64_t >(
631
+ sourceShape.begin () + firstDimToCollapse, sourceShape.end ()));
632
+
633
+ // Compute the collapsed offset.
634
+ ArrayRef<Value> indicesToCollapse (indices.begin () + firstDimToCollapse,
635
+ indices.end ());
636
+ auto &&[collapsedExpr, collapsedVals] = computeLinearIndex (
637
+ collapsedOffset, collapsedStrides, indicesToCollapse);
638
+ collapsedOffset = affine::makeComposedFoldedAffineApply (
639
+ rewriter, loc, collapsedExpr, collapsedVals);
640
+
641
+ if (collapsedOffset.is <Value>()) {
642
+ collapsedIndices.push_back (collapsedOffset.get <Value>());
639
643
} else {
640
644
collapsedIndices.push_back (rewriter.create <arith::ConstantIndexOp>(
641
- loc, *getConstantIntValue (offset )));
645
+ loc, *getConstantIntValue (collapsedOffset )));
642
646
}
643
647
}
644
648
@@ -710,6 +714,7 @@ class FlattenContiguousRowMajorTransferWritePattern
710
714
firstContiguousInnerDim,
711
715
collapsedIndices)))
712
716
return failure ();
717
+
713
718
Value collapsedSource =
714
719
collapseInnerDims (rewriter, loc, source, firstContiguousInnerDim);
715
720
MemRefType collapsedSourceType =
0 commit comments