Skip to content

[mlir] Fix a zero stride canonicalizer crash #74200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 6, 2023
Merged

[mlir] Fix a zero stride canonicalizer crash #74200

merged 8 commits into from
Dec 6, 2023

Conversation

rikhuijzer
Copy link
Member

This PR fixes #73383 and is another shot at the refactoring proposed in #72885.

@llvmbot
Copy link
Member

llvmbot commented Dec 2, 2023

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir-tensor

Author: Rik Huijzer (rikhuijzer)

Changes

This PR fixes #73383 and is another shot at the refactoring proposed in #72885.


Full diff: https://github.com/llvm/llvm-project/pull/74200.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (+27-3)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+6-11)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+7-10)
  • (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+26-1)
  • (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+12)
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 502ab93ddbfa7..a1853438ccf7f 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -139,12 +139,36 @@ SmallVector<int64_t>
 getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
                      llvm::function_ref<bool(Attribute, Attribute)> compare);
 
+/// Helper function to check whether the passed in `sizes` or `values` are
+/// valid. This can be used to re-check whether dimensions are still valid
+/// after constant folding the dynamic dimensions.
+bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets);
+
+/// Helper function to check whether the passed in `strides` are valid. This
+/// can be used to re-check whether dimensions are still valid after constant
+/// folding the dynamic dimensions.
+bool hasValidStrides(SmallVector<int64_t> strides);
+
 /// Returns "success" when any of the elements in `ofrs` is a constant value. In
 /// that case the value is replaced by an attribute. Returns "failure" when no
-/// folding happened. If `onlyNonNegative` is set, only non-negative constant
-/// values are folded.
+/// folding happened. If `onlyNonNegative` and `onlyNonZero` are set, only
+/// non-negative and non-zero constant values are folded respectively.
 LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
-                                   bool onlyNonNegative = false);
+                                   bool onlyNonNegative = false,
+                                   bool onlyNonZero = false);
+
+/// Returns "success" when any of the elements in `OffsetsOrSizes` is a
+/// constant value. In that case the value is replaced by an attribute. Returns
+/// "failure" when no folding happened. Invalid values are not folded to avoid
+/// canonicalization crashes.
+LogicalResult
+foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes);
+
+/// Returns "success" when any of the elements in `strides` is a constant
+/// value. In that case the value is replaced by an attribute. Returns
+/// "failure" when no folding happened. Invalid values are not folded to avoid
+/// canonicalization crashes.
+LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides);
 
 /// Return the number of iterations for a loop with a lower bound `lb`, upper
 /// bound `ub` and step `step`.
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index dce96cca016ff..b2d52e400e52d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2581,17 +2581,12 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
-
-  // If one of the offsets or sizes is invalid, fail the canonicalization.
-  // These checks also occur in the verifier, but they are needed here
-  // because some dynamic dimensions may have been constant folded.
-  for (int64_t offset : staticOffsets)
-    if (offset < 0 && !ShapedType::isDynamic(offset))
-      return {};
-  for (int64_t size : staticSizes)
-    if (size < 0 && !ShapedType::isDynamic(size))
-      return {};
-
+  if (!hasValidSizesOffsets(staticOffsets))
+    return {};
+  if (!hasValidSizesOffsets(staticSizes))
+    return {};
+  if (!hasValidStrides(staticStrides))
+    return {};
   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
                                     staticSizes, staticStrides);
 }
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 8970ea1c73b40..94b7b734f88fe 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1446,13 +1446,8 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
     SmallVector<int64_t> newShape;
     operandsAndShape(resultType, dynamicExtents, newOperands, newShape);
 
-    for (int64_t newdim : newShape) {
-      // This check also occurs in the verifier, but we need it here too
-      // since intermediate passes may have replaced some dynamic dimensions
-      // by constants.
-      if (newdim < 0 && !ShapedType::isDynamic(newdim))
-        return failure();
-    }
+    if (!hasValidSizesOffsets(newShape))
+      return failure();
 
     if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
       return failure();
@@ -2548,9 +2543,9 @@ class InsertSliceOpConstantArgumentFolder final
     SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
 
     // No constant operands were folded, just return;
-    if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) &&
-        failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) &&
-        failed(foldDynamicIndexList(mixedStrides)))
+    if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
+        failed(foldDynamicOffsetSizeList(mixedSizes)) &&
+        failed(foldDynamicStrideList(mixedStrides)))
       return failure();
 
     // Create the new op in canonical form.
@@ -2691,6 +2686,8 @@ struct InsertSliceOpSourceCastInserter final
         newSrcShape[i] = *constInt;
       }
     }
+    if (!hasValidSizesOffsets(newSrcShape))
+      return failure();
 
     RankedTensorType newSrcType =
         RankedTensorType::get(newSrcShape, srcType.getElementType());
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index c7a3d8fc8eb28..0c8a88da789e2 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -256,8 +256,20 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
   return mlir::ceilDiv(*ubConstant - *lbConstant, *stepConstant);
 }
 
+bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets) {
+  return llvm::none_of(sizesOrOffsets, [](int64_t value) {
+    return !ShapedType::isDynamic(value) && value < 0;
+  });
+}
+
+bool hasValidStrides(SmallVector<int64_t> strides) {
+  return llvm::none_of(strides, [](int64_t value) {
+    return !ShapedType::isDynamic(value) && value == 0;
+  });
+}
+
 LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
-                                   bool onlyNonNegative) {
+                                   bool onlyNonNegative, bool onlyNonZero) {
   bool valuesChanged = false;
   for (OpFoldResult &ofr : ofrs) {
     if (ofr.is<Attribute>())
@@ -267,6 +279,8 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
       // Note: All ofrs have index type.
       if (onlyNonNegative && *getConstantIntValue(attr) < 0)
         continue;
+      if (onlyNonZero && *getConstantIntValue(attr) == 0)
+        continue;
       ofr = attr;
       valuesChanged = true;
     }
@@ -274,4 +288,15 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
   return success(valuesChanged);
 }
 
+LogicalResult
+foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes) {
+  return foldDynamicIndexList(offsetsOrSizes, /*onlyNonNegative=*/true,
+                              /*onlyNonZero=*/false);
+}
+
+LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides) {
+  return foldDynamicIndexList(strides, /*onlyNonNegative=*/false,
+                              /*onlyNonZero=*/true);
+}
+
 } // namespace mlir
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index a1f8673638ff8..d3406c630f6dd 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -191,6 +191,18 @@ func.func @no_fold_subview_negative_size(%input: memref<4x1024xf32>) -> memref<?
 
 // -----
 
+// CHECK-LABEL: func @no_fold_subview_zero_stride
+//  CHECK:        %[[SUBVIEW:.+]] = memref.subview
+//  CHECK:        return %[[SUBVIEW]]
+func.func @no_fold_subview_zero_stride(%arg0 : memref<10xf32>) -> memref<1xf32, strided<[?], offset: 1>> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %1 = memref.subview %arg0[1] [1] [%c0] : memref<10xf32> to memref<1xf32, strided<[?], offset: 1>>
+  return %1 : memref<1xf32, strided<[?], offset: 1>>
+}
+
+// -----
+
 // CHECK-LABEL: func @no_fold_of_store
 //  CHECK:   %[[cst:.+]] = memref.cast %arg
 //  CHECK:   memref.store %[[cst]]

Copy link
Member

@Lewuathe Lewuathe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to add the test case (or minimal similar example) with --inline option so that we can confirm the original issue is resolved.

#73383

@rikhuijzer
Copy link
Member Author

Is it possible to add the test case (or minimal similar example) with --inline option so that we can confirm the original issue is resolved.

#73383

Thanks for the review and fixing the typos that I've made! I've added a test in a new MemRef/inlining.mlir file, which is in line with the other files:

$ rg -l '\-inline' mlir/test
mlir/test/Dialect/MemRef/inlining.mlir
mlir/test/Dialect/Vector/inlining.mlir
mlir/test/Dialect/UB/inlining.mlir
mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
mlir/test/Dialect/Async/async-parallel-for-canonicalize.mlir
mlir/test/Dialect/Affine/inlining.mlir
mlir/test/Dialect/LLVMIR/inline-byval-huge.mlir
mlir/test/Dialect/Bufferization/inlining.mlir
mlir/test/Dialect/LLVMIR/inlining-alias-scopes.mlir
mlir/test/Dialect/LLVMIR/inlining.mlir
mlir/test/Dialect/Linalg/inline-scalar-operands.mlir
mlir/test/Dialect/Tosa/inlining.mlir
mlir/test/Dialect/Linalg/inlining.mlir
mlir/test/lib/Transforms/TestInlining.cpp
mlir/test/Transforms/inlining.mlir
mlir/test/Transforms/inlining-repeated-use.mlir
mlir/test/Transforms/inlining-recursive.mlir
mlir/test/Transforms/inlining-dce.mlir
mlir/test/Transforms/test-inlining.mlir

(Also test additions can always be reverted of course so it shouldn't be too bad.) Also I double checked that #73383 doesn't crash on this PR.

Copy link
Member

@Lewuathe Lewuathe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update. It LGTM.

}
// CHECK-NEXT: arith.constant 0 : index
// CHECK-NEXT: %[[SUBVIEW:.*]] = memref.subview
// CHECK-NEXT: return %[[SUBVIEW]] : memref<1xf32, strided<[?], offset: 1>>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a canonicalized crash, we don't need to run the inliner.

Copy link
Member Author

@rikhuijzer rikhuijzer Dec 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this file again in the force push to ed0e4c9. 👍

@joker-eph
Copy link
Collaborator

Is it possible to add the test case (or minimal similar example) with --inline option so that we can confirm the original issue is resolved.

#73383

It is common that the person filing the issue does not reduce the IR or the pipeline that causes the bug.
In particular this particular reporter (are you associated with them @Lewuathe ?) seems like automated fuzzer. Here the only thing that the inliner does is running the canonicalizer.
I updated the issue instead.

@Lewuathe
Copy link
Member

Lewuathe commented Dec 6, 2023

@joker-eph Thank you. That gets clearer to me!

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@rikhuijzer rikhuijzer merged commit 68f0bc6 into llvm:main Dec 6, 2023
@rikhuijzer rikhuijzer deleted the rh/negative-extract-slice-canon branch December 6, 2023 06:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir] Canonicalizer crashed with assertion failure.
4 participants