Skip to content

[mlir][Arith] Let integer range narrowing handle negative values #119642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
262 changes: 141 additions & 121 deletions mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
return inferredRange.getConstantValue();
}

static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
Value newVal) {
assert(oldVal.getType() == newVal.getType() &&
"Can't copy integer ranges between different types");
auto *oldState = solver.lookupState<IntegerValueRangeLattice>(oldVal);
if (!oldState)
return;
(void)solver.getOrCreateState<IntegerValueRangeLattice>(newVal)->join(
*oldState);
}

/// Patterned after SCCP
static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
PatternRewriter &rewriter,
Expand Down Expand Up @@ -80,6 +91,7 @@ static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
if (!constOp)
return failure();

copyIntegerRange(solver, value, constOp->getResult(0));
rewriter.replaceAllUsesWith(value, constOp->getResult(0));
return success();
}
Expand Down Expand Up @@ -195,56 +207,21 @@ struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
DataFlowSolver &solver;
};

/// Check if `type` is index or integer type with `getWidth() > targetBitwidth`.
static LogicalResult checkIntType(Type type, unsigned targetBitwidth) {
Type elemType = getElementTypeOrSelf(type);
if (isa<IndexType>(elemType))
return success();

if (auto intType = dyn_cast<IntegerType>(elemType))
if (intType.getWidth() > targetBitwidth)
return success();

return failure();
}

/// Check if op have same type for all operands and results and this type
/// is suitable for truncation.
static LogicalResult checkElementwiseOpType(Operation *op,
unsigned targetBitwidth) {
if (op->getNumOperands() == 0 || op->getNumResults() == 0)
return failure();

Type type;
for (Value val : llvm::concat<Value>(op->getOperands(), op->getResults())) {
if (!type) {
type = val.getType();
continue;
}

if (type != val.getType())
return failure();
}

return checkIntType(type, targetBitwidth);
}

/// Return union of all operands values ranges.
static std::optional<ConstantIntRanges> getOperandsRange(DataFlowSolver &solver,
ValueRange operands) {
std::optional<ConstantIntRanges> ret;
for (Value value : operands) {
/// Gather ranges for all the values in `values`. Appends to the existing
/// vector.
static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values,
SmallVectorImpl<ConstantIntRanges> &ranges) {
for (Value val : values) {
auto *maybeInferredRange =
solver.lookupState<IntegerValueRangeLattice>(value);
solver.lookupState<IntegerValueRangeLattice>(val);
if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
return std::nullopt;
return failure();

const ConstantIntRanges &inferredRange =
maybeInferredRange->getValue().getValue();

ret = (ret ? ret->rangeUnion(inferredRange) : inferredRange);
ranges.push_back(inferredRange);
}
return ret;
return success();
}

/// Return int type truncated to `targetBitwidth`. If `srcType` is shaped,
Expand All @@ -258,56 +235,79 @@ static Type getTargetType(Type srcType, unsigned targetBitwidth) {
return dstType;
}

/// Check provided `range` is inside `smin, smax, umin, umax` bounds.
static LogicalResult checkRange(const ConstantIntRanges &range, APInt smin,
APInt smax, APInt umin, APInt umax) {
auto sge = [](APInt val1, APInt val2) -> bool {
unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
val1 = val1.sext(width);
val2 = val2.sext(width);
return val1.sge(val2);
};
auto sle = [](APInt val1, APInt val2) -> bool {
unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
val1 = val1.sext(width);
val2 = val2.sext(width);
return val1.sle(val2);
};
auto uge = [](APInt val1, APInt val2) -> bool {
unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
val1 = val1.zext(width);
val2 = val2.zext(width);
return val1.uge(val2);
};
auto ule = [](APInt val1, APInt val2) -> bool {
unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
val1 = val1.zext(width);
val2 = val2.zext(width);
return val1.ule(val2);
};
return success(sge(range.smin(), smin) && sle(range.smax(), smax) &&
uge(range.umin(), umin) && ule(range.umax(), umax));
namespace {
// Enum for tracking which type of truncation should be performed
// to narrow an operation, if any.
enum class CastKind : uint8_t { None, Signed, Unsigned, Both };
} // namespace

/// If the values within `range` can be represented using only `width` bits,
/// return the kind of truncation needed to preserve that property.
///
/// This check relies on the fact that the signed and unsigned ranges are both
/// always correct, but that one might be an approximation of the other,
/// so we want to use the correct truncation operation.
static CastKind checkTruncatability(const ConstantIntRanges &range,
unsigned targetWidth) {
unsigned srcWidth = range.smin().getBitWidth();
if (srcWidth <= targetWidth)
return CastKind::None;
unsigned removedWidth = srcWidth - targetWidth;
// The sign bits need to extend into the sign bit of the target width. For
// example, if we're truncating 64 bits to 32, we need 64 - 32 + 1 = 33 sign
// bits.
bool canTruncateSigned =
range.smin().getNumSignBits() >= (removedWidth + 1) &&
range.smax().getNumSignBits() >= (removedWidth + 1);
bool canTruncateUnsigned = range.umin().countLeadingZeros() >= removedWidth &&
range.umax().countLeadingZeros() >= removedWidth;
if (canTruncateSigned && canTruncateUnsigned)
return CastKind::Both;
if (canTruncateSigned)
return CastKind::Signed;
if (canTruncateUnsigned)
return CastKind::Unsigned;
return CastKind::None;
}

static CastKind mergeCastKinds(CastKind lhs, CastKind rhs) {
if (lhs == CastKind::None || rhs == CastKind::None)
return CastKind::None;
if (lhs == CastKind::Both)
return rhs;
if (rhs == CastKind::Both)
return lhs;
if (lhs == rhs)
return lhs;
return CastKind::None;
}

static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
CastKind castKind) {
Type srcType = src.getType();
assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) &&
"Mixing vector and non-vector types");
assert(castKind != CastKind::None && "Can't cast when casting isn't allowed");
Type srcElemType = getElementTypeOrSelf(srcType);
Type dstElemType = getElementTypeOrSelf(dstType);
assert(srcElemType.isIntOrIndex() && "Invalid src type");
assert(dstElemType.isIntOrIndex() && "Invalid dst type");
if (srcType == dstType)
return src;

if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType))
if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) {
if (castKind == CastKind::Signed)
return builder.create<arith::IndexCastOp>(loc, dstType, src);
return builder.create<arith::IndexCastUIOp>(loc, dstType, src);
}

auto srcInt = cast<IntegerType>(srcElemType);
auto dstInt = cast<IntegerType>(dstElemType);
if (dstInt.getWidth() < srcInt.getWidth())
return builder.create<arith::TruncIOp>(loc, dstType, src);

if (castKind == CastKind::Signed)
return builder.create<arith::ExtSIOp>(loc, dstType, src);
return builder.create<arith::ExtUIOp>(loc, dstType, src);
}

Expand All @@ -319,36 +319,47 @@ struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
using OpTraitRewritePattern::OpTraitRewritePattern;
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
std::optional<ConstantIntRanges> range =
getOperandsRange(solver, op->getResults());
if (!range)
return failure();
if (op->getNumResults() == 0)
return rewriter.notifyMatchFailure(op, "can't narrow resultless op");

SmallVector<ConstantIntRanges> ranges;
if (failed(collectRanges(solver, op->getOperands(), ranges)))
return rewriter.notifyMatchFailure(op, "input without specified range");
if (failed(collectRanges(solver, op->getResults(), ranges)))
return rewriter.notifyMatchFailure(op, "output without specified range");

Type srcType = op->getResult(0).getType();
if (!llvm::all_equal(op->getResultTypes()))
return rewriter.notifyMatchFailure(op, "mismatched result types");
if (op->getNumOperands() == 0 ||
!llvm::all_of(op->getOperandTypes(),
[=](Type t) { return t == srcType; }))
return rewriter.notifyMatchFailure(
op, "no operands or operand types don't match result type");

for (unsigned targetBitwidth : targetBitwidths) {
if (failed(checkElementwiseOpType(op, targetBitwidth)))
continue;

Type srcType = op->getResult(0).getType();

// We are truncating op args to the desired bitwidth before the op and
// then extending op results back to the original width after. extui and
// exti will produce different results for negative values, so limit
// signed range to non-negative values.
auto smin = APInt::getZero(targetBitwidth);
auto smax = APInt::getSignedMaxValue(targetBitwidth);
auto umin = APInt::getMinValue(targetBitwidth);
auto umax = APInt::getMaxValue(targetBitwidth);
if (failed(checkRange(*range, smin, smax, umin, umax)))
CastKind castKind = CastKind::Both;
for (const ConstantIntRanges &range : ranges) {
castKind = mergeCastKinds(castKind,
checkTruncatability(range, targetBitwidth));
if (castKind == CastKind::None)
break;
}
if (castKind == CastKind::None)
continue;

Type targetType = getTargetType(srcType, targetBitwidth);
if (targetType == srcType)
continue;

Location loc = op->getLoc();
IRMapping mapping;
for (Value arg : op->getOperands()) {
Value newArg = doCast(rewriter, loc, arg, targetType);
for (auto [arg, argRange] : llvm::zip_first(op->getOperands(), ranges)) {
CastKind argCastKind = castKind;
// When dealing with `index` values, preserve non-negativity in the
// index_casts since we can't recover this in unsigned when equivalent.
if (argCastKind == CastKind::Signed && argRange.smin().isNonNegative())
argCastKind = CastKind::Both;
Value newArg = doCast(rewriter, loc, arg, targetType, argCastKind);
mapping.map(arg, newArg);
}

Expand All @@ -359,8 +370,12 @@ struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
}
});
SmallVector<Value> newResults;
for (Value res : newOp->getResults())
newResults.emplace_back(doCast(rewriter, loc, res, srcType));
for (auto [newRes, oldRes] :
llvm::zip_equal(newOp->getResults(), op->getResults())) {
Value castBack = doCast(rewriter, loc, newRes, srcType, castKind);
copyIntegerRange(solver, oldRes, castBack);
newResults.push_back(castBack);
}

rewriter.replaceOp(op, newResults);
return success();
Expand All @@ -382,21 +397,19 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
Value lhs = op.getLhs();
Value rhs = op.getRhs();

std::optional<ConstantIntRanges> range =
getOperandsRange(solver, {lhs, rhs});
if (!range)
SmallVector<ConstantIntRanges> ranges;
if (failed(collectRanges(solver, op.getOperands(), ranges)))
return failure();
const ConstantIntRanges &lhsRange = ranges[0];
const ConstantIntRanges &rhsRange = ranges[1];

Type srcType = lhs.getType();
for (unsigned targetBitwidth : targetBitwidths) {
Type srcType = lhs.getType();
if (failed(checkIntType(srcType, targetBitwidth)))
continue;

auto smin = APInt::getSignedMinValue(targetBitwidth);
auto smax = APInt::getSignedMaxValue(targetBitwidth);
auto umin = APInt::getMinValue(targetBitwidth);
auto umax = APInt::getMaxValue(targetBitwidth);
if (failed(checkRange(*range, smin, smax, umin, umax)))
CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth);
CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth);
CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind);
// Note: this includes target width > src width.
if (castKind == CastKind::None)
continue;

Type targetType = getTargetType(srcType, targetBitwidth);
Expand All @@ -405,12 +418,13 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {

Location loc = op->getLoc();
IRMapping mapping;
for (Value arg : op->getOperands()) {
Value newArg = doCast(rewriter, loc, arg, targetType);
mapping.map(arg, newArg);
}
Value lhsCast = doCast(rewriter, loc, lhs, targetType, lhsCastKind);
Value rhsCast = doCast(rewriter, loc, rhs, targetType, rhsCastKind);
mapping.map(lhs, lhsCast);
mapping.map(rhs, rhsCast);

Operation *newOp = rewriter.clone(*op, mapping);
copyIntegerRange(solver, op.getResult(), newOp->getResult(0));
rewriter.replaceOp(op, newOp->getResults());
return success();
}
Expand All @@ -425,19 +439,23 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
/// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg
/// This pattern assumes all passed `targetBitwidths` are not wider than index
/// type.
struct FoldIndexCastChain final : OpRewritePattern<arith::IndexCastUIOp> {
template <typename CastOp>
struct FoldIndexCastChain final : OpRewritePattern<CastOp> {
FoldIndexCastChain(MLIRContext *context, ArrayRef<unsigned> target)
: OpRewritePattern(context), targetBitwidths(target) {}
: OpRewritePattern<CastOp>(context), targetBitwidths(target) {}

LogicalResult matchAndRewrite(arith::IndexCastUIOp op,
LogicalResult matchAndRewrite(CastOp op,
PatternRewriter &rewriter) const override {
auto srcOp = op.getIn().getDefiningOp<arith::IndexCastUIOp>();
auto srcOp = op.getIn().template getDefiningOp<CastOp>();
if (!srcOp)
return failure();
return rewriter.notifyMatchFailure(op, "doesn't come from an index cast");

Value src = srcOp.getIn();
if (src.getType() != op.getType())
return failure();
return rewriter.notifyMatchFailure(op, "outer types don't match");

if (!srcOp.getType().isIndex())
return rewriter.notifyMatchFailure(op, "intermediate type isn't index");

auto intType = dyn_cast<IntegerType>(op.getType());
if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth()))
Expand Down Expand Up @@ -517,7 +535,9 @@ void mlir::arith::populateIntRangeNarrowingPatterns(
ArrayRef<unsigned> bitwidthsSupported) {
patterns.add<NarrowElementwise, NarrowCmpI>(patterns.getContext(), solver,
bitwidthsSupported);
patterns.add<FoldIndexCastChain>(patterns.getContext(), bitwidthsSupported);
patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>,
FoldIndexCastChain<arith::IndexCastOp>>(patterns.getContext(),
bitwidthsSupported);
}

std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
Expand Down
Loading
Loading