Skip to content

[mlir][intrange] Use nsw,nuw flags in inference #92642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/BitmaskEnum.h"
#include <optional>

namespace mlir {
Expand All @@ -31,6 +32,18 @@ static constexpr unsigned indexMaxWidth = 64;

enum class CmpMode : uint32_t { Both, Signed, Unsigned };

enum class OverflowFlags : uint32_t {
None = 0,
Nsw = 1,
Nuw = 2,
LLVM_MARK_AS_BITMASK_ENUM(Nuw)
};

/// Function that performs inference on an array of `ConstantIntRanges` while
/// taking special overflow behavior into account.
using InferRangeWithOvfFlagsFn =
function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>, OverflowFlags)>;

/// Compute `inferFn` on `ranges`, whose size should be the index storage
/// bitwidth. Then, compute the function on `argRanges` again after truncating
/// the ranges to 32 bits. Finally, if the truncation of the 64-bit result is
Expand Down Expand Up @@ -60,11 +73,14 @@ ConstantIntRanges extSIRange(const ConstantIntRanges &range,
ConstantIntRanges truncRange(const ConstantIntRanges &range,
unsigned destWidth);

ConstantIntRanges inferAdd(ArrayRef<ConstantIntRanges> argRanges);
ConstantIntRanges inferAdd(ArrayRef<ConstantIntRanges> argRanges,
OverflowFlags ovfFlags = OverflowFlags::None);

ConstantIntRanges inferSub(ArrayRef<ConstantIntRanges> argRanges);
ConstantIntRanges inferSub(ArrayRef<ConstantIntRanges> argRanges,
OverflowFlags ovfFlags = OverflowFlags::None);

ConstantIntRanges inferMul(ArrayRef<ConstantIntRanges> argRanges);
ConstantIntRanges inferMul(ArrayRef<ConstantIntRanges> argRanges,
OverflowFlags ovfFlags = OverflowFlags::None);

ConstantIntRanges inferDivS(ArrayRef<ConstantIntRanges> argRanges);

Expand Down Expand Up @@ -94,7 +110,8 @@ ConstantIntRanges inferOr(ArrayRef<ConstantIntRanges> argRanges);

ConstantIntRanges inferXor(ArrayRef<ConstantIntRanges> argRanges);

ConstantIntRanges inferShl(ArrayRef<ConstantIntRanges> argRanges);
ConstantIntRanges inferShl(ArrayRef<ConstantIntRanges> argRanges,
OverflowFlags ovfFlags = OverflowFlags::None);

ConstantIntRanges inferShrS(ArrayRef<ConstantIntRanges> argRanges);

Expand Down
22 changes: 18 additions & 4 deletions mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@ using namespace mlir;
using namespace mlir::arith;
using namespace mlir::intrange;

static intrange::OverflowFlags
convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
intrange::OverflowFlags retFlags = intrange::OverflowFlags::None;
if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nsw))
retFlags |= intrange::OverflowFlags::Nsw;
if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nuw))
retFlags |= intrange::OverflowFlags::Nuw;
return retFlags;
}

//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
Expand All @@ -38,7 +48,8 @@ void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,

void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferAdd(argRanges));
setResultRange(getResult(), inferAdd(argRanges, convertArithOverflowFlags(
getOverflowFlags())));
}

//===----------------------------------------------------------------------===//
Expand All @@ -47,7 +58,8 @@ void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,

void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferSub(argRanges));
setResultRange(getResult(), inferSub(argRanges, convertArithOverflowFlags(
getOverflowFlags())));
}

//===----------------------------------------------------------------------===//
Expand All @@ -56,7 +68,8 @@ void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,

void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferMul(argRanges));
setResultRange(getResult(), inferMul(argRanges, convertArithOverflowFlags(
getOverflowFlags())));
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -302,7 +315,8 @@ void arith::SelectOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,

void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferShl(argRanges));
setResultRange(getResult(), inferShl(argRanges, convertArithOverflowFlags(
getOverflowFlags())));
}

//===----------------------------------------------------------------------===//
Expand Down
22 changes: 18 additions & 4 deletions mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,32 @@ void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
// we take the 64-bit result).
//===----------------------------------------------------------------------===//

// Some arithmetic inference functions allow specifying special overflow / wrap
// behavior. We do not require this for the IndexOps and use this helper to call
// the inference function without any `OverflowFlags`.
static std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>
inferWithoutOverflowFlags(InferRangeWithOvfFlagsFn inferWithOvfFn) {
return [inferWithOvfFn](ArrayRef<ConstantIntRanges> argRanges) {
return inferWithOvfFn(argRanges, OverflowFlags::None);
};
}

void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferIndexOp(inferAdd, argRanges, CmpMode::Both));
setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferAdd),
argRanges, CmpMode::Both));
}

void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferIndexOp(inferSub, argRanges, CmpMode::Both));
setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferSub),
argRanges, CmpMode::Both));
}

void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferIndexOp(inferMul, argRanges, CmpMode::Both));
setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferMul),
argRanges, CmpMode::Both));
}

void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
Expand Down Expand Up @@ -127,7 +140,8 @@ void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,

void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferIndexOp(inferShl, argRanges, CmpMode::Both));
setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferShl),
argRanges, CmpMode::Both));
}

void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
Expand Down
99 changes: 59 additions & 40 deletions mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,18 +178,24 @@ ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
//===----------------------------------------------------------------------===//

ConstantIntRanges
mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges) {
mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges,
OverflowFlags ovfFlags) {
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
ConstArithFn uadd = [](const APInt &a,
const APInt &b) -> std::optional<APInt> {

std::function uadd = [=](const APInt &a,
const APInt &b) -> std::optional<APInt> {
bool overflowed = false;
APInt result = a.uadd_ov(b, overflowed);
APInt result = any(ovfFlags & OverflowFlags::Nuw)
? a.uadd_sat(b)
: a.uadd_ov(b, overflowed);
return overflowed ? std::optional<APInt>() : result;
};
ConstArithFn sadd = [](const APInt &a,
const APInt &b) -> std::optional<APInt> {
std::function sadd = [=](const APInt &a,
const APInt &b) -> std::optional<APInt> {
bool overflowed = false;
APInt result = a.sadd_ov(b, overflowed);
APInt result = any(ovfFlags & OverflowFlags::Nsw)
? a.sadd_sat(b)
: a.sadd_ov(b, overflowed);
return overflowed ? std::optional<APInt>() : result;
};

Expand All @@ -205,19 +211,24 @@ mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges) {
//===----------------------------------------------------------------------===//

ConstantIntRanges
mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges) {
mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges,
OverflowFlags ovfFlags) {
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];

ConstArithFn usub = [](const APInt &a,
const APInt &b) -> std::optional<APInt> {
std::function usub = [=](const APInt &a,
const APInt &b) -> std::optional<APInt> {
bool overflowed = false;
APInt result = a.usub_ov(b, overflowed);
APInt result = any(ovfFlags & OverflowFlags::Nuw)
? a.usub_sat(b)
: a.usub_ov(b, overflowed);
return overflowed ? std::optional<APInt>() : result;
};
ConstArithFn ssub = [](const APInt &a,
const APInt &b) -> std::optional<APInt> {
std::function ssub = [=](const APInt &a,
const APInt &b) -> std::optional<APInt> {
bool overflowed = false;
APInt result = a.ssub_ov(b, overflowed);
APInt result = any(ovfFlags & OverflowFlags::Nsw)
? a.ssub_sat(b)
: a.ssub_ov(b, overflowed);
return overflowed ? std::optional<APInt>() : result;
};
ConstantIntRanges urange = computeBoundsBy(
Expand All @@ -232,19 +243,24 @@ mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges) {
//===----------------------------------------------------------------------===//

ConstantIntRanges
mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges) {
mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges,
OverflowFlags ovfFlags) {
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];

ConstArithFn umul = [](const APInt &a,
const APInt &b) -> std::optional<APInt> {
std::function umul = [=](const APInt &a,
const APInt &b) -> std::optional<APInt> {
bool overflowed = false;
APInt result = a.umul_ov(b, overflowed);
APInt result = any(ovfFlags & OverflowFlags::Nuw)
? a.umul_sat(b)
: a.umul_ov(b, overflowed);
return overflowed ? std::optional<APInt>() : result;
};
ConstArithFn smul = [](const APInt &a,
const APInt &b) -> std::optional<APInt> {
std::function smul = [=](const APInt &a,
const APInt &b) -> std::optional<APInt> {
bool overflowed = false;
APInt result = a.smul_ov(b, overflowed);
APInt result = any(ovfFlags & OverflowFlags::Nsw)
? a.smul_sat(b)
: a.smul_ov(b, overflowed);
return overflowed ? std::optional<APInt>() : result;
};

Expand Down Expand Up @@ -542,32 +558,35 @@ mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
//===----------------------------------------------------------------------===//

ConstantIntRanges
mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges) {
mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges,
OverflowFlags ovfFlags) {
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
const APInt &lhsSMin = lhs.smin(), &lhsSMax = lhs.smax(),
&lhsUMax = lhs.umax(), &rhsUMin = rhs.umin(),
&rhsUMax = rhs.umax();
const APInt &rhsUMin = rhs.umin(), &rhsUMax = rhs.umax();

ConstArithFn shl = [](const APInt &l,
const APInt &r) -> std::optional<APInt> {
return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r);
// The signed/unsigned overflow behavior of shl by `rhs` matches a mul with
// 2^rhs.
std::function ushl = [=](const APInt &l,
const APInt &r) -> std::optional<APInt> {
bool overflowed = false;
APInt result = any(ovfFlags & OverflowFlags::Nuw)
? l.ushl_sat(r)
: l.ushl_ov(r, overflowed);
return overflowed ? std::optional<APInt>() : result;
};
std::function sshl = [=](const APInt &l,
const APInt &r) -> std::optional<APInt> {
bool overflowed = false;
APInt result = any(ovfFlags & OverflowFlags::Nsw)
? l.sshl_sat(r)
: l.sshl_ov(r, overflowed);
return overflowed ? std::optional<APInt>() : result;
};

// The minMax inference does not work when there is danger of overflow. In the
// signed case, this leads to the obvious problem that the sign bit might
// change. In the unsigned case, it also leads to problems because the largest
// LHS shifted by the largest RHS does not necessarily result in the largest
// result anymore.
assert(rhsUMax.isNonNegative() && "Unexpected negative shift count");
if (rhsUMax.uge(lhsSMin.getNumSignBits()) ||
rhsUMax.uge(lhsSMax.getNumSignBits()))
return ConstantIntRanges::maxRange(lhsUMax.getBitWidth());

ConstantIntRanges urange =
minMaxBy(shl, {lhs.umin(), lhsUMax}, {rhsUMin, rhsUMax},
minMaxBy(ushl, {lhs.umin(), lhs.umax()}, {rhsUMin, rhsUMax},
/*isSigned=*/false);
ConstantIntRanges srange =
minMaxBy(shl, {lhsSMin, lhsSMax}, {rhsUMin, rhsUMax},
minMaxBy(sshl, {lhs.smin(), lhs.smax()}, {rhsUMin, rhsUMax},
/*isSigned=*/true);
return urange.intersection(srange);
}
Expand Down
Loading
Loading