Skip to content

[mlir][vector] Fix extractelement/insertelement folder crash on poison attr #71333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 6, 2023

Conversation

Hardcode84
Copy link
Contributor

@Hardcode84 Hardcode84 commented Nov 5, 2023

Types of incoming attributes weren't properly checked.

…n attr

Types of incoming attributes weren't properly checked.
@llvmbot
Copy link
Member

llvmbot commented Nov 5, 2023

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes

…n attr

Types of incoming attributes weren't properly checked.


Full diff: https://github.com/llvm/llvm-project/pull/71333.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+14-12)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+65-1)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 60416f550ee619d..69cbdcd3f536f98 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1188,9 +1188,6 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
   if (!adaptor.getPosition())
     return {};
 
-  Attribute src = adaptor.getVector();
-  Attribute pos = adaptor.getPosition();
-
   // Fold extractelement (splat X) -> X.
   if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
     return splat.getInput();
@@ -1200,13 +1197,16 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
     if (!llvm::isa<VectorType>(broadcast.getSource().getType()))
       return broadcast.getSource();
 
+  auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
+  auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
   if (!pos || !src)
     return {};
 
-  auto srcElements = llvm::cast<DenseElementsAttr>(src).getValues<Attribute>();
+  auto srcElements = src.getValues<Attribute>();
 
-  auto attr = llvm::dyn_cast<IntegerAttr>(pos);
-  uint64_t posIdx = attr.getInt();
+  uint64_t posIdx = pos.getInt();
+  if (posIdx >= srcElements.size())
+    return {};
 
   return srcElements[posIdx];
 }
@@ -2511,18 +2511,20 @@ OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
   if (!adaptor.getPosition())
     return {};
 
-  Attribute src = adaptor.getSource();
-  Attribute dst = adaptor.getDest();
-  Attribute pos = adaptor.getPosition();
+  auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
+  auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
+  auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
   if (!src || !dst || !pos)
     return {};
 
-  auto dstElements = llvm::cast<DenseElementsAttr>(dst).getValues<Attribute>();
+  if (src.getType() != getDestVectorType().getElementType())
+    return {};
+
+  auto dstElements = dst.getValues<Attribute>();
 
   SmallVector<Attribute> results(dstElements);
 
-  auto attr = llvm::dyn_cast<IntegerAttr>(pos);
-  uint64_t posIdx = attr.getInt();
+  uint64_t posIdx = pos.getInt();
   if (posIdx >= results.size())
     return {};
   results[posIdx] = src;
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index f6bb42b1b249153..163fdd67b0cfd34 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2027,6 +2027,46 @@ func.func @insert_element_invalid_fold() -> vector<1xf32> {
   return %46 : vector<1xf32>
 }
 
+
+// -----
+
+// Do not crash on poison
+// CHECK-LABEL: func @insert_poison_fold1
+//       CHECK:   vector.insertelement
+func.func @insert_poison_fold1() -> vector<4xi32> {
+  %v = ub.poison : vector<4xi32>
+  %s = arith.constant 7 : i32
+  %i = arith.constant 2 : i32
+  %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32>
+  return %1 : vector<4xi32>
+}
+
+// -----
+
+// Do not crash on poison
+// CHECK-LABEL: func @insert_poison_fold2
+//       CHECK:   vector.insertelement
+func.func @insert_poison_fold2() -> vector<4xi32> {
+  %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32>
+  %s = ub.poison : i32
+  %i = arith.constant 2 : i32
+  %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32>
+  return %1 : vector<4xi32>
+}
+
+// -----
+
+// Do not crash on poison
+// CHECK-LABEL: func @insert_poison_fold3
+//       CHECK:   vector.insertelement
+func.func @insert_poison_fold3() -> vector<4xi32> {
+  %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32>
+  %s = arith.constant 7 : i32
+  %i = ub.poison : i32
+  %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32>
+  return %1 : vector<4xi32>
+}
+
 // -----
 
 // CHECK-LABEL: func @extract_element_fold
@@ -2051,6 +2091,30 @@ func.func @extract_element_splat_fold(%a : i32) -> i32 {
 
 // -----
 
+// Do not crash on poison
+// CHECK-LABEL: func @extract_element_poison_fold1
+//       CHECK:   vector.extractelement
+func.func @extract_element_poison_fold1() -> i32 {
+  %v = ub.poison : vector<4xi32>
+  %i = arith.constant 2 : i32
+  %1 = vector.extractelement %v[%i : i32] : vector<4xi32>
+  return %1 : i32
+}
+
+// -----
+
+// Do not crash on poison
+// CHECK-LABEL: func @extract_element_poison_fold2
+//       CHECK:   vector.extractelement
+func.func @extract_element_poison_fold2() -> i32 {
+  %v = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32>
+  %i = ub.poison : i32
+  %1 = vector.extractelement %v[%i : i32] : vector<4xi32>
+  return %1 : i32
+}
+
+// -----
+
 // CHECK-LABEL: func @reduce_one_element_vector_extract
 //  CHECK-SAME: (%[[V:.+]]: vector<1xf32>)
 //       CHECK:   %[[S:.+]] = vector.extract %[[V]][0] : f32 from vector<1xf32>
@@ -2436,4 +2500,4 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
       permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
       tensor<4x4x4xf32>, vector<1x100x4x5xf32>
   return %r : vector<1x100x4x5xf32>
-}
\ No newline at end of file
+}

@joker-eph
Copy link
Collaborator

Nit: please fix PR description/title wrapping

@Hardcode84 Hardcode84 changed the title [mlir][vector] Fix extractelement/insertelement folder crash on poiso… [mlir][vector] Fix extractelement/insertelement folder crash on poison attr Nov 5, 2023
@Hardcode84 Hardcode84 merged commit 0a22a80 into llvm:main Nov 6, 2023
@Hardcode84 Hardcode84 deleted the fix-vector-poison branch November 6, 2023 13:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants