Skip to content

[mlir][intrange] Use nsw,nuw flags in inference #92642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 22, 2024

Conversation

ubfx
Copy link
Member

@ubfx ubfx commented May 18, 2024

This patch includes the "no signed wrap" and "no unsigned wrap" flags,
which can be used to annotate some Ops in the arith dialect and also
in LLVMIR, in the integer range inference.

The general approach is to use saturating arithmetic operations to infer
bounds which are assumed to not wrap and use overflowing arithmetic
operations in the normal case. If overflow is detected in the normal case,
special handling makes sure that we don't underestimate the result range.

Stacked on #92641

@llvmbot
Copy link
Member

llvmbot commented May 18, 2024

@llvm/pr-subscribers-mlir-index

@llvm/pr-subscribers-mlir-arith

Author: Felix Schneider (ubfx)

Changes

This patch includes the "no signed wrap" and "no unsigned wrap" flags,
which can be used to annotate some Ops in the arith dialect and also
in LLVMIR, in the integer range inference.

The general approach is to use saturating arithmetic operations to infer
bounds which are assumed to not wrap and use overflowing arithmetic
operations in the normal case. If overflow is detected in the normal case,
special handling makes sure that we don't underestimate the result range.

Stacked on #92641


Patch is 24.02 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/92642.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h (+21-4)
  • (modified) mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp (+18-4)
  • (modified) mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp (+18-4)
  • (modified) mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp (+58-40)
  • (modified) mlir/test/Dialect/Arith/int-range-interface.mlir (+134-1)
  • (modified) mlir/test/Dialect/Arith/int-range-opts.mlir (+2-2)
  • (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+14-5)
diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
index 97c97c23ba82a..851bb534bc7ee 100644
--- a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
+++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
@@ -16,6 +16,7 @@
 
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/BitmaskEnum.h"
 #include <optional>
 
 namespace mlir {
@@ -31,6 +32,18 @@ static constexpr unsigned indexMaxWidth = 64;
 
 enum class CmpMode : uint32_t { Both, Signed, Unsigned };
 
+enum class OverflowFlags : uint32_t {
+  None = 0,
+  Nsw = 1,
+  Nuw = 2,
+  LLVM_MARK_AS_BITMASK_ENUM(Nuw)
+};
+
+/// Function that performs inference on an array of `ConstantIntRanges` while
+/// taking special overflow behavior into account.
+using InferRangeWithOvfFlagsFn =
+    function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>, OverflowFlags)>;
+
 /// Compute `inferFn` on `ranges`, whose size should be the index storage
 /// bitwidth. Then, compute the function on `argRanges` again after truncating
 /// the ranges to 32 bits. Finally, if the truncation of the 64-bit result is
@@ -60,11 +73,14 @@ ConstantIntRanges extSIRange(const ConstantIntRanges &range,
 ConstantIntRanges truncRange(const ConstantIntRanges &range,
                              unsigned destWidth);
 
-ConstantIntRanges inferAdd(ArrayRef<ConstantIntRanges> argRanges);
+ConstantIntRanges inferAdd(ArrayRef<ConstantIntRanges> argRanges,
+                           OverflowFlags ovfFlags = OverflowFlags::None);
 
-ConstantIntRanges inferSub(ArrayRef<ConstantIntRanges> argRanges);
+ConstantIntRanges inferSub(ArrayRef<ConstantIntRanges> argRanges,
+                           OverflowFlags ovfFlags = OverflowFlags::None);
 
-ConstantIntRanges inferMul(ArrayRef<ConstantIntRanges> argRanges);
+ConstantIntRanges inferMul(ArrayRef<ConstantIntRanges> argRanges,
+                           OverflowFlags ovfFlags = OverflowFlags::None);
 
 ConstantIntRanges inferDivS(ArrayRef<ConstantIntRanges> argRanges);
 
@@ -94,7 +110,8 @@ ConstantIntRanges inferOr(ArrayRef<ConstantIntRanges> argRanges);
 
 ConstantIntRanges inferXor(ArrayRef<ConstantIntRanges> argRanges);
 
-ConstantIntRanges inferShl(ArrayRef<ConstantIntRanges> argRanges);
+ConstantIntRanges inferShl(ArrayRef<ConstantIntRanges> argRanges,
+                           OverflowFlags ovfFlags = OverflowFlags::None);
 
 ConstantIntRanges inferShrS(ArrayRef<ConstantIntRanges> argRanges);
 
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 71eb36bb07a6e..6024f37359ad5 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -19,6 +19,16 @@ using namespace mlir;
 using namespace mlir::arith;
 using namespace mlir::intrange;
 
+intrange::OverflowFlags
+convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
+  intrange::OverflowFlags retFlags = intrange::OverflowFlags::None;
+  if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nsw))
+    retFlags |= intrange::OverflowFlags::Nsw;
+  if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nuw))
+    retFlags |= intrange::OverflowFlags::Nuw;
+  return retFlags;
+}
+
 //===----------------------------------------------------------------------===//
 // ConstantOp
 //===----------------------------------------------------------------------===//
@@ -38,7 +48,8 @@ void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferAdd(argRanges));
+  setResultRange(getResult(), inferAdd(argRanges, convertArithOverflowFlags(
+                                                      getOverflowFlags())));
 }
 
 //===----------------------------------------------------------------------===//
@@ -47,7 +58,8 @@ void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferSub(argRanges));
+  setResultRange(getResult(), inferSub(argRanges, convertArithOverflowFlags(
+                                                      getOverflowFlags())));
 }
 
 //===----------------------------------------------------------------------===//
@@ -56,7 +68,8 @@ void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferMul(argRanges));
+  setResultRange(getResult(), inferMul(argRanges, convertArithOverflowFlags(
+                                                      getOverflowFlags())));
 }
 
 //===----------------------------------------------------------------------===//
@@ -302,7 +315,8 @@ void arith::SelectOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferShl(argRanges));
+  setResultRange(getResult(), inferShl(argRanges, convertArithOverflowFlags(
+                                                      getOverflowFlags())));
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
index b6b8a136791c7..ffa74f0c0faae 100644
--- a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
@@ -44,19 +44,32 @@ void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 // we take the 64-bit result).
 //===----------------------------------------------------------------------===//
 
+// Some arithmetic inference functions allow specifying special overflow / wrap
+// behavior. We do not require this for the IndexOps and use this helper to call
+// the inference function without any `OverflowFlags`.
+std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>
+inferWithoutOverflowFlags(InferRangeWithOvfFlagsFn inferWithOvfFn) {
+  return [inferWithOvfFn](ArrayRef<ConstantIntRanges> argRanges) {
+    return inferWithOvfFn(argRanges, OverflowFlags::None);
+  };
+}
+
 void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferAdd, argRanges, CmpMode::Both));
+  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferAdd),
+                                           argRanges, CmpMode::Both));
 }
 
 void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferSub, argRanges, CmpMode::Both));
+  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferSub),
+                                           argRanges, CmpMode::Both));
 }
 
 void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferMul, argRanges, CmpMode::Both));
+  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferMul),
+                                           argRanges, CmpMode::Both));
 }
 
 void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
@@ -127,7 +140,8 @@ void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferShl, argRanges, CmpMode::Both));
+  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferShl),
+                                           argRanges, CmpMode::Both));
 }
 
 void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 6af229cae10ab..a1422db70fa37 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -178,18 +178,23 @@ ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
 //===----------------------------------------------------------------------===//
 
 ConstantIntRanges
-mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges) {
+mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges,
+                         OverflowFlags ovfFlags) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-  ConstArithFn uadd = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  ConstArithFn uadd = [ovfFlags](const APInt &a,
+                                 const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.uadd_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nuw)
+                       ? a.uadd_sat(b)
+                       : a.uadd_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
-  ConstArithFn sadd = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  ConstArithFn sadd = [ovfFlags](const APInt &a,
+                                 const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.sadd_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nsw)
+                       ? a.sadd_sat(b)
+                       : a.sadd_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
 
@@ -205,19 +210,24 @@ mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges) {
 //===----------------------------------------------------------------------===//
 
 ConstantIntRanges
-mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges) {
+mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges,
+                         OverflowFlags ovfFlags) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
 
-  ConstArithFn usub = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  ConstArithFn usub = [ovfFlags](const APInt &a,
+                                 const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.usub_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nuw)
+                       ? a.usub_sat(b)
+                       : a.usub_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
-  ConstArithFn ssub = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  ConstArithFn ssub = [ovfFlags](const APInt &a,
+                                 const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.ssub_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nsw)
+                       ? a.ssub_sat(b)
+                       : a.ssub_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
   ConstantIntRanges urange = computeBoundsBy(
@@ -232,19 +242,24 @@ mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges) {
 //===----------------------------------------------------------------------===//
 
 ConstantIntRanges
-mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges) {
+mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges,
+                         OverflowFlags ovfFlags) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
 
-  ConstArithFn umul = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  ConstArithFn umul = [ovfFlags](const APInt &a,
+                                 const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.umul_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nuw)
+                       ? a.umul_sat(b)
+                       : a.umul_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
-  ConstArithFn smul = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  ConstArithFn smul = [ovfFlags](const APInt &a,
+                                 const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.smul_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nsw)
+                       ? a.smul_sat(b)
+                       : a.smul_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
 
@@ -542,32 +557,35 @@ mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
 //===----------------------------------------------------------------------===//
 
 ConstantIntRanges
-mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges) {
+mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges,
+                         OverflowFlags ovfFlags) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-  const APInt &lhsSMin = lhs.smin(), &lhsSMax = lhs.smax(),
-              &lhsUMax = lhs.umax(), &rhsUMin = rhs.umin(),
-              &rhsUMax = rhs.umax();
+  const APInt &rhsUMin = rhs.umin(), &rhsUMax = rhs.umax();
 
-  ConstArithFn shl = [](const APInt &l,
-                        const APInt &r) -> std::optional<APInt> {
-    return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r);
+  // The signed/unsigned overflow behavior of shl by `rhs` matches a mul with
+  // 2^rhs.
+  ConstArithFn ushl = [ovfFlags](const APInt &l,
+                                 const APInt &r) -> std::optional<APInt> {
+    bool overflowed = false;
+    APInt result = any(ovfFlags & OverflowFlags::Nuw)
+                       ? l.ushl_sat(r)
+                       : l.ushl_ov(r, overflowed);
+    return overflowed ? std::optional<APInt>() : result;
+  };
+  ConstArithFn sshl = [ovfFlags](const APInt &l,
+                                 const APInt &r) -> std::optional<APInt> {
+    bool overflowed = false;
+    APInt result = any(ovfFlags & OverflowFlags::Nsw)
+                       ? l.sshl_sat(r)
+                       : l.sshl_ov(r, overflowed);
+    return overflowed ? std::optional<APInt>() : result;
   };
-
-  // The minMax inference does not work when there is danger of overflow. In the
-  // signed case, this leads to the obvious problem that the sign bit might
-  // change. In the unsigned case, it also leads to problems because the largest
-  // LHS shifted by the largest RHS does not necessarily result in the largest
-  // result anymore.
-  assert(rhsUMax.isNonNegative() && "Unexpected negative shift count");
-  if (rhsUMax.uge(lhsSMin.getNumSignBits()) ||
-      rhsUMax.uge(lhsSMax.getNumSignBits()))
-    return ConstantIntRanges::maxRange(lhsUMax.getBitWidth());
 
   ConstantIntRanges urange =
-      minMaxBy(shl, {lhs.umin(), lhsUMax}, {rhsUMin, rhsUMax},
+      minMaxBy(ushl, {lhs.umin(), lhs.umax()}, {rhsUMin, rhsUMax},
                /*isSigned=*/false);
   ConstantIntRanges srange =
-      minMaxBy(shl, {lhsSMin, lhsSMax}, {rhsUMin, rhsUMax},
+      minMaxBy(sshl, {lhs.smin(), lhs.smax()}, {rhsUMin, rhsUMax},
                /*isSigned=*/true);
   return urange.intersection(srange);
 }
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 16524b3634723..5b538197a0c11 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -758,7 +758,7 @@ func.func private @callee(%arg0: memref<?xindex, 4>) {
 }
 
 // CHECK-LABEL: func @test_i8_bounds
-// CHECK: test.reflect_bounds {smax = 127 : i8, smin = -128 : i8, umax = -1 : i8, umin = 0 : i8}
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 255 : ui8, umin = 0 : ui8}
 func.func @test_i8_bounds() -> i8 {
   %cst1 = arith.constant 1 : i8
   %0 = test.with_bounds { umin = 0 : i8, umax = 255 : i8, smin = -128 : i8, smax = 127 : i8 } : i8
@@ -766,3 +766,136 @@ func.func @test_i8_bounds() -> i8 {
   %2 = test.reflect_bounds %1 : i8
   return %2: i8
 }
+
+// CHECK-LABEL: func @test_add_1
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 255 : ui8, umin = 0 : ui8}
+func.func @test_add_1() -> i8 {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 255 : i8, smin = -128 : i8, smax = 127 : i8 } : i8
+  %1 = arith.addi %0, %cst1 : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// Tests below check inference with overflow flags.
+
+// CHECK-LABEL: func @test_add_i8_wrap1
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 128 : ui8, umin = 1 : ui8}
+func.func @test_add_i8_wrap1() -> i8 {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 127 : i8, smin = 0 : i8, smax = 127 : i8 } : i8
+  // smax overflow
+  %1 = arith.addi %0, %cst1 : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_add_i8_wrap2
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 128 : ui8, umin = 1 : ui8}
+func.func @test_add_i8_wrap2() -> i8 {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 127 : i8, smin = 0 : i8, smax = 127 : i8 } : i8
+  // smax overflow
+  %1 = arith.addi %0, %cst1 overflow<nuw> : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_add_i8_nowrap
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = 1 : si8, umax = 127 : ui8, umin = 1 : ui8}
+func.func @test_add_i8_nowrap() -> i8 {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 127 : i8, smin = 0 : i8, smax = 127 : i8 } : i8
+  // nsw flag stops smax from overflowing
+  %1 = arith.addi %0, %cst1 overflow<nsw> : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_sub_i8_wrap1
+// CHECK: test.reflect_bounds {smax = 5 : si8, smin = -10 : si8, umax = 255 : ui8, umin = 0 : ui8} %1 : i8
+func.func @test_sub_i8_wrap1() -> i8 {
+  %cst10 = arith.constant 10 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 15 : i8, smin = 0 : i8, smax = 15 : i8 } : i8
+  // umin underflows
+  %1 = arith.subi %0, %cst10 : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_sub_i8_wrap2
+// CHECK: test.reflect_bounds {smax = 5 : si8, smin = -10 : si8, umax = 255 : ui8, umin = 0 : ui8} %1 : i8
+func.func @test_sub_i8_wrap2() -> i8 {
+  %cst10 = arith.constant 10 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 15 : i8, smin = 0 : i8, smax = 15 : i8 } : i8
+  // umin underflows
+  %1 = arith.subi %0, %cst10 overflow<nsw> : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_sub_i8_nowrap
+// CHECK: test.reflect_bounds {smax = 5 : si8, smin = 0 : si8, umax = 5 : ui8, umin = 0 : ui8}
+func.func @test_sub_i8_nowrap() -> i8 {
+  %cst10 = arith.constant 10 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 15 : i8, smin = 0 : i8, smax = 15 : i8 } : i8
+  // nuw flag stops umin from underflowing
+  %1 = arith.subi %0, %cst10 overflow<nuw> : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_mul_i8_wrap
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 200 : ui8, umin = 100 : ui8}
+func.func @test_mul_i8_wrap() -> i8 {
+  %cst10 = arith.constant 10 : i8
+  %0 = test.with_bounds { umin = 10 : i8, umax = 20 : i8, smin = 10 : i8, smax = 20 : i8 } : i8
+  // smax overflows
+  %1 = arith.muli %0, %cst10 : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_mul_i8_nowrap
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = 100 : si8, uma...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 18, 2024

@llvm/pr-subscribers-mlir

Author: Felix Schneider (ubfx)

Changes

This patch includes the "no signed wrap" and "no unsigned wrap" flags,
which can be used to annotate some Ops in the arith dialect and also
in LLVMIR, in the integer range inference.

The general approach is to use saturating arithmetic operations to infer
bounds which are assumed to not wrap and use overflowing arithmetic
operations in the normal case. If overflow is detected in the normal case,
special handling makes sure that we don't underestimate the result range.

Stacked on #92641


Patch is 24.02 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/92642.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h (+21-4)
  • (modified) mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp (+18-4)
  • (modified) mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp (+18-4)
  • (modified) mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp (+58-40)
  • (modified) mlir/test/Dialect/Arith/int-range-interface.mlir (+134-1)
  • (modified) mlir/test/Dialect/Arith/int-range-opts.mlir (+2-2)
  • (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+14-5)
diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
index 97c97c23ba82a..851bb534bc7ee 100644
--- a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
+++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
@@ -16,6 +16,7 @@
 
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/BitmaskEnum.h"
 #include <optional>
 
 namespace mlir {
@@ -31,6 +32,18 @@ static constexpr unsigned indexMaxWidth = 64;
 
 enum class CmpMode : uint32_t { Both, Signed, Unsigned };
 
+enum class OverflowFlags : uint32_t {
+  None = 0,
+  Nsw = 1,
+  Nuw = 2,
+  LLVM_MARK_AS_BITMASK_ENUM(Nuw)
+};
+
+/// Function that performs inference on an array of `ConstantIntRanges` while
+/// taking special overflow behavior into account.
+using InferRangeWithOvfFlagsFn =
+    function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>, OverflowFlags)>;
+
 /// Compute `inferFn` on `ranges`, whose size should be the index storage
 /// bitwidth. Then, compute the function on `argRanges` again after truncating
 /// the ranges to 32 bits. Finally, if the truncation of the 64-bit result is
@@ -60,11 +73,14 @@ ConstantIntRanges extSIRange(const ConstantIntRanges &range,
 ConstantIntRanges truncRange(const ConstantIntRanges &range,
                              unsigned destWidth);
 
-ConstantIntRanges inferAdd(ArrayRef<ConstantIntRanges> argRanges);
+ConstantIntRanges inferAdd(ArrayRef<ConstantIntRanges> argRanges,
+                           OverflowFlags ovfFlags = OverflowFlags::None);
 
-ConstantIntRanges inferSub(ArrayRef<ConstantIntRanges> argRanges);
+ConstantIntRanges inferSub(ArrayRef<ConstantIntRanges> argRanges,
+                           OverflowFlags ovfFlags = OverflowFlags::None);
 
-ConstantIntRanges inferMul(ArrayRef<ConstantIntRanges> argRanges);
+ConstantIntRanges inferMul(ArrayRef<ConstantIntRanges> argRanges,
+                           OverflowFlags ovfFlags = OverflowFlags::None);
 
 ConstantIntRanges inferDivS(ArrayRef<ConstantIntRanges> argRanges);
 
@@ -94,7 +110,8 @@ ConstantIntRanges inferOr(ArrayRef<ConstantIntRanges> argRanges);
 
 ConstantIntRanges inferXor(ArrayRef<ConstantIntRanges> argRanges);
 
-ConstantIntRanges inferShl(ArrayRef<ConstantIntRanges> argRanges);
+ConstantIntRanges inferShl(ArrayRef<ConstantIntRanges> argRanges,
+                           OverflowFlags ovfFlags = OverflowFlags::None);
 
 ConstantIntRanges inferShrS(ArrayRef<ConstantIntRanges> argRanges);
 
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 71eb36bb07a6e..6024f37359ad5 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -19,6 +19,16 @@ using namespace mlir;
 using namespace mlir::arith;
 using namespace mlir::intrange;
 
+intrange::OverflowFlags
+convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
+  intrange::OverflowFlags retFlags = intrange::OverflowFlags::None;
+  if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nsw))
+    retFlags |= intrange::OverflowFlags::Nsw;
+  if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nuw))
+    retFlags |= intrange::OverflowFlags::Nuw;
+  return retFlags;
+}
+
 //===----------------------------------------------------------------------===//
 // ConstantOp
 //===----------------------------------------------------------------------===//
@@ -38,7 +48,8 @@ void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferAdd(argRanges));
+  setResultRange(getResult(), inferAdd(argRanges, convertArithOverflowFlags(
+                                                      getOverflowFlags())));
 }
 
 //===----------------------------------------------------------------------===//
@@ -47,7 +58,8 @@ void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferSub(argRanges));
+  setResultRange(getResult(), inferSub(argRanges, convertArithOverflowFlags(
+                                                      getOverflowFlags())));
 }
 
 //===----------------------------------------------------------------------===//
@@ -56,7 +68,8 @@ void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferMul(argRanges));
+  setResultRange(getResult(), inferMul(argRanges, convertArithOverflowFlags(
+                                                      getOverflowFlags())));
 }
 
 //===----------------------------------------------------------------------===//
@@ -302,7 +315,8 @@ void arith::SelectOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferShl(argRanges));
+  setResultRange(getResult(), inferShl(argRanges, convertArithOverflowFlags(
+                                                      getOverflowFlags())));
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
index b6b8a136791c7..ffa74f0c0faae 100644
--- a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
@@ -44,19 +44,32 @@ void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 // we take the 64-bit result).
 //===----------------------------------------------------------------------===//
 
+// Some arithmetic inference functions allow specifying special overflow / wrap
+// behavior. We do not require this for the IndexOps and use this helper to call
+// the inference function without any `OverflowFlags`.
+std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>
+inferWithoutOverflowFlags(InferRangeWithOvfFlagsFn inferWithOvfFn) {
+  return [inferWithOvfFn](ArrayRef<ConstantIntRanges> argRanges) {
+    return inferWithOvfFn(argRanges, OverflowFlags::None);
+  };
+}
+
 void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferAdd, argRanges, CmpMode::Both));
+  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferAdd),
+                                           argRanges, CmpMode::Both));
 }
 
 void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferSub, argRanges, CmpMode::Both));
+  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferSub),
+                                           argRanges, CmpMode::Both));
 }
 
 void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferMul, argRanges, CmpMode::Both));
+  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferMul),
+                                           argRanges, CmpMode::Both));
 }
 
 void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
@@ -127,7 +140,8 @@ void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferShl, argRanges, CmpMode::Both));
+  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferShl),
+                                           argRanges, CmpMode::Both));
 }
 
 void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 6af229cae10ab..a1422db70fa37 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -178,18 +178,23 @@ ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
 //===----------------------------------------------------------------------===//
 
 ConstantIntRanges
-mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges) {
+mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges,
+                         OverflowFlags ovfFlags) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-  ConstArithFn uadd = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  ConstArithFn uadd = [ovfFlags](const APInt &a,
+                                 const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.uadd_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nuw)
+                       ? a.uadd_sat(b)
+                       : a.uadd_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
-  ConstArithFn sadd = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  ConstArithFn sadd = [ovfFlags](const APInt &a,
+                                 const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.sadd_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nsw)
+                       ? a.sadd_sat(b)
+                       : a.sadd_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
 
@@ -205,19 +210,24 @@ mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges) {
 //===----------------------------------------------------------------------===//
 
 ConstantIntRanges
-mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges) {
+mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges,
+                         OverflowFlags ovfFlags) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
 
-  ConstArithFn usub = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  ConstArithFn usub = [ovfFlags](const APInt &a,
+                                 const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.usub_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nuw)
+                       ? a.usub_sat(b)
+                       : a.usub_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
-  ConstArithFn ssub = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  ConstArithFn ssub = [ovfFlags](const APInt &a,
+                                 const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.ssub_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nsw)
+                       ? a.ssub_sat(b)
+                       : a.ssub_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
   ConstantIntRanges urange = computeBoundsBy(
@@ -232,19 +242,24 @@ mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges) {
 //===----------------------------------------------------------------------===//
 
 ConstantIntRanges
-mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges) {
+mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges,
+                         OverflowFlags ovfFlags) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
 
-  ConstArithFn umul = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  ConstArithFn umul = [ovfFlags](const APInt &a,
+                                 const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.umul_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nuw)
+                       ? a.umul_sat(b)
+                       : a.umul_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
-  ConstArithFn smul = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  ConstArithFn smul = [ovfFlags](const APInt &a,
+                                 const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.smul_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nsw)
+                       ? a.smul_sat(b)
+                       : a.smul_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
 
@@ -542,32 +557,35 @@ mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
 //===----------------------------------------------------------------------===//
 
 ConstantIntRanges
-mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges) {
+mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges,
+                         OverflowFlags ovfFlags) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-  const APInt &lhsSMin = lhs.smin(), &lhsSMax = lhs.smax(),
-              &lhsUMax = lhs.umax(), &rhsUMin = rhs.umin(),
-              &rhsUMax = rhs.umax();
+  const APInt &rhsUMin = rhs.umin(), &rhsUMax = rhs.umax();
 
-  ConstArithFn shl = [](const APInt &l,
-                        const APInt &r) -> std::optional<APInt> {
-    return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r);
+  // The signed/unsigned overflow behavior of shl by `rhs` matches a mul with
+  // 2^rhs.
+  ConstArithFn ushl = [ovfFlags](const APInt &l,
+                                 const APInt &r) -> std::optional<APInt> {
+    bool overflowed = false;
+    APInt result = any(ovfFlags & OverflowFlags::Nuw)
+                       ? l.ushl_sat(r)
+                       : l.ushl_ov(r, overflowed);
+    return overflowed ? std::optional<APInt>() : result;
+  };
+  ConstArithFn sshl = [ovfFlags](const APInt &l,
+                                 const APInt &r) -> std::optional<APInt> {
+    bool overflowed = false;
+    APInt result = any(ovfFlags & OverflowFlags::Nsw)
+                       ? l.sshl_sat(r)
+                       : l.sshl_ov(r, overflowed);
+    return overflowed ? std::optional<APInt>() : result;
   };
-
-  // The minMax inference does not work when there is danger of overflow. In the
-  // signed case, this leads to the obvious problem that the sign bit might
-  // change. In the unsigned case, it also leads to problems because the largest
-  // LHS shifted by the largest RHS does not necessarily result in the largest
-  // result anymore.
-  assert(rhsUMax.isNonNegative() && "Unexpected negative shift count");
-  if (rhsUMax.uge(lhsSMin.getNumSignBits()) ||
-      rhsUMax.uge(lhsSMax.getNumSignBits()))
-    return ConstantIntRanges::maxRange(lhsUMax.getBitWidth());
 
   ConstantIntRanges urange =
-      minMaxBy(shl, {lhs.umin(), lhsUMax}, {rhsUMin, rhsUMax},
+      minMaxBy(ushl, {lhs.umin(), lhs.umax()}, {rhsUMin, rhsUMax},
                /*isSigned=*/false);
   ConstantIntRanges srange =
-      minMaxBy(shl, {lhsSMin, lhsSMax}, {rhsUMin, rhsUMax},
+      minMaxBy(sshl, {lhs.smin(), lhs.smax()}, {rhsUMin, rhsUMax},
                /*isSigned=*/true);
   return urange.intersection(srange);
 }
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 16524b3634723..5b538197a0c11 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -758,7 +758,7 @@ func.func private @callee(%arg0: memref<?xindex, 4>) {
 }
 
 // CHECK-LABEL: func @test_i8_bounds
-// CHECK: test.reflect_bounds {smax = 127 : i8, smin = -128 : i8, umax = -1 : i8, umin = 0 : i8}
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 255 : ui8, umin = 0 : ui8}
 func.func @test_i8_bounds() -> i8 {
   %cst1 = arith.constant 1 : i8
   %0 = test.with_bounds { umin = 0 : i8, umax = 255 : i8, smin = -128 : i8, smax = 127 : i8 } : i8
@@ -766,3 +766,136 @@ func.func @test_i8_bounds() -> i8 {
   %2 = test.reflect_bounds %1 : i8
   return %2: i8
 }
+
+// CHECK-LABEL: func @test_add_1
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 255 : ui8, umin = 0 : ui8}
+func.func @test_add_1() -> i8 {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 255 : i8, smin = -128 : i8, smax = 127 : i8 } : i8
+  %1 = arith.addi %0, %cst1 : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// Tests below check inference with overflow flags.
+
+// CHECK-LABEL: func @test_add_i8_wrap1
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 128 : ui8, umin = 1 : ui8}
+func.func @test_add_i8_wrap1() -> i8 {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 127 : i8, smin = 0 : i8, smax = 127 : i8 } : i8
+  // smax overflow
+  %1 = arith.addi %0, %cst1 : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_add_i8_wrap2
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 128 : ui8, umin = 1 : ui8}
+func.func @test_add_i8_wrap2() -> i8 {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 127 : i8, smin = 0 : i8, smax = 127 : i8 } : i8
+  // smax overflow
+  %1 = arith.addi %0, %cst1 overflow<nuw> : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_add_i8_nowrap
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = 1 : si8, umax = 127 : ui8, umin = 1 : ui8}
+func.func @test_add_i8_nowrap() -> i8 {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 127 : i8, smin = 0 : i8, smax = 127 : i8 } : i8
+  // nsw flag stops smax from overflowing
+  %1 = arith.addi %0, %cst1 overflow<nsw> : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_sub_i8_wrap1
+// CHECK: test.reflect_bounds {smax = 5 : si8, smin = -10 : si8, umax = 255 : ui8, umin = 0 : ui8} %1 : i8
+func.func @test_sub_i8_wrap1() -> i8 {
+  %cst10 = arith.constant 10 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 15 : i8, smin = 0 : i8, smax = 15 : i8 } : i8
+  // umin underflows
+  %1 = arith.subi %0, %cst10 : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_sub_i8_wrap2
+// CHECK: test.reflect_bounds {smax = 5 : si8, smin = -10 : si8, umax = 255 : ui8, umin = 0 : ui8} %1 : i8
+func.func @test_sub_i8_wrap2() -> i8 {
+  %cst10 = arith.constant 10 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 15 : i8, smin = 0 : i8, smax = 15 : i8 } : i8
+  // umin underflows
+  %1 = arith.subi %0, %cst10 overflow<nsw> : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_sub_i8_nowrap
+// CHECK: test.reflect_bounds {smax = 5 : si8, smin = 0 : si8, umax = 5 : ui8, umin = 0 : ui8}
+func.func @test_sub_i8_nowrap() -> i8 {
+  %cst10 = arith.constant 10 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 15 : i8, smin = 0 : i8, smax = 15 : i8 } : i8
+  // nuw flag stops umin from underflowing
+  %1 = arith.subi %0, %cst10 overflow<nuw> : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_mul_i8_wrap
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 200 : ui8, umin = 100 : ui8}
+func.func @test_mul_i8_wrap() -> i8 {
+  %cst10 = arith.constant 10 : i8
+  %0 = test.with_bounds { umin = 10 : i8, umax = 20 : i8, smin = 10 : i8, smax = 20 : i8 } : i8
+  // smax overflows
+  %1 = arith.muli %0, %cst10 : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_mul_i8_nowrap
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = 100 : si8, uma...
[truncated]

@Hardcode84
Copy link
Contributor

Nice, I had this in plans (#77392) for a long time, but never had time to work on it. Thanks for tackling this.

Copy link

github-actions bot commented May 19, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@ubfx ubfx force-pushed the users/ubfx/intrange-infer-nsw-nuw branch from 5f2b4ea to ce4cbe3 Compare May 19, 2024 09:39
@ubfx
Copy link
Member Author

ubfx commented May 19, 2024

Nice, I had this in plans (#77392) for a long time, but never had time to work on it. Thanks for tackling this.

Thanks for the review, I'll post in the meta issue once this is merged

@ubfx ubfx force-pushed the users/ubfx/intrange-infer-nsw-nuw branch from f1e29ea to 3092534 Compare May 19, 2024 18:28
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a solid change, approved

ubfx added 4 commits May 21, 2024 19:05
This patch includes the "no signed wrap" and "no unsigned wrap" flags,
which can be used to annotate some Ops in the `arith` dialect and also
in LLVMIR, in the integer range inference.

The general approach is to use saturating arithmetic operations to infer
bounds which are assumed to not wrap and use overflowing arithmetic
operations in the normal case. If overflow is detected in the normal case,
special handling makes sure that we don't underestimate the result range.
@ubfx ubfx force-pushed the users/ubfx/intrange-infer-nsw-nuw branch from 3092534 to 9b400b7 Compare May 21, 2024 17:06
@ubfx ubfx merged commit 2034f2f into main May 22, 2024
4 checks passed
@ubfx ubfx deleted the users/ubfx/intrange-infer-nsw-nuw branch May 22, 2024 07:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants