Skip to content

Commit 9bf7930

Browse files
authored
[mlir][Arith] Let integer range narrowing handle negative values (#119642)
Update integer range narrowing to handle negative values. The previous restriction to only narrowing known-non-negative values wasn't needed, as both the signed and unsigned ranges represent bounds on the values of each variable in the program ... except that one might be more accurate than the other. So, if either the signed or unsigned interpretetation of the inputs and outputs allows for integer narrowing, the narrowing is permitted. This commit also updates the integer optimization rewrites to preserve the stae of constant-like operations and those that are narrowed so that rewrites of other operations don't lose that range information.
1 parent ecdf0da commit 9bf7930

File tree

2 files changed

+246
-142
lines changed

2 files changed

+246
-142
lines changed

mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp

Lines changed: 141 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,17 @@ static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
4646
return inferredRange.getConstantValue();
4747
}
4848

49+
static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
50+
Value newVal) {
51+
assert(oldVal.getType() == newVal.getType() &&
52+
"Can't copy integer ranges between different types");
53+
auto *oldState = solver.lookupState<IntegerValueRangeLattice>(oldVal);
54+
if (!oldState)
55+
return;
56+
(void)solver.getOrCreateState<IntegerValueRangeLattice>(newVal)->join(
57+
*oldState);
58+
}
59+
4960
/// Patterned after SCCP
5061
static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
5162
PatternRewriter &rewriter,
@@ -80,6 +91,7 @@ static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
8091
if (!constOp)
8192
return failure();
8293

94+
copyIntegerRange(solver, value, constOp->getResult(0));
8395
rewriter.replaceAllUsesWith(value, constOp->getResult(0));
8496
return success();
8597
}
@@ -195,56 +207,21 @@ struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
195207
DataFlowSolver &solver;
196208
};
197209

198-
/// Check if `type` is index or integer type with `getWidth() > targetBitwidth`.
199-
static LogicalResult checkIntType(Type type, unsigned targetBitwidth) {
200-
Type elemType = getElementTypeOrSelf(type);
201-
if (isa<IndexType>(elemType))
202-
return success();
203-
204-
if (auto intType = dyn_cast<IntegerType>(elemType))
205-
if (intType.getWidth() > targetBitwidth)
206-
return success();
207-
208-
return failure();
209-
}
210-
211-
/// Check if op have same type for all operands and results and this type
212-
/// is suitable for truncation.
213-
static LogicalResult checkElementwiseOpType(Operation *op,
214-
unsigned targetBitwidth) {
215-
if (op->getNumOperands() == 0 || op->getNumResults() == 0)
216-
return failure();
217-
218-
Type type;
219-
for (Value val : llvm::concat<Value>(op->getOperands(), op->getResults())) {
220-
if (!type) {
221-
type = val.getType();
222-
continue;
223-
}
224-
225-
if (type != val.getType())
226-
return failure();
227-
}
228-
229-
return checkIntType(type, targetBitwidth);
230-
}
231-
232-
/// Return union of all operands values ranges.
233-
static std::optional<ConstantIntRanges> getOperandsRange(DataFlowSolver &solver,
234-
ValueRange operands) {
235-
std::optional<ConstantIntRanges> ret;
236-
for (Value value : operands) {
210+
/// Gather ranges for all the values in `values`. Appends to the existing
211+
/// vector.
212+
static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values,
213+
SmallVectorImpl<ConstantIntRanges> &ranges) {
214+
for (Value val : values) {
237215
auto *maybeInferredRange =
238-
solver.lookupState<IntegerValueRangeLattice>(value);
216+
solver.lookupState<IntegerValueRangeLattice>(val);
239217
if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
240-
return std::nullopt;
218+
return failure();
241219

242220
const ConstantIntRanges &inferredRange =
243221
maybeInferredRange->getValue().getValue();
244-
245-
ret = (ret ? ret->rangeUnion(inferredRange) : inferredRange);
222+
ranges.push_back(inferredRange);
246223
}
247-
return ret;
224+
return success();
248225
}
249226

250227
/// Return int type truncated to `targetBitwidth`. If `srcType` is shaped,
@@ -258,56 +235,79 @@ static Type getTargetType(Type srcType, unsigned targetBitwidth) {
258235
return dstType;
259236
}
260237

261-
/// Check provided `range` is inside `smin, smax, umin, umax` bounds.
262-
static LogicalResult checkRange(const ConstantIntRanges &range, APInt smin,
263-
APInt smax, APInt umin, APInt umax) {
264-
auto sge = [](APInt val1, APInt val2) -> bool {
265-
unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
266-
val1 = val1.sext(width);
267-
val2 = val2.sext(width);
268-
return val1.sge(val2);
269-
};
270-
auto sle = [](APInt val1, APInt val2) -> bool {
271-
unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
272-
val1 = val1.sext(width);
273-
val2 = val2.sext(width);
274-
return val1.sle(val2);
275-
};
276-
auto uge = [](APInt val1, APInt val2) -> bool {
277-
unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
278-
val1 = val1.zext(width);
279-
val2 = val2.zext(width);
280-
return val1.uge(val2);
281-
};
282-
auto ule = [](APInt val1, APInt val2) -> bool {
283-
unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
284-
val1 = val1.zext(width);
285-
val2 = val2.zext(width);
286-
return val1.ule(val2);
287-
};
288-
return success(sge(range.smin(), smin) && sle(range.smax(), smax) &&
289-
uge(range.umin(), umin) && ule(range.umax(), umax));
238+
namespace {
239+
// Enum for tracking which type of truncation should be performed
240+
// to narrow an operation, if any.
241+
enum class CastKind : uint8_t { None, Signed, Unsigned, Both };
242+
} // namespace
243+
244+
/// If the values within `range` can be represented using only `width` bits,
245+
/// return the kind of truncation needed to preserve that property.
246+
///
247+
/// This check relies on the fact that the signed and unsigned ranges are both
248+
/// always correct, but that one might be an approximation of the other,
249+
/// so we want to use the correct truncation operation.
250+
static CastKind checkTruncatability(const ConstantIntRanges &range,
251+
unsigned targetWidth) {
252+
unsigned srcWidth = range.smin().getBitWidth();
253+
if (srcWidth <= targetWidth)
254+
return CastKind::None;
255+
unsigned removedWidth = srcWidth - targetWidth;
256+
// The sign bits need to extend into the sign bit of the target width. For
257+
// example, if we're truncating 64 bits to 32, we need 64 - 32 + 1 = 33 sign
258+
// bits.
259+
bool canTruncateSigned =
260+
range.smin().getNumSignBits() >= (removedWidth + 1) &&
261+
range.smax().getNumSignBits() >= (removedWidth + 1);
262+
bool canTruncateUnsigned = range.umin().countLeadingZeros() >= removedWidth &&
263+
range.umax().countLeadingZeros() >= removedWidth;
264+
if (canTruncateSigned && canTruncateUnsigned)
265+
return CastKind::Both;
266+
if (canTruncateSigned)
267+
return CastKind::Signed;
268+
if (canTruncateUnsigned)
269+
return CastKind::Unsigned;
270+
return CastKind::None;
271+
}
272+
273+
static CastKind mergeCastKinds(CastKind lhs, CastKind rhs) {
274+
if (lhs == CastKind::None || rhs == CastKind::None)
275+
return CastKind::None;
276+
if (lhs == CastKind::Both)
277+
return rhs;
278+
if (rhs == CastKind::Both)
279+
return lhs;
280+
if (lhs == rhs)
281+
return lhs;
282+
return CastKind::None;
290283
}
291284

292-
static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
285+
static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
286+
CastKind castKind) {
293287
Type srcType = src.getType();
294288
assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) &&
295289
"Mixing vector and non-vector types");
290+
assert(castKind != CastKind::None && "Can't cast when casting isn't allowed");
296291
Type srcElemType = getElementTypeOrSelf(srcType);
297292
Type dstElemType = getElementTypeOrSelf(dstType);
298293
assert(srcElemType.isIntOrIndex() && "Invalid src type");
299294
assert(dstElemType.isIntOrIndex() && "Invalid dst type");
300295
if (srcType == dstType)
301296
return src;
302297

303-
if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType))
298+
if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) {
299+
if (castKind == CastKind::Signed)
300+
return builder.create<arith::IndexCastOp>(loc, dstType, src);
304301
return builder.create<arith::IndexCastUIOp>(loc, dstType, src);
302+
}
305303

306304
auto srcInt = cast<IntegerType>(srcElemType);
307305
auto dstInt = cast<IntegerType>(dstElemType);
308306
if (dstInt.getWidth() < srcInt.getWidth())
309307
return builder.create<arith::TruncIOp>(loc, dstType, src);
310308

309+
if (castKind == CastKind::Signed)
310+
return builder.create<arith::ExtSIOp>(loc, dstType, src);
311311
return builder.create<arith::ExtUIOp>(loc, dstType, src);
312312
}
313313

@@ -319,36 +319,47 @@ struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
319319
using OpTraitRewritePattern::OpTraitRewritePattern;
320320
LogicalResult matchAndRewrite(Operation *op,
321321
PatternRewriter &rewriter) const override {
322-
std::optional<ConstantIntRanges> range =
323-
getOperandsRange(solver, op->getResults());
324-
if (!range)
325-
return failure();
322+
if (op->getNumResults() == 0)
323+
return rewriter.notifyMatchFailure(op, "can't narrow resultless op");
324+
325+
SmallVector<ConstantIntRanges> ranges;
326+
if (failed(collectRanges(solver, op->getOperands(), ranges)))
327+
return rewriter.notifyMatchFailure(op, "input without specified range");
328+
if (failed(collectRanges(solver, op->getResults(), ranges)))
329+
return rewriter.notifyMatchFailure(op, "output without specified range");
330+
331+
Type srcType = op->getResult(0).getType();
332+
if (!llvm::all_equal(op->getResultTypes()))
333+
return rewriter.notifyMatchFailure(op, "mismatched result types");
334+
if (op->getNumOperands() == 0 ||
335+
!llvm::all_of(op->getOperandTypes(),
336+
[=](Type t) { return t == srcType; }))
337+
return rewriter.notifyMatchFailure(
338+
op, "no operands or operand types don't match result type");
326339

327340
for (unsigned targetBitwidth : targetBitwidths) {
328-
if (failed(checkElementwiseOpType(op, targetBitwidth)))
329-
continue;
330-
331-
Type srcType = op->getResult(0).getType();
332-
333-
// We are truncating op args to the desired bitwidth before the op and
334-
// then extending op results back to the original width after. extui and
335-
// exti will produce different results for negative values, so limit
336-
// signed range to non-negative values.
337-
auto smin = APInt::getZero(targetBitwidth);
338-
auto smax = APInt::getSignedMaxValue(targetBitwidth);
339-
auto umin = APInt::getMinValue(targetBitwidth);
340-
auto umax = APInt::getMaxValue(targetBitwidth);
341-
if (failed(checkRange(*range, smin, smax, umin, umax)))
341+
CastKind castKind = CastKind::Both;
342+
for (const ConstantIntRanges &range : ranges) {
343+
castKind = mergeCastKinds(castKind,
344+
checkTruncatability(range, targetBitwidth));
345+
if (castKind == CastKind::None)
346+
break;
347+
}
348+
if (castKind == CastKind::None)
342349
continue;
343-
344350
Type targetType = getTargetType(srcType, targetBitwidth);
345351
if (targetType == srcType)
346352
continue;
347353

348354
Location loc = op->getLoc();
349355
IRMapping mapping;
350-
for (Value arg : op->getOperands()) {
351-
Value newArg = doCast(rewriter, loc, arg, targetType);
356+
for (auto [arg, argRange] : llvm::zip_first(op->getOperands(), ranges)) {
357+
CastKind argCastKind = castKind;
358+
// When dealing with `index` values, preserve non-negativity in the
359+
// index_casts since we can't recover this in unsigned when equivalent.
360+
if (argCastKind == CastKind::Signed && argRange.smin().isNonNegative())
361+
argCastKind = CastKind::Both;
362+
Value newArg = doCast(rewriter, loc, arg, targetType, argCastKind);
352363
mapping.map(arg, newArg);
353364
}
354365

@@ -359,8 +370,12 @@ struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
359370
}
360371
});
361372
SmallVector<Value> newResults;
362-
for (Value res : newOp->getResults())
363-
newResults.emplace_back(doCast(rewriter, loc, res, srcType));
373+
for (auto [newRes, oldRes] :
374+
llvm::zip_equal(newOp->getResults(), op->getResults())) {
375+
Value castBack = doCast(rewriter, loc, newRes, srcType, castKind);
376+
copyIntegerRange(solver, oldRes, castBack);
377+
newResults.push_back(castBack);
378+
}
364379

365380
rewriter.replaceOp(op, newResults);
366381
return success();
@@ -382,21 +397,19 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
382397
Value lhs = op.getLhs();
383398
Value rhs = op.getRhs();
384399

385-
std::optional<ConstantIntRanges> range =
386-
getOperandsRange(solver, {lhs, rhs});
387-
if (!range)
400+
SmallVector<ConstantIntRanges> ranges;
401+
if (failed(collectRanges(solver, op.getOperands(), ranges)))
388402
return failure();
403+
const ConstantIntRanges &lhsRange = ranges[0];
404+
const ConstantIntRanges &rhsRange = ranges[1];
389405

406+
Type srcType = lhs.getType();
390407
for (unsigned targetBitwidth : targetBitwidths) {
391-
Type srcType = lhs.getType();
392-
if (failed(checkIntType(srcType, targetBitwidth)))
393-
continue;
394-
395-
auto smin = APInt::getSignedMinValue(targetBitwidth);
396-
auto smax = APInt::getSignedMaxValue(targetBitwidth);
397-
auto umin = APInt::getMinValue(targetBitwidth);
398-
auto umax = APInt::getMaxValue(targetBitwidth);
399-
if (failed(checkRange(*range, smin, smax, umin, umax)))
408+
CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth);
409+
CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth);
410+
CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind);
411+
// Note: this includes target width > src width.
412+
if (castKind == CastKind::None)
400413
continue;
401414

402415
Type targetType = getTargetType(srcType, targetBitwidth);
@@ -405,12 +418,13 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
405418

406419
Location loc = op->getLoc();
407420
IRMapping mapping;
408-
for (Value arg : op->getOperands()) {
409-
Value newArg = doCast(rewriter, loc, arg, targetType);
410-
mapping.map(arg, newArg);
411-
}
421+
Value lhsCast = doCast(rewriter, loc, lhs, targetType, lhsCastKind);
422+
Value rhsCast = doCast(rewriter, loc, rhs, targetType, rhsCastKind);
423+
mapping.map(lhs, lhsCast);
424+
mapping.map(rhs, rhsCast);
412425

413426
Operation *newOp = rewriter.clone(*op, mapping);
427+
copyIntegerRange(solver, op.getResult(), newOp->getResult(0));
414428
rewriter.replaceOp(op, newOp->getResults());
415429
return success();
416430
}
@@ -425,19 +439,23 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
425439
/// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg
426440
/// This pattern assumes all passed `targetBitwidths` are not wider than index
427441
/// type.
428-
struct FoldIndexCastChain final : OpRewritePattern<arith::IndexCastUIOp> {
442+
template <typename CastOp>
443+
struct FoldIndexCastChain final : OpRewritePattern<CastOp> {
429444
FoldIndexCastChain(MLIRContext *context, ArrayRef<unsigned> target)
430-
: OpRewritePattern(context), targetBitwidths(target) {}
445+
: OpRewritePattern<CastOp>(context), targetBitwidths(target) {}
431446

432-
LogicalResult matchAndRewrite(arith::IndexCastUIOp op,
447+
LogicalResult matchAndRewrite(CastOp op,
433448
PatternRewriter &rewriter) const override {
434-
auto srcOp = op.getIn().getDefiningOp<arith::IndexCastUIOp>();
449+
auto srcOp = op.getIn().template getDefiningOp<CastOp>();
435450
if (!srcOp)
436-
return failure();
451+
return rewriter.notifyMatchFailure(op, "doesn't come from an index cast");
437452

438453
Value src = srcOp.getIn();
439454
if (src.getType() != op.getType())
440-
return failure();
455+
return rewriter.notifyMatchFailure(op, "outer types don't match");
456+
457+
if (!srcOp.getType().isIndex())
458+
return rewriter.notifyMatchFailure(op, "intermediate type isn't index");
441459

442460
auto intType = dyn_cast<IntegerType>(op.getType());
443461
if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth()))
@@ -517,7 +535,9 @@ void mlir::arith::populateIntRangeNarrowingPatterns(
517535
ArrayRef<unsigned> bitwidthsSupported) {
518536
patterns.add<NarrowElementwise, NarrowCmpI>(patterns.getContext(), solver,
519537
bitwidthsSupported);
520-
patterns.add<FoldIndexCastChain>(patterns.getContext(), bitwidthsSupported);
538+
patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>,
539+
FoldIndexCastChain<arith::IndexCastOp>>(patterns.getContext(),
540+
bitwidthsSupported);
521541
}
522542

523543
std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {

0 commit comments

Comments
 (0)