Skip to content

[mlir] Avoid common folder assuming all types are supported #68054

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 3, 2023

Conversation

jpienaar
Copy link
Member

@jpienaar jpienaar commented Oct 2, 2023

Previously this would just assume all conversions are possible and this would crash. Use an in-tree testing case here.

Previously this would just assume all conversions are possible and this would crash. Use an in-tree testing case here.
@llvmbot
Copy link
Member

llvmbot commented Oct 2, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-arith

Changes

Previously this would just assume all conversions are possible and this would crash. Use an in-tree testing case here.


Full diff: https://github.com/llvm/llvm-project/pull/68054.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/CommonFolders.h (+15-6)
  • (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+17)
diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index 6257e4a60188d57..7dabc781cd59526 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -93,8 +93,12 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
     if (lhs.getType() != rhs.getType())
       return {};
 
-    auto lhsIt = lhs.value_begin<ElementValueT>();
-    auto rhsIt = rhs.value_begin<ElementValueT>();
+    auto maybeLhsIt = lhs.try_value_begin<ElementValueT>();
+    auto maybeRhsIt = rhs.try_value_begin<ElementValueT>();
+    if (!maybeLhsIt || !maybeRhsIt)
+      return {};
+    auto lhsIt = *maybeLhsIt;
+    auto rhsIt = *maybeRhsIt;
     SmallVector<ElementValueT, 4> elementResults;
     elementResults.reserve(lhs.getNumElements());
     for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) {
@@ -227,7 +231,10 @@ Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
     // expanding the values.
     auto op = cast<ElementsAttr>(operands[0]);
 
-    auto opIt = op.value_begin<ElementValueT>();
+    auto maybeOpIt = op.try_value_begin<ElementValueT>();
+    if (!maybeOpIt)
+      return {};
+    auto opIt = *maybeOpIt;
     SmallVector<ElementValueT> elementResults;
     elementResults.reserve(op.getNumElements());
     for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
@@ -293,12 +300,14 @@ Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
       return {};
     return DenseElementsAttr::get(cast<ShapedType>(resType), elementResult);
   }
-  if (isa<ElementsAttr>(operands[0])) {
+  if (auto op = dyn_cast<ElementsAttr>(operands[0])) {
     // Operand is ElementsAttr-derived; perform an element-wise fold by
     // expanding the value.
-    auto op = cast<ElementsAttr>(operands[0]);
     bool castStatus = true;
-    auto opIt = op.value_begin<ElementValueT>();
+    auto maybeOpIt = op.try_value_begin<ElementValueT>();
+    if (!maybeOpIt)
+      return {};
+    auto opIt = *maybeOpIt;
     SmallVector<TargetElementValueT> elementResults;
     elementResults.reserve(op.getNumElements());
     for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 84096354e6afe33..8ee13fc9d1136a7 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2669,3 +2669,20 @@ func.func @extsi_poison() -> i64 {
   %1 = arith.extsi %0 : i32 to i64
   return %1 : i64
 }
+
+// Just checks that this doesn't crashes.
+// CHECK-LABEL: @unsignedExtendConstantResource
+func.func @unsignedExtendConstantResource() -> tensor<i16> {
+  %c2 = arith.constant dense_resource<blob1> : tensor<i8>
+  %ext = arith.extui %c2 : tensor<i8> to tensor<i16>
+  return %ext : tensor<i16>
+}
+
+{-#
+  dialect_resources: {
+    builtin: {
+      // Note: This is just copied blob, the actual value isn't used or checked.
+      blob1: "0x08000000010000000000000002000000000000000300000000000000"
+    }
+  }
+#-}

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % typo

@jpienaar jpienaar merged commit 0a53005 into llvm:main Oct 3, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants