@@ -182,31 +182,34 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) {
182
182
// No matter the shift amount, the trailing zeros will stay zero.
183
183
unsigned MinTrailingZeros = LHS.countMinTrailingZeros ();
184
184
185
- // Minimum shift amount low bits are known zero.
186
185
APInt MinShiftAmount = RHS.getMinValue ();
187
- if (MinShiftAmount.ult (BitWidth)) {
188
- MinTrailingZeros += MinShiftAmount.getZExtValue ();
189
- MinTrailingZeros = std::min (MinTrailingZeros, BitWidth);
190
- }
186
+ if (MinShiftAmount.uge (BitWidth))
187
+ // Always poison. Return unknown because we don't like returning conflict.
188
+ return Known;
189
+
190
+ // Minimum shift amount low bits are known zero.
191
+ MinTrailingZeros += MinShiftAmount.getZExtValue ();
192
+ MinTrailingZeros = std::min (MinTrailingZeros, BitWidth);
191
193
192
194
// If the maximum shift is in range, then find the common bits from all
193
195
// possible shifts.
194
196
APInt MaxShiftAmount = RHS.getMaxValue ();
195
- if (MaxShiftAmount. ult (BitWidth) && !LHS.isUnknown ()) {
197
+ if (!LHS.isUnknown ()) {
196
198
uint64_t ShiftAmtZeroMask = (~RHS.Zero ).getZExtValue ();
197
199
uint64_t ShiftAmtOneMask = RHS.One .getZExtValue ();
198
200
assert (MinShiftAmount.ult (MaxShiftAmount) && " Illegal shift range" );
199
201
Known.Zero .setAllBits ();
200
202
Known.One .setAllBits ();
201
203
for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue (),
202
- MaxShiftAmt = MaxShiftAmount.getZExtValue ( );
204
+ MaxShiftAmt = MaxShiftAmount.getLimitedValue (BitWidth - 1 );
203
205
ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
204
206
// Skip if the shift amount is impossible.
205
207
if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
206
208
(ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
207
209
continue ;
208
210
KnownBits SpecificShift;
209
211
SpecificShift.Zero = LHS.Zero << ShiftAmt;
212
+ SpecificShift.Zero .setLowBits (ShiftAmt);
210
213
SpecificShift.One = LHS.One << ShiftAmt;
211
214
Known = KnownBits::commonBits (Known, SpecificShift);
212
215
if (Known.isUnknown ())
@@ -237,29 +240,32 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS) {
237
240
238
241
// Minimum shift amount high bits are known zero.
239
242
APInt MinShiftAmount = RHS.getMinValue ();
240
- if (MinShiftAmount.ult (BitWidth)) {
241
- MinLeadingZeros += MinShiftAmount.getZExtValue ();
242
- MinLeadingZeros = std::min (MinLeadingZeros, BitWidth);
243
- }
243
+ if (MinShiftAmount.uge (BitWidth))
244
+ // Always poison. Return unknown because we don't like returning conflict.
245
+ return Known;
246
+
247
+ MinLeadingZeros += MinShiftAmount.getZExtValue ();
248
+ MinLeadingZeros = std::min (MinLeadingZeros, BitWidth);
244
249
245
250
// If the maximum shift is in range, then find the common bits from all
246
251
// possible shifts.
247
252
APInt MaxShiftAmount = RHS.getMaxValue ();
248
- if (MaxShiftAmount. ult (BitWidth) && !LHS.isUnknown ()) {
253
+ if (!LHS.isUnknown ()) {
249
254
uint64_t ShiftAmtZeroMask = (~RHS.Zero ).getZExtValue ();
250
255
uint64_t ShiftAmtOneMask = RHS.One .getZExtValue ();
251
256
assert (MinShiftAmount.ult (MaxShiftAmount) && " Illegal shift range" );
252
257
Known.Zero .setAllBits ();
253
258
Known.One .setAllBits ();
254
259
for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue (),
255
- MaxShiftAmt = MaxShiftAmount.getZExtValue ( );
260
+ MaxShiftAmt = MaxShiftAmount.getLimitedValue (BitWidth - 1 );
256
261
ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
257
262
// Skip if the shift amount is impossible.
258
263
if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
259
264
(ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
260
265
continue ;
261
266
KnownBits SpecificShift = LHS;
262
267
SpecificShift.Zero .lshrInPlace (ShiftAmt);
268
+ SpecificShift.Zero .setHighBits (ShiftAmt);
263
269
SpecificShift.One .lshrInPlace (ShiftAmt);
264
270
Known = KnownBits::commonBits (Known, SpecificShift);
265
271
if (Known.isUnknown ())
@@ -289,28 +295,30 @@ KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS) {
289
295
290
296
// Minimum shift amount high bits are known sign bits.
291
297
APInt MinShiftAmount = RHS.getMinValue ();
292
- if (MinShiftAmount.ult (BitWidth)) {
293
- if (MinLeadingZeros) {
294
- MinLeadingZeros += MinShiftAmount.getZExtValue ();
295
- MinLeadingZeros = std::min (MinLeadingZeros, BitWidth);
296
- }
297
- if (MinLeadingOnes) {
298
- MinLeadingOnes += MinShiftAmount.getZExtValue ();
299
- MinLeadingOnes = std::min (MinLeadingOnes, BitWidth);
300
- }
298
+ if (MinShiftAmount.uge (BitWidth))
299
+ // Always poison. Return unknown because we don't like returning conflict.
300
+ return Known;
301
+
302
+ if (MinLeadingZeros) {
303
+ MinLeadingZeros += MinShiftAmount.getZExtValue ();
304
+ MinLeadingZeros = std::min (MinLeadingZeros, BitWidth);
305
+ }
306
+ if (MinLeadingOnes) {
307
+ MinLeadingOnes += MinShiftAmount.getZExtValue ();
308
+ MinLeadingOnes = std::min (MinLeadingOnes, BitWidth);
301
309
}
302
310
303
311
// If the maximum shift is in range, then find the common bits from all
304
312
// possible shifts.
305
313
APInt MaxShiftAmount = RHS.getMaxValue ();
306
- if (MaxShiftAmount. ult (BitWidth) && !LHS.isUnknown ()) {
314
+ if (!LHS.isUnknown ()) {
307
315
uint64_t ShiftAmtZeroMask = (~RHS.Zero ).getZExtValue ();
308
316
uint64_t ShiftAmtOneMask = RHS.One .getZExtValue ();
309
317
assert (MinShiftAmount.ult (MaxShiftAmount) && " Illegal shift range" );
310
318
Known.Zero .setAllBits ();
311
319
Known.One .setAllBits ();
312
320
for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue (),
313
- MaxShiftAmt = MaxShiftAmount.getZExtValue ( );
321
+ MaxShiftAmt = MaxShiftAmount.getLimitedValue (BitWidth - 1 );
314
322
ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
315
323
// Skip if the shift amount is impossible.
316
324
if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
0 commit comments