Skip to content

Commit 96c907d

Browse files
Revert "[mlir][memref] memref.subview: Verify result strides" (#80116)
Reverts #79865 I think there is a bug in the stride computation in `SubViewOp::inferResultType`. (Was already there before this change.) Reverting this commit for now and updating the original pull request with a fix and more test cases.
1 parent f852503 commit 96c907d

File tree

5 files changed

+65
-90
lines changed

5 files changed

+65
-90
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 52 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -917,7 +917,7 @@ static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
917917
/// This accounts for cases where there are multiple unit-dims, but only a
918918
/// subset of those are dropped. For MemRefTypes these can be disambiguated
919919
/// using the strides. If a dimension is dropped the stride must be dropped too.
920-
static FailureOr<llvm::SmallBitVector>
920+
static std::optional<llvm::SmallBitVector>
921921
computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
922922
ArrayRef<OpFoldResult> sizes) {
923923
llvm::SmallBitVector unusedDims(originalType.getRank());
@@ -941,7 +941,7 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
941941
getStridesAndOffset(originalType, originalStrides, originalOffset)) ||
942942
failed(
943943
getStridesAndOffset(reducedType, candidateStrides, candidateOffset)))
944-
return failure();
944+
return std::nullopt;
945945

946946
// For memrefs, a dimension is truly dropped if its corresponding stride is
947947
// also dropped. This is particularly important when more than one of the dims
@@ -976,22 +976,22 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
976976
candidateStridesNumOccurences[originalStride]) {
977977
// This should never happen. Cant have a stride in the reduced rank type
978978
// that wasnt in the original one.
979-
return failure();
979+
return std::nullopt;
980980
}
981981
}
982982

983983
if ((int64_t)unusedDims.count() + reducedType.getRank() !=
984984
originalType.getRank())
985-
return failure();
985+
return std::nullopt;
986986
return unusedDims;
987987
}
988988

989989
llvm::SmallBitVector SubViewOp::getDroppedDims() {
990990
MemRefType sourceType = getSourceType();
991991
MemRefType resultType = getType();
992-
FailureOr<llvm::SmallBitVector> unusedDims =
992+
std::optional<llvm::SmallBitVector> unusedDims =
993993
computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
994-
assert(succeeded(unusedDims) && "unable to find unused dims of subview");
994+
assert(unusedDims && "unable to find unused dims of subview");
995995
return *unusedDims;
996996
}
997997

@@ -2745,7 +2745,7 @@ void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
27452745
/// For ViewLikeOpInterface.
27462746
Value SubViewOp::getViewSource() { return getSource(); }
27472747

2748-
/// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
2748+
/// Return true if t1 and t2 have equal offsets (both dynamic or of same
27492749
/// static value).
27502750
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
27512751
int64_t t1Offset, t2Offset;
@@ -2755,41 +2755,56 @@ static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
27552755
return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
27562756
}
27572757

2758-
/// Return true if `t1` and `t2` have equal strides (both dynamic or of same
2759-
/// static value).
2760-
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2) {
2761-
int64_t t1Offset, t2Offset;
2762-
SmallVector<int64_t> t1Strides, t2Strides;
2763-
auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
2764-
auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
2765-
if (failed(res1) || failed(res2))
2766-
return false;
2767-
for (auto [s1, s2] : llvm::zip_equal(t1Strides, t2Strides))
2768-
if (s1 != s2)
2769-
return false;
2770-
return true;
2758+
/// Checks if `original` Type type can be rank reduced to `reduced` type.
2759+
/// This function is slight variant of `is subsequence` algorithm where
2760+
/// not matching dimension must be 1.
2761+
static SliceVerificationResult
2762+
isRankReducedMemRefType(MemRefType originalType,
2763+
MemRefType candidateRankReducedType,
2764+
ArrayRef<OpFoldResult> sizes) {
2765+
auto partialRes = isRankReducedType(originalType, candidateRankReducedType);
2766+
if (partialRes != SliceVerificationResult::Success)
2767+
return partialRes;
2768+
2769+
auto optionalUnusedDimsMask = computeMemRefRankReductionMask(
2770+
originalType, candidateRankReducedType, sizes);
2771+
2772+
// Sizes cannot be matched in case empty vector is returned.
2773+
if (!optionalUnusedDimsMask)
2774+
return SliceVerificationResult::LayoutMismatch;
2775+
2776+
if (originalType.getMemorySpace() !=
2777+
candidateRankReducedType.getMemorySpace())
2778+
return SliceVerificationResult::MemSpaceMismatch;
2779+
2780+
// No amount of stride dropping can reconcile incompatible offsets.
2781+
if (!haveCompatibleOffsets(originalType, candidateRankReducedType))
2782+
return SliceVerificationResult::LayoutMismatch;
2783+
2784+
return SliceVerificationResult::Success;
27712785
}
27722786

2787+
template <typename OpTy>
27732788
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result,
2774-
Operation *op, Type expectedType) {
2789+
OpTy op, Type expectedType) {
27752790
auto memrefType = llvm::cast<ShapedType>(expectedType);
27762791
switch (result) {
27772792
case SliceVerificationResult::Success:
27782793
return success();
27792794
case SliceVerificationResult::RankTooLarge:
2780-
return op->emitError("expected result rank to be smaller or equal to ")
2795+
return op.emitError("expected result rank to be smaller or equal to ")
27812796
<< "the source rank. ";
27822797
case SliceVerificationResult::SizeMismatch:
2783-
return op->emitError("expected result type to be ")
2798+
return op.emitError("expected result type to be ")
27842799
<< expectedType
27852800
<< " or a rank-reduced version. (mismatch of result sizes) ";
27862801
case SliceVerificationResult::ElemTypeMismatch:
2787-
return op->emitError("expected result element type to be ")
2802+
return op.emitError("expected result element type to be ")
27882803
<< memrefType.getElementType();
27892804
case SliceVerificationResult::MemSpaceMismatch:
2790-
return op->emitError("expected result and source memory spaces to match.");
2805+
return op.emitError("expected result and source memory spaces to match.");
27912806
case SliceVerificationResult::LayoutMismatch:
2792-
return op->emitError("expected result type to be ")
2807+
return op.emitError("expected result type to be ")
27932808
<< expectedType
27942809
<< " or a rank-reduced version. (mismatch of result layout) ";
27952810
}
@@ -2811,46 +2826,13 @@ LogicalResult SubViewOp::verify() {
28112826
if (!isStrided(baseType))
28122827
return emitError("base type ") << baseType << " is not strided";
28132828

2814-
// Compute the expected result type, assuming that there are no rank
2815-
// reductions.
2816-
auto expectedType = cast<MemRefType>(SubViewOp::inferResultType(
2817-
baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()));
2818-
2819-
// Verify all properties of a shaped type: rank, element type and dimension
2820-
// sizes. This takes into account potential rank reductions.
2821-
auto shapedTypeVerification = isRankReducedType(
2822-
/*originalType=*/expectedType, /*candidateReducedType=*/subViewType);
2823-
if (shapedTypeVerification != SliceVerificationResult::Success)
2824-
return produceSubViewErrorMsg(shapedTypeVerification, *this, expectedType);
2825-
2826-
// Make sure that the memory space did not change.
2827-
if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
2828-
return produceSubViewErrorMsg(SliceVerificationResult::MemSpaceMismatch,
2829-
*this, expectedType);
2830-
2831-
// Verify the offset of the layout map.
2832-
if (!haveCompatibleOffsets(expectedType, subViewType))
2833-
return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
2834-
*this, expectedType);
2835-
2836-
// The only thing that's left to verify now are the strides. First, compute
2837-
// the unused dimensions due to rank reductions. We have to look at sizes and
2838-
// strides to decide which dimensions were dropped. This function also
2839-
// partially verifies strides in case of rank reductions.
2840-
auto unusedDims = computeMemRefRankReductionMask(expectedType, subViewType,
2841-
getMixedSizes());
2842-
if (failed(unusedDims))
2843-
return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
2844-
*this, expectedType);
2845-
2846-
// Strides must match if there are no rank reductions.
2847-
// TODO: Verify strides when there are rank reductions. Strides are partially
2848-
// checked in `computeMemRefRankReductionMask`.
2849-
if (unusedDims->none() && !haveCompatibleStrides(expectedType, subViewType))
2850-
return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
2851-
*this, expectedType);
2829+
// Verify result type against inferred type.
2830+
auto expectedType = SubViewOp::inferResultType(
2831+
baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides());
28522832

2853-
return success();
2833+
auto result = isRankReducedMemRefType(llvm::cast<MemRefType>(expectedType),
2834+
subViewType, getMixedSizes());
2835+
return produceSubViewErrorMsg(result, *this, expectedType);
28542836
}
28552837

28562838
raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
@@ -2900,9 +2882,11 @@ static MemRefType getCanonicalSubViewResultType(
29002882
ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
29012883
auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
29022884
sourceType, mixedOffsets, mixedSizes, mixedStrides));
2903-
FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
2904-
currentSourceType, currentResultType, mixedSizes);
2905-
if (failed(unusedDims))
2885+
std::optional<llvm::SmallBitVector> unusedDims =
2886+
computeMemRefRankReductionMask(currentSourceType, currentResultType,
2887+
mixedSizes);
2888+
// Return nullptr as failure mode.
2889+
if (!unusedDims)
29062890
return nullptr;
29072891

29082892
auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());

mlir/test/Dialect/GPU/decompose-memrefs.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ func.func @decompose_subview(%arg0 : memref<?x?x?xf32>) {
119119
// CHECK: %[[IDX1:.*]] = affine.apply #[[MAP1]]()[%[[STRIDES]]#1]
120120
// CHECK: %[[IDX2:.*]] = affine.apply #[[MAP2]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
121121
// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX2]]], sizes: [%{{.*}}, %{{.*}}, %{{.*}}], strides: [%[[IDX]], %[[IDX1]], 4]
122-
// CHECK: "test.test"(%[[PTR]]) : (memref<?x?x?xf32, strided<[?, ?, 4], offset: ?>>) -> ()
122+
// CHECK: "test.test"(%[[PTR]]) : (memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) -> ()
123123
func.func @decompose_subview_strided(%arg0 : memref<?x?x?xf32>) {
124124
%c0 = arith.constant 0 : index
125125
%c1 = arith.constant 1 : index
@@ -129,8 +129,8 @@ func.func @decompose_subview_strided(%arg0 : memref<?x?x?xf32>) {
129129
%block_dim2 = memref.dim %arg0, %c2 : memref<?x?x?xf32>
130130
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
131131
threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) {
132-
%res = memref.subview %arg0[%tx, %ty, %tz] [%c2, %c2, %c2] [2, 3, 4] : memref<?x?x?xf32> to memref<?x?x?xf32, strided<[?, ?, 4], offset: ?>>
133-
"test.test"(%res) : (memref<?x?x?xf32, strided<[?, ?, 4], offset: ?>>) -> ()
132+
%res = memref.subview %arg0[%tx, %ty, %tz] [%c2, %c2, %c2] [2, 3, 4] : memref<?x?x?xf32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
133+
"test.test"(%res) : (memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) -> ()
134134
gpu.terminator
135135
}
136136
return

mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -595,9 +595,9 @@ func.func @subview_of_subview(%m: memref<1x1024xf32, 3>, %pos: index)
595595
{
596596
%0 = memref.subview %m[3, %pos] [1, 2] [1, 1]
597597
: memref<1x1024xf32, 3>
598-
to memref<1x2xf32, strided<[1024, 1], offset: ?>, 3>
598+
to memref<1x2xf32, strided<[1024, 2], offset: ?>, 3>
599599
%1 = memref.subview %0[1, 2] [1, 1] [1, 1]
600-
: memref<1x2xf32, strided<[1024, 1], offset: ?>, 3>
600+
: memref<1x2xf32, strided<[1024, 2], offset: ?>, 3>
601601
to memref<f32, strided<[], offset: ?>, 3>
602602
return %1 : memref<f32, strided<[], offset: ?>, 3>
603603
}
@@ -675,9 +675,9 @@ func.func @fold_gpu_subgroup_mma_store_matrix_1d(%dst: memref<?xvector<4xf32>>,
675675
// CHECK-LABEL: func.func @fold_gpu_subgroup_mma_load_matrix_2d
676676
// CHECK-SAME: %[[SRC:.+]]: memref<128x128xf32>
677677
func.func @fold_gpu_subgroup_mma_load_matrix_2d(%arg0 : memref<128x128xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> !gpu.mma_matrix<16x16xf16, "COp"> {
678-
%subview = memref.subview %arg0[%arg1, %arg2][64, 32][2, 1] : memref<128x128xf32> to memref<64x32xf32, strided<[256, 1], offset: ?>>
678+
%subview = memref.subview %arg0[%arg1, %arg2][64, 32][2, 1] : memref<128x128xf32> to memref<64x32xf32, strided<[64, 1], offset: ?>>
679679
// CHECK: gpu.subgroup_mma_load_matrix %[[SRC]][{{.+}}] {leadDimension = 32 : index} : memref<128x128xf32> -> !gpu.mma_matrix<16x16xf16, "COp">
680-
%matrix = gpu.subgroup_mma_load_matrix %subview[%arg3, %arg4] {leadDimension = 32 : index} : memref<64x32xf32, strided<[256, 1], offset: ?>> -> !gpu.mma_matrix<16x16xf16, "COp">
680+
%matrix = gpu.subgroup_mma_load_matrix %subview[%arg3, %arg4] {leadDimension = 32 : index} : memref<64x32xf32, strided<[64, 1], offset: ?>> -> !gpu.mma_matrix<16x16xf16, "COp">
681681
return %matrix : !gpu.mma_matrix<16x16xf16, "COp">
682682
}
683683

@@ -686,9 +686,9 @@ func.func @fold_gpu_subgroup_mma_load_matrix_2d(%arg0 : memref<128x128xf32>, %ar
686686
// CHECK-LABEL: func.func @fold_gpu_subgroup_mma_load_matrix_2d
687687
// CHECK-SAME: %[[DST:.+]]: memref<128x128xf32>
688688
func.func @fold_gpu_subgroup_mma_load_matrix_2d(%arg0 : memref<128x128xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %matrix: !gpu.mma_matrix<16x16xf16, "COp">) {
689-
%subview = memref.subview %arg0[%arg1, %arg2][64, 32][2, 1] : memref<128x128xf32> to memref<64x32xf32, strided<[256, 1], offset: ?>>
689+
%subview = memref.subview %arg0[%arg1, %arg2][64, 32][2, 1] : memref<128x128xf32> to memref<64x32xf32, strided<[64, 1], offset: ?>>
690690
// CHECK: gpu.subgroup_mma_store_matrix %{{.+}}, %[[DST]][{{.+}}] {leadDimension = 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<128x128xf32>
691-
gpu.subgroup_mma_store_matrix %matrix, %subview[%arg3, %arg4] {leadDimension = 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<64x32xf32, strided<[256, 1], offset: ?>>
691+
gpu.subgroup_mma_store_matrix %matrix, %subview[%arg3, %arg4] {leadDimension = 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<64x32xf32, strided<[64, 1], offset: ?>>
692692
return
693693
}
694694

mlir/test/Dialect/MemRef/invalid.mlir

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,12 +1073,3 @@ func.func @dim_0_ranked(%arg : memref<f32>, %arg1 : index) {
10731073
memref.dim %arg, %arg1 : memref<f32> // expected-error {{'memref.dim' op operand #0 must be unranked.memref of any type values or non-0-ranked.memref of any type values, but got 'memref<f32>'}}
10741074
return
10751075
}
1076-
1077-
// -----
1078-
1079-
func.func @subview_invalid_strides(%m: memref<7x22x333x4444xi32>) {
1080-
// expected-error @below{{expected result type to be 'memref<7x11x333x4444xi32, strided<[32556744, 2959704, 4444, 1]>>' or a rank-reduced version. (mismatch of result layout)}}
1081-
%subview = memref.subview %m[0, 0, 0, 0] [7, 11, 333, 4444] [1, 2, 1, 1]
1082-
: memref<7x22x333x4444xi32> to memref<7x11x333x4444xi32>
1083-
return
1084-
}

mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,10 @@ module {
8888
// Prepare a buffer for x0, x1, x2, y0 and a buffer for y1.
8989
%xys = memref.alloc() : memref<20xi32>
9090
%xy = memref.cast %xys : memref<20xi32> to memref<?xi32>
91-
%x0 = memref.subview %xy[%i0][%i5][4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
92-
%x1 = memref.subview %xy[%i1][%i5][4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
93-
%x2 = memref.subview %xy[%i2][%i5][4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
94-
%y0 = memref.subview %xy[%i3][%i5][4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
91+
%x0 = memref.subview %xy[%i0][%i5][%i4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
92+
%x1 = memref.subview %xy[%i1][%i5][%i4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
93+
%x2 = memref.subview %xy[%i2][%i5][%i4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
94+
%y0 = memref.subview %xy[%i3][%i5][%i4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
9595
%y1s = memref.alloc() : memref<7xi32>
9696
%y1 = memref.cast %y1s : memref<7xi32> to memref<?xi32>
9797

0 commit comments

Comments
 (0)