Skip to content

Commit efac666

Browse files
[fixup] Don't try to collapse non-leftmost dynamic dimension
Even though it's possible in principle, the affected patterns need strides to be determined statically.
1 parent 1590fe6 commit efac666

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,15 @@ static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
582582

583583
namespace {
584584

585+
/// Helper functon to return the index of the last dynamic dimension in `shape`.
586+
int64_t lastDynIndex(ArrayRef<int64_t> shape) {
587+
return static_cast<int64_t>(
588+
std::distance(
589+
std::find(shape.rbegin(), shape.rend(), ShapedType::kDynamic),
590+
shape.rend()) -
591+
1);
592+
}
593+
585594
/// Rewrites contiguous row-major vector.transfer_read ops by inserting
586595
/// memref.collapse_shape on the source so that the resulting
587596
/// vector.transfer_read has a 1D source. Requires the source shape to be
@@ -631,8 +640,9 @@ class FlattenContiguousRowMajorTransferReadPattern
631640
return failure();
632641

633642
// Determinine the first memref dimension to collapse
634-
int64_t firstDimToCollapse =
635-
sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims();
643+
int64_t firstDimToCollapse = std::max(
644+
lastDynIndex(sourceType.getShape()),
645+
sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims());
636646

637647
// 1. Collapse the source memref
638648
Value collapsedSource =
@@ -725,8 +735,9 @@ class FlattenContiguousRowMajorTransferWritePattern
725735
return failure();
726736

727737
// Determinine the first memref dimension to collapse
728-
int64_t firstDimToCollapse =
729-
sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims();
738+
int64_t firstDimToCollapse = std::max(
739+
lastDynIndex(sourceType.getShape()),
740+
sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims());
730741

731742
// 1. Collapse the source memref
732743
Value collapsedSource =

0 commit comments

Comments
 (0)