@@ -46,6 +46,17 @@ static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
46
46
return inferredRange.getConstantValue ();
47
47
}
48
48
49
+ static void copyIntegerRange (DataFlowSolver &solver, Value oldVal,
50
+ Value newVal) {
51
+ assert (oldVal.getType () == newVal.getType () &&
52
+ " Can't copy integer ranges between different types" );
53
+ auto *oldState = solver.lookupState <IntegerValueRangeLattice>(oldVal);
54
+ if (!oldState)
55
+ return ;
56
+ (void )solver.getOrCreateState <IntegerValueRangeLattice>(newVal)->join (
57
+ *oldState);
58
+ }
59
+
49
60
// / Patterned after SCCP
50
61
static LogicalResult maybeReplaceWithConstant (DataFlowSolver &solver,
51
62
PatternRewriter &rewriter,
@@ -80,6 +91,7 @@ static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
80
91
if (!constOp)
81
92
return failure ();
82
93
94
+ copyIntegerRange (solver, value, constOp->getResult (0 ));
83
95
rewriter.replaceAllUsesWith (value, constOp->getResult (0 ));
84
96
return success ();
85
97
}
@@ -195,56 +207,21 @@ struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
195
207
DataFlowSolver &solver;
196
208
};
197
209
198
- // / Check if `type` is index or integer type with `getWidth() > targetBitwidth`.
199
- static LogicalResult checkIntType (Type type, unsigned targetBitwidth) {
200
- Type elemType = getElementTypeOrSelf (type);
201
- if (isa<IndexType>(elemType))
202
- return success ();
203
-
204
- if (auto intType = dyn_cast<IntegerType>(elemType))
205
- if (intType.getWidth () > targetBitwidth)
206
- return success ();
207
-
208
- return failure ();
209
- }
210
-
211
- // / Check if op have same type for all operands and results and this type
212
- // / is suitable for truncation.
213
- static LogicalResult checkElementwiseOpType (Operation *op,
214
- unsigned targetBitwidth) {
215
- if (op->getNumOperands () == 0 || op->getNumResults () == 0 )
216
- return failure ();
217
-
218
- Type type;
219
- for (Value val : llvm::concat<Value>(op->getOperands (), op->getResults ())) {
220
- if (!type) {
221
- type = val.getType ();
222
- continue ;
223
- }
224
-
225
- if (type != val.getType ())
226
- return failure ();
227
- }
228
-
229
- return checkIntType (type, targetBitwidth);
230
- }
231
-
232
- // / Return union of all operands values ranges.
233
- static std::optional<ConstantIntRanges> getOperandsRange (DataFlowSolver &solver,
234
- ValueRange operands) {
235
- std::optional<ConstantIntRanges> ret;
236
- for (Value value : operands) {
210
+ // / Gather ranges for all the values in `values`. Appends to the existing
211
+ // / vector.
212
+ static LogicalResult collectRanges (DataFlowSolver &solver, ValueRange values,
213
+ SmallVectorImpl<ConstantIntRanges> &ranges) {
214
+ for (Value val : values) {
237
215
auto *maybeInferredRange =
238
- solver.lookupState <IntegerValueRangeLattice>(value );
216
+ solver.lookupState <IntegerValueRangeLattice>(val );
239
217
if (!maybeInferredRange || maybeInferredRange->getValue ().isUninitialized ())
240
- return std::nullopt ;
218
+ return failure () ;
241
219
242
220
const ConstantIntRanges &inferredRange =
243
221
maybeInferredRange->getValue ().getValue ();
244
-
245
- ret = (ret ? ret->rangeUnion (inferredRange) : inferredRange);
222
+ ranges.push_back (inferredRange);
246
223
}
247
- return ret ;
224
+ return success () ;
248
225
}
249
226
250
227
// / Return int type truncated to `targetBitwidth`. If `srcType` is shaped,
@@ -258,56 +235,79 @@ static Type getTargetType(Type srcType, unsigned targetBitwidth) {
258
235
return dstType;
259
236
}
260
237
261
- // / Check provided `range` is inside `smin, smax, umin, umax` bounds.
262
- static LogicalResult checkRange (const ConstantIntRanges &range, APInt smin,
263
- APInt smax, APInt umin, APInt umax) {
264
- auto sge = [](APInt val1, APInt val2) -> bool {
265
- unsigned width = std::max (val1.getBitWidth (), val2.getBitWidth ());
266
- val1 = val1.sext (width);
267
- val2 = val2.sext (width);
268
- return val1.sge (val2);
269
- };
270
- auto sle = [](APInt val1, APInt val2) -> bool {
271
- unsigned width = std::max (val1.getBitWidth (), val2.getBitWidth ());
272
- val1 = val1.sext (width);
273
- val2 = val2.sext (width);
274
- return val1.sle (val2);
275
- };
276
- auto uge = [](APInt val1, APInt val2) -> bool {
277
- unsigned width = std::max (val1.getBitWidth (), val2.getBitWidth ());
278
- val1 = val1.zext (width);
279
- val2 = val2.zext (width);
280
- return val1.uge (val2);
281
- };
282
- auto ule = [](APInt val1, APInt val2) -> bool {
283
- unsigned width = std::max (val1.getBitWidth (), val2.getBitWidth ());
284
- val1 = val1.zext (width);
285
- val2 = val2.zext (width);
286
- return val1.ule (val2);
287
- };
288
- return success (sge (range.smin (), smin) && sle (range.smax (), smax) &&
289
- uge (range.umin (), umin) && ule (range.umax (), umax));
238
+ namespace {
239
+ // Enum for tracking which type of truncation should be performed
240
+ // to narrow an operation, if any.
241
+ enum class CastKind : uint8_t { None, Signed, Unsigned, Both };
242
+ } // namespace
243
+
244
+ // / If the values within `range` can be represented using only `width` bits,
245
+ // / return the kind of truncation needed to preserve that property.
246
+ // /
247
+ // / This check relies on the fact that the signed and unsigned ranges are both
248
+ // / always correct, but that one might be an approximation of the other,
249
+ // / so we want to use the correct truncation operation.
250
+ static CastKind checkTruncatability (const ConstantIntRanges &range,
251
+ unsigned targetWidth) {
252
+ unsigned srcWidth = range.smin ().getBitWidth ();
253
+ if (srcWidth <= targetWidth)
254
+ return CastKind::None;
255
+ unsigned removedWidth = srcWidth - targetWidth;
256
+ // The sign bits need to extend into the sign bit of the target width. For
257
+ // example, if we're truncating 64 bits to 32, we need 64 - 32 + 1 = 33 sign
258
+ // bits.
259
+ bool canTruncateSigned =
260
+ range.smin ().getNumSignBits () >= (removedWidth + 1 ) &&
261
+ range.smax ().getNumSignBits () >= (removedWidth + 1 );
262
+ bool canTruncateUnsigned = range.umin ().countLeadingZeros () >= removedWidth &&
263
+ range.umax ().countLeadingZeros () >= removedWidth;
264
+ if (canTruncateSigned && canTruncateUnsigned)
265
+ return CastKind::Both;
266
+ if (canTruncateSigned)
267
+ return CastKind::Signed;
268
+ if (canTruncateUnsigned)
269
+ return CastKind::Unsigned;
270
+ return CastKind::None;
271
+ }
272
+
273
+ static CastKind mergeCastKinds (CastKind lhs, CastKind rhs) {
274
+ if (lhs == CastKind::None || rhs == CastKind::None)
275
+ return CastKind::None;
276
+ if (lhs == CastKind::Both)
277
+ return rhs;
278
+ if (rhs == CastKind::Both)
279
+ return lhs;
280
+ if (lhs == rhs)
281
+ return lhs;
282
+ return CastKind::None;
290
283
}
291
284
292
- static Value doCast (OpBuilder &builder, Location loc, Value src, Type dstType) {
285
+ static Value doCast (OpBuilder &builder, Location loc, Value src, Type dstType,
286
+ CastKind castKind) {
293
287
Type srcType = src.getType ();
294
288
assert (isa<VectorType>(srcType) == isa<VectorType>(dstType) &&
295
289
" Mixing vector and non-vector types" );
290
+ assert (castKind != CastKind::None && " Can't cast when casting isn't allowed" );
296
291
Type srcElemType = getElementTypeOrSelf (srcType);
297
292
Type dstElemType = getElementTypeOrSelf (dstType);
298
293
assert (srcElemType.isIntOrIndex () && " Invalid src type" );
299
294
assert (dstElemType.isIntOrIndex () && " Invalid dst type" );
300
295
if (srcType == dstType)
301
296
return src;
302
297
303
- if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType))
298
+ if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) {
299
+ if (castKind == CastKind::Signed)
300
+ return builder.create <arith::IndexCastOp>(loc, dstType, src);
304
301
return builder.create <arith::IndexCastUIOp>(loc, dstType, src);
302
+ }
305
303
306
304
auto srcInt = cast<IntegerType>(srcElemType);
307
305
auto dstInt = cast<IntegerType>(dstElemType);
308
306
if (dstInt.getWidth () < srcInt.getWidth ())
309
307
return builder.create <arith::TruncIOp>(loc, dstType, src);
310
308
309
+ if (castKind == CastKind::Signed)
310
+ return builder.create <arith::ExtSIOp>(loc, dstType, src);
311
311
return builder.create <arith::ExtUIOp>(loc, dstType, src);
312
312
}
313
313
@@ -319,36 +319,47 @@ struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
319
319
using OpTraitRewritePattern::OpTraitRewritePattern;
320
320
LogicalResult matchAndRewrite (Operation *op,
321
321
PatternRewriter &rewriter) const override {
322
- std::optional<ConstantIntRanges> range =
323
- getOperandsRange (solver, op->getResults ());
324
- if (!range)
325
- return failure ();
322
+ if (op->getNumResults () == 0 )
323
+ return rewriter.notifyMatchFailure (op, " can't narrow resultless op" );
324
+
325
+ SmallVector<ConstantIntRanges> ranges;
326
+ if (failed (collectRanges (solver, op->getOperands (), ranges)))
327
+ return rewriter.notifyMatchFailure (op, " input without specified range" );
328
+ if (failed (collectRanges (solver, op->getResults (), ranges)))
329
+ return rewriter.notifyMatchFailure (op, " output without specified range" );
330
+
331
+ Type srcType = op->getResult (0 ).getType ();
332
+ if (!llvm::all_equal (op->getResultTypes ()))
333
+ return rewriter.notifyMatchFailure (op, " mismatched result types" );
334
+ if (op->getNumOperands () == 0 ||
335
+ !llvm::all_of (op->getOperandTypes (),
336
+ [=](Type t) { return t == srcType; }))
337
+ return rewriter.notifyMatchFailure (
338
+ op, " no operands or operand types don't match result type" );
326
339
327
340
for (unsigned targetBitwidth : targetBitwidths) {
328
- if (failed (checkElementwiseOpType (op, targetBitwidth)))
329
- continue ;
330
-
331
- Type srcType = op->getResult (0 ).getType ();
332
-
333
- // We are truncating op args to the desired bitwidth before the op and
334
- // then extending op results back to the original width after. extui and
335
- // exti will produce different results for negative values, so limit
336
- // signed range to non-negative values.
337
- auto smin = APInt::getZero (targetBitwidth);
338
- auto smax = APInt::getSignedMaxValue (targetBitwidth);
339
- auto umin = APInt::getMinValue (targetBitwidth);
340
- auto umax = APInt::getMaxValue (targetBitwidth);
341
- if (failed (checkRange (*range, smin, smax, umin, umax)))
341
+ CastKind castKind = CastKind::Both;
342
+ for (const ConstantIntRanges &range : ranges) {
343
+ castKind = mergeCastKinds (castKind,
344
+ checkTruncatability (range, targetBitwidth));
345
+ if (castKind == CastKind::None)
346
+ break ;
347
+ }
348
+ if (castKind == CastKind::None)
342
349
continue ;
343
-
344
350
Type targetType = getTargetType (srcType, targetBitwidth);
345
351
if (targetType == srcType)
346
352
continue ;
347
353
348
354
Location loc = op->getLoc ();
349
355
IRMapping mapping;
350
- for (Value arg : op->getOperands ()) {
351
- Value newArg = doCast (rewriter, loc, arg, targetType);
356
+ for (auto [arg, argRange] : llvm::zip_first (op->getOperands (), ranges)) {
357
+ CastKind argCastKind = castKind;
358
+ // When dealing with `index` values, preserve non-negativity in the
359
+ // index_casts since we can't recover this in unsigned when equivalent.
360
+ if (argCastKind == CastKind::Signed && argRange.smin ().isNonNegative ())
361
+ argCastKind = CastKind::Both;
362
+ Value newArg = doCast (rewriter, loc, arg, targetType, argCastKind);
352
363
mapping.map (arg, newArg);
353
364
}
354
365
@@ -359,8 +370,12 @@ struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
359
370
}
360
371
});
361
372
SmallVector<Value> newResults;
362
- for (Value res : newOp->getResults ())
363
- newResults.emplace_back (doCast (rewriter, loc, res, srcType));
373
+ for (auto [newRes, oldRes] :
374
+ llvm::zip_equal (newOp->getResults (), op->getResults ())) {
375
+ Value castBack = doCast (rewriter, loc, newRes, srcType, castKind);
376
+ copyIntegerRange (solver, oldRes, castBack);
377
+ newResults.push_back (castBack);
378
+ }
364
379
365
380
rewriter.replaceOp (op, newResults);
366
381
return success ();
@@ -382,21 +397,19 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
382
397
Value lhs = op.getLhs ();
383
398
Value rhs = op.getRhs ();
384
399
385
- std::optional<ConstantIntRanges> range =
386
- getOperandsRange (solver, {lhs, rhs});
387
- if (!range)
400
+ SmallVector<ConstantIntRanges> ranges;
401
+ if (failed (collectRanges (solver, op.getOperands (), ranges)))
388
402
return failure ();
403
+ const ConstantIntRanges &lhsRange = ranges[0 ];
404
+ const ConstantIntRanges &rhsRange = ranges[1 ];
389
405
406
+ Type srcType = lhs.getType ();
390
407
for (unsigned targetBitwidth : targetBitwidths) {
391
- Type srcType = lhs.getType ();
392
- if (failed (checkIntType (srcType, targetBitwidth)))
393
- continue ;
394
-
395
- auto smin = APInt::getSignedMinValue (targetBitwidth);
396
- auto smax = APInt::getSignedMaxValue (targetBitwidth);
397
- auto umin = APInt::getMinValue (targetBitwidth);
398
- auto umax = APInt::getMaxValue (targetBitwidth);
399
- if (failed (checkRange (*range, smin, smax, umin, umax)))
408
+ CastKind lhsCastKind = checkTruncatability (lhsRange, targetBitwidth);
409
+ CastKind rhsCastKind = checkTruncatability (rhsRange, targetBitwidth);
410
+ CastKind castKind = mergeCastKinds (lhsCastKind, rhsCastKind);
411
+ // Note: this includes target width > src width.
412
+ if (castKind == CastKind::None)
400
413
continue ;
401
414
402
415
Type targetType = getTargetType (srcType, targetBitwidth);
@@ -405,12 +418,13 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
405
418
406
419
Location loc = op->getLoc ();
407
420
IRMapping mapping;
408
- for ( Value arg : op-> getOperands ()) {
409
- Value newArg = doCast (rewriter, loc, arg , targetType);
410
- mapping.map (arg, newArg );
411
- }
421
+ Value lhsCast = doCast (rewriter, loc, lhs, targetType, lhsCastKind);
422
+ Value rhsCast = doCast (rewriter, loc, rhs , targetType, rhsCastKind );
423
+ mapping.map (lhs, lhsCast );
424
+ mapping. map (rhs, rhsCast);
412
425
413
426
Operation *newOp = rewriter.clone (*op, mapping);
427
+ copyIntegerRange (solver, op.getResult (), newOp->getResult (0 ));
414
428
rewriter.replaceOp (op, newOp->getResults ());
415
429
return success ();
416
430
}
@@ -425,19 +439,23 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
425
439
// / Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg
426
440
// / This pattern assumes all passed `targetBitwidths` are not wider than index
427
441
// / type.
428
- struct FoldIndexCastChain final : OpRewritePattern<arith::IndexCastUIOp> {
442
+ template <typename CastOp>
443
+ struct FoldIndexCastChain final : OpRewritePattern<CastOp> {
429
444
FoldIndexCastChain (MLIRContext *context, ArrayRef<unsigned > target)
430
- : OpRewritePattern(context), targetBitwidths(target) {}
445
+ : OpRewritePattern<CastOp> (context), targetBitwidths(target) {}
431
446
432
- LogicalResult matchAndRewrite (arith::IndexCastUIOp op,
447
+ LogicalResult matchAndRewrite (CastOp op,
433
448
PatternRewriter &rewriter) const override {
434
- auto srcOp = op.getIn ().getDefiningOp <arith::IndexCastUIOp >();
449
+ auto srcOp = op.getIn ().template getDefiningOp <CastOp >();
435
450
if (!srcOp)
436
- return failure ( );
451
+ return rewriter. notifyMatchFailure (op, " doesn't come from an index cast " );
437
452
438
453
Value src = srcOp.getIn ();
439
454
if (src.getType () != op.getType ())
440
- return failure ();
455
+ return rewriter.notifyMatchFailure (op, " outer types don't match" );
456
+
457
+ if (!srcOp.getType ().isIndex ())
458
+ return rewriter.notifyMatchFailure (op, " intermediate type isn't index" );
441
459
442
460
auto intType = dyn_cast<IntegerType>(op.getType ());
443
461
if (!intType || !llvm::is_contained (targetBitwidths, intType.getWidth ()))
@@ -517,7 +535,9 @@ void mlir::arith::populateIntRangeNarrowingPatterns(
517
535
ArrayRef<unsigned > bitwidthsSupported) {
518
536
patterns.add <NarrowElementwise, NarrowCmpI>(patterns.getContext (), solver,
519
537
bitwidthsSupported);
520
- patterns.add <FoldIndexCastChain>(patterns.getContext (), bitwidthsSupported);
538
+ patterns.add <FoldIndexCastChain<arith::IndexCastUIOp>,
539
+ FoldIndexCastChain<arith::IndexCastOp>>(patterns.getContext (),
540
+ bitwidthsSupported);
521
541
}
522
542
523
543
std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass () {
0 commit comments