Skip to content

[mlir][int-range] Update int range inference for arith.xori #117272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 26, 2024

Conversation

Hardcode84
Copy link
Contributor

Previous impl was getting incorrect results for widths > i1 and was disabled.

While same algorithm can be used for andi and ori too, without additional modifications it will produce less precise result.

@llvmbot
Copy link
Member

llvmbot commented Nov 22, 2024

@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes

Previous impl was getting incorrect results for widths > i1 and was disabled.

While same algorithm can be used for andi and ori too, without additional modifications it will produce less precise result.


Full diff: https://github.com/llvm/llvm-project/pull/117272.diff

2 Files Affected:

  • (modified) mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp (+17-21)
  • (modified) mlir/test/Dialect/Arith/int-range-interface.mlir (+2-2)
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 2c1276d577a55b..7a73a94201f1d6 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -556,29 +556,25 @@ mlir::intrange::inferOr(ArrayRef<ConstantIntRanges> argRanges) {
                   /*isSigned=*/false);
 }
 
+/// Get bitmask of all bits which can change while iterating in
+/// [bound.umin(), bound.umax()].
+static APInt getVaryingBitsMask(const ConstantIntRanges &bound) {
+  APInt leftVal = bound.umin(), rightVal = bound.umax();
+  unsigned bitwidth = leftVal.getBitWidth();
+  unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero();
+  return APInt::getLowBitsSet(bitwidth, differingBits);
+}
+
 ConstantIntRanges
 mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
-  // TODO: The code below doesn't work for bitwidths > i1.
-  // For input ranges lhs=[2060639849, 2060639850], rhs=[2060639849, 2060639849]
-  // widenBitwiseBounds will produce:
-  // lhs:
-  // 2060639848  01111010110100101101111001101000
-  // 2060639851  01111010110100101101111001101011
-  // rhs:
-  // 2060639849  01111010110100101101111001101001
-  // 2060639849  01111010110100101101111001101001
-  // None of those combinations xor to 0, while intermediate values does.
-  unsigned width = argRanges[0].umin().getBitWidth();
-  if (width > 1)
-    return ConstantIntRanges::maxRange(width);
-
-  auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
-  auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
-  auto xori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
-    return a ^ b;
-  };
-  return minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
-                  /*isSigned=*/false);
+  // Construct mask of varying bits for both ranges, xor values and then replace
+  // masked bits with 0s and 1s to get min and max values respectively.
+  ConstantIntRanges lhs = argRanges[0], rhs = argRanges[1];
+  APInt mask = getVaryingBitsMask(lhs) | getVaryingBitsMask(rhs);
+  APInt res = lhs.umin() ^ rhs.umin();
+  APInt min = res & ~mask;
+  APInt max = res | mask;
+  return ConstantIntRanges::fromUnsigned(min, max);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 4db846fa4656a3..48a3eb20eb7fb0 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -481,8 +481,8 @@ func.func @xori_i1() -> (i1, i1) {
 }
 
 // CHECK-LABEL: func @xori
-// TODO: xor folding is temporarily disabled
-// CHECK-NOT: arith.constant false
+// CHECK: %[[false:.*]] = arith.constant false
+// CHECK: return %[[false]]
 func.func @xori(%arg0 : i64, %arg1 : i64) -> i1 {
     %c0 = arith.constant 0 : i64
     %c7 = arith.constant 7 : i64

Previous impl was getting incorrect results for widths > i1 and was disabled.

While same algorithm can be used for `andi` and `ori` too, , without additional modifications it will produce less precise result.
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like it's fine

@Hardcode84 Hardcode84 merged commit d471c85 into llvm:main Nov 26, 2024
8 checks passed
@Hardcode84 Hardcode84 deleted the int-range-bitwise branch November 26, 2024 10:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants