@@ -582,6 +582,15 @@ static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
582
582
583
583
namespace {
584
584
585
+ // / Helper functon to return the index of the last dynamic dimension in `shape`.
586
+ int64_t lastDynIndex (ArrayRef<int64_t > shape) {
587
+ return static_cast <int64_t >(
588
+ std::distance (
589
+ std::find (shape.rbegin (), shape.rend (), ShapedType::kDynamic ),
590
+ shape.rend ()) -
591
+ 1 );
592
+ }
593
+
585
594
// / Rewrites contiguous row-major vector.transfer_read ops by inserting
586
595
// / memref.collapse_shape on the source so that the resulting
587
596
// / vector.transfer_read has a 1D source. Requires the source shape to be
@@ -631,8 +640,9 @@ class FlattenContiguousRowMajorTransferReadPattern
631
640
return failure ();
632
641
633
642
// Determinine the first memref dimension to collapse
634
- int64_t firstDimToCollapse =
635
- sourceType.getRank () - sourceType.getMaxCollapsableTrailingDims ();
643
+ int64_t firstDimToCollapse = std::max (
644
+ lastDynIndex (sourceType.getShape ()),
645
+ sourceType.getRank () - sourceType.getMaxCollapsableTrailingDims ());
636
646
637
647
// 1. Collapse the source memref
638
648
Value collapsedSource =
@@ -725,8 +735,9 @@ class FlattenContiguousRowMajorTransferWritePattern
725
735
return failure ();
726
736
727
737
// Determinine the first memref dimension to collapse
728
- int64_t firstDimToCollapse =
729
- sourceType.getRank () - sourceType.getMaxCollapsableTrailingDims ();
738
+ int64_t firstDimToCollapse = std::max (
739
+ lastDynIndex (sourceType.getShape ()),
740
+ sourceType.getRank () - sourceType.getMaxCollapsableTrailingDims ());
730
741
731
742
// 1. Collapse the source memref
732
743
Value collapsedSource =
0 commit comments