Skip to content

[mlir] Fix crash when folding tensor.dim(tensor.collapse()) on out-of-bound dim #119941

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 16, 2024

Conversation

joker-eph
Copy link
Collaborator

@llvmbot
Copy link
Member

llvmbot commented Dec 14, 2024

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: Mehdi Amini (joker-eph)

Changes

#119866


Full diff: https://github.com/llvm/llvm-project/pull/119941.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+2-1)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+14)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 9bb628781342ca..21f78cf96c70e9 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2012,7 +2012,8 @@ struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
 
     // Only constant dimension values are supported.
     std::optional<int64_t> dim = dimOp.getConstantIndex();
-    if (!dim.has_value())
+    if (!dim.has_value() ||
+        dim.value() >= collapseShapeOp.getResultType().getRank())
       return failure();
 
     // Skip static dims. These are folded to constant ops.
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 613ec066337294..e8fc4ce834e18f 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2344,6 +2344,20 @@ func.func @dim_of_collapse_shape(%t: tensor<?x?x?x7x?xf32>) -> index {
 
 // -----
 
+// Can't fold when dim is out of bound.
+// CHECK-LABEL: func @out_of_bound_dim_of_collapse_shape(
+//       CHECK:   %[[DIM:.*]] = tensor.dim
+//       CHECK:   return %[[DIM]]
+func.func @out_of_bound_dim_of_collapse_shape(%t: tensor<?x?x?x7x?xf32>) -> index {
+  %c5 = arith.constant 5 : index
+  %0 = tensor.collapse_shape %t [[0], [1, 2, 3, 4]]
+      : tensor<?x?x?x7x?xf32> into tensor<?x?xf32>
+  %1 = tensor.dim %0, %c5 : tensor<?x?xf32>
+  return %1 : index
+}
+
+// -----
+
 // CHECK-LABEL: func @collapse_expand_fold_to_cast(
 //  CHECK-SAME:     %[[t:.*]]: tensor<?xf32>
 //       CHECK:   return %[[t]]

@joker-eph joker-eph merged commit 6a4750d into llvm:main Dec 16, 2024
11 checks passed
@joker-eph joker-eph deleted the fix_folding_crash branch December 16, 2024 15:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants