-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Fix crash in vector.insert
canonicalization
#97801
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Fix crash in vector.insert
canonicalization
#97801
Conversation
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThe Full diff: https://github.com/llvm/llvm-project/pull/97801.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 53a6648de014c..7967aa1582fd8 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2851,6 +2851,9 @@ class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
Attribute vectorDestCst;
if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
return failure();
+ auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
+ if (!denseDest)
+ return failure();
VectorType destTy = destVector.getType();
if (destTy.isScalable())
@@ -2860,9 +2863,7 @@ class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
if (destTy.getNumElements() > vectorSizeFoldThreshold &&
!destVector.hasOneUse())
return failure();
-
- auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
-
+
Value sourceValue = op.getSource();
Attribute sourceCst;
if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1a674d715ca61..9f16dfb5093d0 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2729,3 +2729,14 @@ func.func @fold_vector_step_to_constant() -> vector<4xindex> {
%0 = vector.step : vector<4xindex>
return %0 : vector<4xindex>
}
+
+// -----
+
+// CHECK-LABEL: func @vector_insert_regression(
+// CHECK: llvm.mlir.undef
+// CHECK: vector.insert
+func.func @vector_insert_regression(%arg0: i8) -> vector<4xi8> {
+ %0 = llvm.mlir.undef : vector<4xi8>
+ %1 = vector.insert %arg0, %0 [0] : i8 into vector<4xi8>
+ return %1 : vector<4xi8>
+}
|
@llvm/pr-subscribers-mlir-vector Author: Matthias Springer (matthias-springer) ChangesThe Full diff: https://github.com/llvm/llvm-project/pull/97801.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 53a6648de014c..7967aa1582fd8 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2851,6 +2851,9 @@ class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
Attribute vectorDestCst;
if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
return failure();
+ auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
+ if (!denseDest)
+ return failure();
VectorType destTy = destVector.getType();
if (destTy.isScalable())
@@ -2860,9 +2863,7 @@ class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
if (destTy.getNumElements() > vectorSizeFoldThreshold &&
!destVector.hasOneUse())
return failure();
-
- auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
-
+
Value sourceValue = op.getSource();
Attribute sourceCst;
if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1a674d715ca61..9f16dfb5093d0 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2729,3 +2729,14 @@ func.func @fold_vector_step_to_constant() -> vector<4xindex> {
%0 = vector.step : vector<4xindex>
return %0 : vector<4xindex>
}
+
+// -----
+
+// CHECK-LABEL: func @vector_insert_regression(
+// CHECK: llvm.mlir.undef
+// CHECK: vector.insert
+func.func @vector_insert_regression(%arg0: i8) -> vector<4xi8> {
+ %0 = llvm.mlir.undef : vector<4xi8>
+ %1 = vector.insert %arg0, %0 [0] : i8 into vector<4xi8>
+ return %1 : vector<4xi8>
+}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor nit but otherwise LGTM, cheers
The `InsertOpConstantFolder` assumed that whenever the destination can be folded to a constant attribute, that attribute must be a `DenseElementsAttr`. That is is not necessarily the case.
2a7b2e6
to
9db3421
Compare
The
InsertOpConstantFolder
assumed that whenever the destination can be folded to a constant attribute, that attribute must be aDenseElementsAttr
. That is is not necessarily the case.