@@ -343,7 +343,7 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
343
343
}
344
344
345
345
KnownBits KnownBits::lshr (const KnownBits &LHS, const KnownBits &RHS,
346
- bool ShAmtNonZero, bool /* Exact*/ ) {
346
+ bool ShAmtNonZero, bool Exact) {
347
347
unsigned BitWidth = LHS.getBitWidth ();
348
348
auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
349
349
KnownBits Known = LHS;
@@ -367,6 +367,18 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
367
367
// Find the common bits from all possible shifts.
368
368
APInt MaxValue = RHS.getMaxValue ();
369
369
unsigned MaxShiftAmount = getMaxShiftAmount (MaxValue, BitWidth);
370
+
371
+ // If exact, bound MaxShiftAmount to first known 1 in LHS.
372
+ if (Exact) {
373
+ unsigned FirstOne = LHS.countMaxTrailingZeros ();
374
+ if (FirstOne < MinShiftAmount) {
375
+ // Always poison. Return zero because we don't like returning conflict.
376
+ Known.setAllZero ();
377
+ return Known;
378
+ }
379
+ MaxShiftAmount = std::min (MaxShiftAmount, FirstOne);
380
+ }
381
+
370
382
unsigned ShiftAmtZeroMask = RHS.Zero .zextOrTrunc (32 ).getZExtValue ();
371
383
unsigned ShiftAmtOneMask = RHS.One .zextOrTrunc (32 ).getZExtValue ();
372
384
Known.Zero .setAllBits ();
@@ -389,7 +401,7 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
389
401
}
390
402
391
403
KnownBits KnownBits::ashr (const KnownBits &LHS, const KnownBits &RHS,
392
- bool ShAmtNonZero, bool /* Exact*/ ) {
404
+ bool ShAmtNonZero, bool Exact) {
393
405
unsigned BitWidth = LHS.getBitWidth ();
394
406
auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
395
407
KnownBits Known = LHS;
@@ -415,6 +427,18 @@ KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
415
427
// Find the common bits from all possible shifts.
416
428
APInt MaxValue = RHS.getMaxValue ();
417
429
unsigned MaxShiftAmount = getMaxShiftAmount (MaxValue, BitWidth);
430
+
431
+ // If exact, bound MaxShiftAmount to first known 1 in LHS.
432
+ if (Exact) {
433
+ unsigned FirstOne = LHS.countMaxTrailingZeros ();
434
+ if (FirstOne < MinShiftAmount) {
435
+ // Always poison. Return zero because we don't like returning conflict.
436
+ Known.setAllZero ();
437
+ return Known;
438
+ }
439
+ MaxShiftAmount = std::min (MaxShiftAmount, FirstOne);
440
+ }
441
+
418
442
unsigned ShiftAmtZeroMask = RHS.Zero .zextOrTrunc (32 ).getZExtValue ();
419
443
unsigned ShiftAmtOneMask = RHS.One .zextOrTrunc (32 ).getZExtValue ();
420
444
Known.Zero .setAllBits ();
0 commit comments