Skip to content

Commit d471c85

Browse files
authored
[mlir][int-range] Update int range inference for arith.xori (#117272)
Previous impl was getting incorrect results for widths > i1 and was disabled. While same algorithm can be used for `andi` and `ori` too, without additional modifications it will produce less precise result.
1 parent b9e3a76 commit d471c85

File tree

2 files changed

+19
-23
lines changed

2 files changed

+19
-23
lines changed

mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -556,29 +556,25 @@ mlir::intrange::inferOr(ArrayRef<ConstantIntRanges> argRanges) {
556556
/*isSigned=*/false);
557557
}
558558

559+
/// Get bitmask of all bits which can change while iterating in
560+
/// [bound.umin(), bound.umax()].
561+
static APInt getVaryingBitsMask(const ConstantIntRanges &bound) {
562+
APInt leftVal = bound.umin(), rightVal = bound.umax();
563+
unsigned bitwidth = leftVal.getBitWidth();
564+
unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero();
565+
return APInt::getLowBitsSet(bitwidth, differingBits);
566+
}
567+
559568
ConstantIntRanges
560569
mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
561-
// TODO: The code below doesn't work for bitwidths > i1.
562-
// For input ranges lhs=[2060639849, 2060639850], rhs=[2060639849, 2060639849]
563-
// widenBitwiseBounds will produce:
564-
// lhs:
565-
// 2060639848 01111010110100101101111001101000
566-
// 2060639851 01111010110100101101111001101011
567-
// rhs:
568-
// 2060639849 01111010110100101101111001101001
569-
// 2060639849 01111010110100101101111001101001
570-
// None of those combinations xor to 0, while intermediate values does.
571-
unsigned width = argRanges[0].umin().getBitWidth();
572-
if (width > 1)
573-
return ConstantIntRanges::maxRange(width);
574-
575-
auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
576-
auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
577-
auto xori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
578-
return a ^ b;
579-
};
580-
return minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
581-
/*isSigned=*/false);
570+
// Construct mask of varying bits for both ranges, xor values and then replace
571+
// masked bits with 0s and 1s to get min and max values respectively.
572+
ConstantIntRanges lhs = argRanges[0], rhs = argRanges[1];
573+
APInt mask = getVaryingBitsMask(lhs) | getVaryingBitsMask(rhs);
574+
APInt res = lhs.umin() ^ rhs.umin();
575+
APInt min = res & ~mask;
576+
APInt max = res | mask;
577+
return ConstantIntRanges::fromUnsigned(min, max);
582578
}
583579

584580
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Arith/int-range-interface.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -481,8 +481,8 @@ func.func @xori_i1() -> (i1, i1) {
481481
}
482482

483483
// CHECK-LABEL: func @xori
484-
// TODO: xor folding is temporarily disabled
485-
// CHECK-NOT: arith.constant false
484+
// CHECK: %[[false:.*]] = arith.constant false
485+
// CHECK: return %[[false]]
486486
func.func @xori(%arg0 : i64, %arg1 : i64) -> i1 {
487487
%c0 = arith.constant 0 : i64
488488
%c7 = arith.constant 7 : i64

0 commit comments

Comments
 (0)