Skip to content

[mlir][arith] Overflow flags propagation in arith canonicalizations. #91646

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 13, 2024

Conversation

Hardcode84
Copy link
Contributor

No description provided.

@llvmbot
Copy link
Member

llvmbot commented May 9, 2024

@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/91646.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td (+18-17)
  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+7)
  • (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+153)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 02d05780a7ac1..6f8e0b868a179 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -24,10 +24,10 @@ def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">;
 // Multiply two integer attributes and create a new one with the result.
 def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">;
 
-// TODO: Canonicalizations currently doesn't take into account integer overflow
-// flags and always reset them to default (wraparound) which is safe but can
-// inhibit later optimizations. Individual patterns must be reviewed for
-// better handling of overflow flags.
+// Merge overflow flags from 2 ops, selecting most conservative combination.
+def MergeOverflow : NativeCodeCall<"mergeOverflowFlags($0, $1)">;
+
+// Default oveflow flag (wraparound).
 defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
 
 class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
@@ -45,7 +45,7 @@ def AddIAddConstant :
           (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
-            DefOverflow)>;
+            (MergeOverflow $ovf1, $ovf2))>;
 
 // addi(subi(x, c0), c1) -> addi(x, c1 - c0)
 def AddISubConstantRHS :
@@ -53,7 +53,7 @@ def AddISubConstantRHS :
           (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
-            DefOverflow)>;
+            (MergeOverflow $ovf1, $ovf2))>;
 
 // addi(subi(c0, x), c1) -> subi(c0 + c1, x)
 def AddISubConstantLHS :
@@ -61,7 +61,7 @@ def AddISubConstantLHS :
           (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
-            DefOverflow)>;
+            (MergeOverflow $ovf1, $ovf2))>;
 
 def IsScalarOrSplatNegativeOne :
     Constraint<And<[
@@ -73,7 +73,7 @@ def AddIMulNegativeOneRhs :
     Pat<(Arith_AddIOp
            $x,
            (Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0), $ovf1), $ovf2),
-        (Arith_SubIOp $x, $y, DefOverflow),
+        (Arith_SubIOp $x, $y, (MergeOverflow $ovf1, $ovf2)),
         [(IsScalarOrSplatNegativeOne $c0)]>;
 
 // addi(muli(x, -1), y) -> subi(y, x)
@@ -81,7 +81,7 @@ def AddIMulNegativeOneLhs :
     Pat<(Arith_AddIOp
            (Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0), $ovf1),
            $y, $ovf2),
-        (Arith_SubIOp $y, $x, DefOverflow),
+        (Arith_SubIOp $y, $x, (MergeOverflow $ovf1, $ovf2)),
         [(IsScalarOrSplatNegativeOne $c0)]>;
 
 // muli(muli(x, c0), c1) -> muli(x, c0 * c1)
@@ -90,7 +90,7 @@ def MulIMulIConstant :
           (Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)),
-            DefOverflow)>;
+            (MergeOverflow $ovf1, $ovf2))>;
 
 //===----------------------------------------------------------------------===//
 // AddUIExtendedOp
@@ -113,7 +113,7 @@ def SubIRHSAddConstant :
           (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)),
-            DefOverflow)>;
+            (MergeOverflow $ovf1, $ovf2))>;
 
 // subi(c1, addi(x, c0)) -> subi(c1 - c0, x)
 def SubILHSAddConstant :
@@ -121,7 +121,7 @@ def SubILHSAddConstant :
           (ConstantLikeMatcher APIntAttr:$c1),
           (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
         (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x,
-            DefOverflow)>;
+            (MergeOverflow $ovf1, $ovf2))>;
 
 // subi(subi(x, c0), c1) -> subi(x, c0 + c1)
 def SubIRHSSubConstantRHS :
@@ -129,7 +129,7 @@ def SubIRHSSubConstantRHS :
           (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
-            DefOverflow)>;
+            (MergeOverflow $ovf1, $ovf2))>;
 
 // subi(subi(c0, x), c1) -> subi(c0 - c1, x)
 def SubIRHSSubConstantLHS :
@@ -137,7 +137,7 @@ def SubIRHSSubConstantLHS :
           (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x,
-            DefOverflow)>;
+            (MergeOverflow $ovf1, $ovf2))>;
 
 // subi(c1, subi(x, c0)) -> subi(c0 + c1, x)
 def SubILHSSubConstantRHS :
@@ -145,7 +145,7 @@ def SubILHSSubConstantRHS :
           (ConstantLikeMatcher APIntAttr:$c1),
           (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
         (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
-            DefOverflow)>;
+            (MergeOverflow $ovf1, $ovf2))>;
 
 // subi(c1, subi(c0, x)) -> addi(x, c1 - c0)
 def SubILHSSubConstantLHS :
@@ -153,12 +153,13 @@ def SubILHSSubConstantLHS :
           (ConstantLikeMatcher APIntAttr:$c1),
           (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1), $ovf2),
         (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
-            DefOverflow)>;
+            (MergeOverflow $ovf1, $ovf2))>;
 
 // subi(subi(a, b), a) -> subi(0, b)
 def SubISubILHSRHSLHS :
     Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y, $ovf1), $x, $ovf2),
-        (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y, DefOverflow)>;
+        (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y,
+            (MergeOverflow $ovf1, $ovf2))>;
 
 //===----------------------------------------------------------------------===//
 // MulSIExtendedOp
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index a1568d0ebba3a..d86dfe91771b5 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -64,6 +64,13 @@ static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
   return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
 }
 
+static IntegerOverflowFlagsAttr
+mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
+                   IntegerOverflowFlagsAttr val2) {
+  return IntegerOverflowFlagsAttr::get(val1.getContext(),
+                                       val1.getValue() & val2.getValue());
+}
+
 /// Invert an integer comparison predicate.
 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
   switch (pred) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index f7ce2123a93c6..1b0999fa8a798 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -833,6 +833,30 @@ func.func @tripleAddAdd(%arg0: index) -> index {
   return %add2 : index
 }
 
+// CHECK-LABEL: @tripleAddAddOvf1
+//       CHECK:   %[[cres:.+]] = arith.constant 59 : index
+//       CHECK:   %[[add:.+]] = arith.addi %arg0, %[[cres]] overflow<nsw, nuw> : index
+//       CHECK:   return %[[add]]
+func.func @tripleAddAddOvf1(%arg0: index) -> index {
+  %c17 = arith.constant 17 : index
+  %c42 = arith.constant 42 : index
+  %add1 = arith.addi %c17, %arg0 overflow<nsw, nuw> : index
+  %add2 = arith.addi %c42, %add1 overflow<nsw, nuw> : index
+  return %add2 : index
+}
+
+// CHECK-LABEL: @tripleAddAddOvf2
+//       CHECK:   %[[cres:.+]] = arith.constant 59 : index
+//       CHECK:   %[[add:.+]] = arith.addi %arg0, %[[cres]] : index
+//       CHECK:   return %[[add]]
+func.func @tripleAddAddOvf2(%arg0: index) -> index {
+  %c17 = arith.constant 17 : index
+  %c42 = arith.constant 42 : index
+  %add1 = arith.addi %c17, %arg0 overflow<nsw> : index
+  %add2 = arith.addi %c42, %add1 overflow<nuw> : index
+  return %add2 : index
+}
+
 // CHECK-LABEL: @tripleAddSub0
 //       CHECK:   %[[cres:.+]] = arith.constant 59 : index
 //       CHECK:   %[[add:.+]] = arith.subi %[[cres]], %arg0 : index
@@ -845,6 +869,18 @@ func.func @tripleAddSub0(%arg0: index) -> index {
   return %add2 : index
 }
 
+// CHECK-LABEL: @tripleAddSub0Ovf
+//       CHECK:   %[[cres:.+]] = arith.constant 59 : index
+//       CHECK:   %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow<nsw, nuw> : index
+//       CHECK:   return %[[add]]
+func.func @tripleAddSub0Ovf(%arg0: index) -> index {
+  %c17 = arith.constant 17 : index
+  %c42 = arith.constant 42 : index
+  %add1 = arith.subi %c17, %arg0 overflow<nsw, nuw> : index
+  %add2 = arith.addi %c42, %add1 overflow<nsw, nuw> : index
+  return %add2 : index
+}
+
 // CHECK-LABEL: @tripleAddSub1
 //       CHECK:   %[[cres:.+]] = arith.constant 25 : index
 //       CHECK:   %[[add:.+]] = arith.addi %arg0, %[[cres]] : index
@@ -857,6 +893,18 @@ func.func @tripleAddSub1(%arg0: index) -> index {
   return %add2 : index
 }
 
+// CHECK-LABEL: @tripleAddSub1Ovf
+//       CHECK:   %[[cres:.+]] = arith.constant 25 : index
+//       CHECK:   %[[add:.+]] = arith.addi %arg0, %[[cres]] overflow<nsw, nuw> : index
+//       CHECK:   return %[[add]]
+func.func @tripleAddSub1Ovf(%arg0: index) -> index {
+  %c17 = arith.constant 17 : index
+  %c42 = arith.constant 42 : index
+  %add1 = arith.subi %arg0, %c17 overflow<nsw, nuw> : index
+  %add2 = arith.addi %c42, %add1 overflow<nsw, nuw> : index
+  return %add2 : index
+}
+
 // CHECK-LABEL: @tripleSubAdd0
 //       CHECK:   %[[cres:.+]] = arith.constant 25 : index
 //       CHECK:   %[[add:.+]] = arith.subi %[[cres]], %arg0 : index
@@ -869,6 +917,18 @@ func.func @tripleSubAdd0(%arg0: index) -> index {
   return %add2 : index
 }
 
+// CHECK-LABEL: @tripleSubAdd0Ovf
+//       CHECK:   %[[cres:.+]] = arith.constant 25 : index
+//       CHECK:   %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow<nsw, nuw> : index
+//       CHECK:   return %[[add]]
+func.func @tripleSubAdd0Ovf(%arg0: index) -> index {
+  %c17 = arith.constant 17 : index
+  %c42 = arith.constant 42 : index
+  %add1 = arith.addi %c17, %arg0 overflow<nsw, nuw> : index
+  %add2 = arith.subi %c42, %add1 overflow<nsw, nuw> : index
+  return %add2 : index
+}
+
 // CHECK-LABEL: @tripleSubAdd1
 //       CHECK:   %[[cres:.+]] = arith.constant -25 : index
 //       CHECK:   %[[add:.+]] = arith.addi %arg0, %[[cres]] : index
@@ -881,6 +941,18 @@ func.func @tripleSubAdd1(%arg0: index) -> index {
   return %add2 : index
 }
 
+// CHECK-LABEL: @tripleSubAdd1Ovf
+//       CHECK:   %[[cres:.+]] = arith.constant -25 : index
+//       CHECK:   %[[add:.+]] = arith.addi %arg0, %[[cres]] overflow<nsw, nuw> : index
+//       CHECK:   return %[[add]]
+func.func @tripleSubAdd1Ovf(%arg0: index) -> index {
+  %c17 = arith.constant 17 : index
+  %c42 = arith.constant 42 : index
+  %add1 = arith.addi %c17, %arg0 overflow<nsw, nuw> : index
+  %add2 = arith.subi %add1, %c42 overflow<nsw, nuw> : index
+  return %add2 : index
+}
+
 // CHECK-LABEL: @subSub0
 //       CHECK:   %[[c0:.+]] = arith.constant 0 : index
 //       CHECK:   %[[add:.+]] = arith.subi %[[c0]], %arg1 : index
@@ -891,6 +963,16 @@ func.func @subSub0(%arg0: index, %arg1: index) -> index {
   return %sub2 : index
 }
 
+// CHECK-LABEL: @subSub0Ovf
+//       CHECK:   %[[c0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[add:.+]] = arith.subi %[[c0]], %arg1 overflow<nsw, nuw> : index
+//       CHECK:   return %[[add]]
+func.func @subSub0Ovf(%arg0: index, %arg1: index) -> index {
+  %sub1 = arith.subi %arg0, %arg1 overflow<nsw, nuw> : index
+  %sub2 = arith.subi %sub1, %arg0 overflow<nsw, nuw> : index
+  return %sub2 : index
+}
+
 // CHECK-LABEL: @tripleSubSub0
 //       CHECK:   %[[cres:.+]] = arith.constant 25 : index
 //       CHECK:   %[[add:.+]] = arith.addi %arg0, %[[cres]] : index
@@ -903,6 +985,19 @@ func.func @tripleSubSub0(%arg0: index) -> index {
   return %add2 : index
 }
 
+// CHECK-LABEL: @tripleSubSub0Ovf
+//       CHECK:   %[[cres:.+]] = arith.constant 25 : index
+//       CHECK:   %[[add:.+]] = arith.addi %arg0, %[[cres]] overflow<nsw, nuw> : index
+//       CHECK:   return %[[add]]
+func.func @tripleSubSub0Ovf(%arg0: index) -> index {
+  %c17 = arith.constant 17 : index
+  %c42 = arith.constant 42 : index
+  %add1 = arith.subi %c17, %arg0 overflow<nsw, nuw> : index
+  %add2 = arith.subi %c42, %add1 overflow<nsw, nuw> : index
+  return %add2 : index
+}
+
+
 // CHECK-LABEL: @tripleSubSub1
 //       CHECK:   %[[cres:.+]] = arith.constant -25 : index
 //       CHECK:   %[[add:.+]] = arith.subi %[[cres]], %arg0 : index
@@ -915,6 +1010,18 @@ func.func @tripleSubSub1(%arg0: index) -> index {
   return %add2 : index
 }
 
+// CHECK-LABEL: @tripleSubSub1Ovf
+//       CHECK:   %[[cres:.+]] = arith.constant -25 : index
+//       CHECK:   %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow<nsw, nuw> : index
+//       CHECK:   return %[[add]]
+func.func @tripleSubSub1Ovf(%arg0: index) -> index {
+  %c17 = arith.constant 17 : index
+  %c42 = arith.constant 42 : index
+  %add1 = arith.subi %c17, %arg0 overflow<nsw, nuw> : index
+  %add2 = arith.subi %add1, %c42 overflow<nsw, nuw> : index
+  return %add2 : index
+}
+
 // CHECK-LABEL: @tripleSubSub2
 //       CHECK:   %[[cres:.+]] = arith.constant 59 : index
 //       CHECK:   %[[add:.+]] = arith.subi %[[cres]], %arg0 : index
@@ -927,6 +1034,18 @@ func.func @tripleSubSub2(%arg0: index) -> index {
   return %add2 : index
 }
 
+// CHECK-LABEL: @tripleSubSub2Ovf
+//       CHECK:   %[[cres:.+]] = arith.constant 59 : index
+//       CHECK:   %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow<nsw, nuw> : index
+//       CHECK:   return %[[add]]
+func.func @tripleSubSub2Ovf(%arg0: index) -> index {
+  %c17 = arith.constant 17 : index
+  %c42 = arith.constant 42 : index
+  %add1 = arith.subi %arg0, %c17 overflow<nsw, nuw> : index
+  %add2 = arith.subi %c42, %add1 overflow<nsw, nuw> : index
+  return %add2 : index
+}
+
 // CHECK-LABEL: @tripleSubSub3
 //       CHECK:   %[[cres:.+]] = arith.constant 59 : index
 //       CHECK:   %[[add:.+]] = arith.subi %arg0, %[[cres]] : index
@@ -939,6 +1058,18 @@ func.func @tripleSubSub3(%arg0: index) -> index {
   return %add2 : index
 }
 
+// CHECK-LABEL: @tripleSubSub3Ovf
+//       CHECK:   %[[cres:.+]] = arith.constant 59 : index
+//       CHECK:   %[[add:.+]] = arith.subi %arg0, %[[cres]] overflow<nsw, nuw> : index
+//       CHECK:   return %[[add]]
+func.func @tripleSubSub3Ovf(%arg0: index) -> index {
+  %c17 = arith.constant 17 : index
+  %c42 = arith.constant 42 : index
+  %add1 = arith.subi %arg0, %c17 overflow<nsw, nuw> : index
+  %add2 = arith.subi %add1, %c42 overflow<nsw, nuw> : index
+  return %add2 : index
+}
+
 // CHECK-LABEL: @subAdd1
 //  CHECK-NEXT:   return %arg0
 func.func @subAdd1(%arg0: index, %arg1 : index) -> index {
@@ -1018,6 +1149,17 @@ func.func @addiMuliToSubiRhsI32(%arg0: i32, %arg1: i32) -> i32 {
   return %add : i32
 }
 
+// CHECK-LABEL: @addiMuliToSubiRhsI32Ovf
+//  CHECK-SAME:   (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
+//       CHECK:   %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] overflow<nsw, nuw> : i32
+//       CHECK:   return %[[SUB]]
+func.func @addiMuliToSubiRhsI32Ovf(%arg0: i32, %arg1: i32) -> i32 {
+  %c-1 = arith.constant -1 : i32
+  %neg = arith.muli %arg1, %c-1 overflow<nsw, nuw> : i32
+  %add = arith.addi %arg0, %neg overflow<nsw, nuw> : i32
+  return %add : i32
+}
+
 // CHECK-LABEL: @addiMuliToSubiRhsIndex
 //  CHECK-SAME:   (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index)
 //       CHECK:   %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : index
@@ -1051,6 +1193,17 @@ func.func @addiMuliToSubiLhsI32(%arg0: i32, %arg1: i32) -> i32 {
   return %add : i32
 }
 
+// CHECK-LABEL: @addiMuliToSubiLhsI32Ovf
+//  CHECK-SAME:   (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
+//       CHECK:   %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] overflow<nsw, nuw> : i32
+//       CHECK:   return %[[SUB]]
+func.func @addiMuliToSubiLhsI32Ovf(%arg0: i32, %arg1: i32) -> i32 {
+  %c-1 = arith.constant -1 : i32
+  %neg = arith.muli %arg1, %c-1 overflow<nsw, nuw> : i32
+  %add = arith.addi %neg, %arg0 overflow<nsw, nuw> : i32
+  return %add : i32
+}
+
 // CHECK-LABEL: @addiMuliToSubiLhsIndex
 //  CHECK-SAME:   (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index)
 //       CHECK:   %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : index

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for improving this.

I saw some possible issue when running the multiplication one through alive. Can you double check if there is a real problem or if I just made a mistake when replicating the test case?

// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
// CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] overflow<nsw, nuw> : i32
// CHECK: return %[[SUB]]
func.func @addiMuliToSubiLhsI32Ovf(%arg0: i32, %arg1: i32) -> i32 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When running this through alive I get a mismatch:
https://alive2.llvm.org/ce/z/8NhXoT
I did enter things manually so apologies if I made a mistake. It seems like the multiplication of 0 with -1 does not result in poison (despite the nuw flag) while sub 0, 1 does result in a unsigned overflow. The result is a bit of a surprise for me but maybe others can shed some light on it?

Could make to run this tests systematically through alive? I did it manually but it could be fairly quick to lower the test input and output to llvm ir and the copy past these overflow flag tests individually into alive. I have more trust in a tool than in my reviewing skills here. What are your takes on this @kuhar @Hardcode84?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, some script which can run entire arith canonicalization suite through alive2 will be extremely useful. I can think about it and switch this PR to draft fro now.

Copy link
Contributor Author

@Hardcode84 Hardcode84 May 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, I wrote this abomination https://gist.github.com/Hardcode84/13a8c5a45aa0e3849e1ab5d30a56749e

  • -split-input-file is not handled properly and I had to manually remove // ----- markers
  • Had to remove tensor mentions from the file as it cannot be translated
  • Had to remove i0 tests as they hard-crash -mlir-to-llvmir
  • Had to manually cleanup duplicated source_filename and metadata from result.

And all of this was in vain as alive2 timeouts on such big input https://alive2.llvm.org/ce/z/bjvhLR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess they website has limited compute time allocated :(.

I don't think the script is in vain though. I would start by running it on the new tests you added. Worst case test case by test case to avoid the time out.

Copy link
Contributor Author

@Hardcode84 Hardcode84 May 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is better script https://gist.github.com/Hardcode84/bbf08e0923d4b9004aa831d0a1d3ac4f

Usage python verify_canon.py D:/projs/llvm/llvm-project/mlir/test/Dialect/Arith/canonicalize.mlir tripleAddAddOvf1 tripleAddAddOvf2 tripleAddSub0Ovf tripleAddSub1Ovf tripleSubAdd0Ovf tripleSubAdd1Ovf subSub0Ovf tripleSubSub0Ovf tripleSubSub1Ovf tripleSubSub2Ovf tripleSubSub3Ovf addiMuliToSubiRhsI32Ovf addiMuliToSubiLhsI32Ovf > out.txt

output https://alive2.llvm.org/ce/z/KhQs4J

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Looks like the patterns that convert add <-> sub seem to be tricky, probably because addition uses an unsigned overflow to implement subtraction? So maybe dropping the nuw flag helps.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove patterns that failed from this PR and will create separate PRs for them and for the script.

@Hardcode84 Hardcode84 marked this pull request as draft May 10, 2024 09:29
@Hardcode84 Hardcode84 marked this pull request as ready for review May 11, 2024 20:16
@Hardcode84
Copy link
Contributor Author

Removed canons where MergeOverflow doesn't work for now. https://alive2.llvm.org/ce/z/gQAsdn

Hardcode84 added a commit to Hardcode84/llvm-project that referenced this pull request May 11, 2024
This script takes IR before and after canonicalization, translates it into llvm IR and converts to format suitablle for Alive2 https://alive2.llvm.org/ce/

This is primarily for arith canonicalizations verification, but technically it can be adapted for any dialect translatable to llvm.

Usage `python verify_canon.py canonicalize.mlir func1 func2 ...`

Example output: https://alive2.llvm.org/ce/z/KhQs4J

Initial discussion: llvm#91646 (review)
Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks LGTM modulo the open comments.

@Hardcode84 Hardcode84 merged commit 345f57d into llvm:main May 13, 2024
@Hardcode84 Hardcode84 deleted the arith-overflow-canon branch May 13, 2024 11:47
Hardcode84 added a commit that referenced this pull request May 13, 2024
…91867)

This script takes IR before and after canonicalization, translates it
into llvm IR and converts it to format suitable for Alive2
https://alive2.llvm.org/ce/

This is primarily for arith canonicalizations verification, but
technically it can be adapted for any dialect translatable to llvm.

Usage `python verify_canon.py canonicalize.mlir -f func1 func2 ...`

Example output: https://alive2.llvm.org/ce/z/KhQs4J

Initial discussion:
#91646 (review)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants