Skip to content

Commit 9d73a8b

Browse files
committed
[KnownBits] Make shl/lshr/ashr implementations optimal
The implementations for shifts were suboptimal in the case where the max shift amount was >= bitwidth. In that case we should still use the usual code clamped to BitWidth-1 rather than just giving up entirely. Additionally, there was an implementation bug where the known zero bits for the individual shift amounts were not set in the shl/lshr implementations. I think after these changes, we'll be able to drop some of the code in ValueTracking which *also* evaluates all possible shift amounts and has been papering over this issue. For the "all poison" case I've opted to return an unknown value for now. It would be better to return zero, but this has fairly substantial test fallout, so I figured it's best to not mix it into this change. (The "correct" return value would be a conflict, but given that a lot of our APIs assert conflict-freedom, that's probably not the best idea to actually return.) Differential Revision: https://reviews.llvm.org/D150587
1 parent d187cee commit 9d73a8b

File tree

4 files changed

+34
-36
lines changed

4 files changed

+34
-36
lines changed

llvm/lib/Support/KnownBits.cpp

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -182,31 +182,34 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) {
182182
// No matter the shift amount, the trailing zeros will stay zero.
183183
unsigned MinTrailingZeros = LHS.countMinTrailingZeros();
184184

185-
// Minimum shift amount low bits are known zero.
186185
APInt MinShiftAmount = RHS.getMinValue();
187-
if (MinShiftAmount.ult(BitWidth)) {
188-
MinTrailingZeros += MinShiftAmount.getZExtValue();
189-
MinTrailingZeros = std::min(MinTrailingZeros, BitWidth);
190-
}
186+
if (MinShiftAmount.uge(BitWidth))
187+
// Always poison. Return unknown because we don't like returning conflict.
188+
return Known;
189+
190+
// Minimum shift amount low bits are known zero.
191+
MinTrailingZeros += MinShiftAmount.getZExtValue();
192+
MinTrailingZeros = std::min(MinTrailingZeros, BitWidth);
191193

192194
// If the maximum shift is in range, then find the common bits from all
193195
// possible shifts.
194196
APInt MaxShiftAmount = RHS.getMaxValue();
195-
if (MaxShiftAmount.ult(BitWidth) && !LHS.isUnknown()) {
197+
if (!LHS.isUnknown()) {
196198
uint64_t ShiftAmtZeroMask = (~RHS.Zero).getZExtValue();
197199
uint64_t ShiftAmtOneMask = RHS.One.getZExtValue();
198200
assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
199201
Known.Zero.setAllBits();
200202
Known.One.setAllBits();
201203
for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
202-
MaxShiftAmt = MaxShiftAmount.getZExtValue();
204+
MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1);
203205
ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
204206
// Skip if the shift amount is impossible.
205207
if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
206208
(ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
207209
continue;
208210
KnownBits SpecificShift;
209211
SpecificShift.Zero = LHS.Zero << ShiftAmt;
212+
SpecificShift.Zero.setLowBits(ShiftAmt);
210213
SpecificShift.One = LHS.One << ShiftAmt;
211214
Known = KnownBits::commonBits(Known, SpecificShift);
212215
if (Known.isUnknown())
@@ -237,29 +240,32 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS) {
237240

238241
// Minimum shift amount high bits are known zero.
239242
APInt MinShiftAmount = RHS.getMinValue();
240-
if (MinShiftAmount.ult(BitWidth)) {
241-
MinLeadingZeros += MinShiftAmount.getZExtValue();
242-
MinLeadingZeros = std::min(MinLeadingZeros, BitWidth);
243-
}
243+
if (MinShiftAmount.uge(BitWidth))
244+
// Always poison. Return unknown because we don't like returning conflict.
245+
return Known;
246+
247+
MinLeadingZeros += MinShiftAmount.getZExtValue();
248+
MinLeadingZeros = std::min(MinLeadingZeros, BitWidth);
244249

245250
// If the maximum shift is in range, then find the common bits from all
246251
// possible shifts.
247252
APInt MaxShiftAmount = RHS.getMaxValue();
248-
if (MaxShiftAmount.ult(BitWidth) && !LHS.isUnknown()) {
253+
if (!LHS.isUnknown()) {
249254
uint64_t ShiftAmtZeroMask = (~RHS.Zero).getZExtValue();
250255
uint64_t ShiftAmtOneMask = RHS.One.getZExtValue();
251256
assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
252257
Known.Zero.setAllBits();
253258
Known.One.setAllBits();
254259
for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
255-
MaxShiftAmt = MaxShiftAmount.getZExtValue();
260+
MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1);
256261
ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
257262
// Skip if the shift amount is impossible.
258263
if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
259264
(ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
260265
continue;
261266
KnownBits SpecificShift = LHS;
262267
SpecificShift.Zero.lshrInPlace(ShiftAmt);
268+
SpecificShift.Zero.setHighBits(ShiftAmt);
263269
SpecificShift.One.lshrInPlace(ShiftAmt);
264270
Known = KnownBits::commonBits(Known, SpecificShift);
265271
if (Known.isUnknown())
@@ -289,28 +295,30 @@ KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS) {
289295

290296
// Minimum shift amount high bits are known sign bits.
291297
APInt MinShiftAmount = RHS.getMinValue();
292-
if (MinShiftAmount.ult(BitWidth)) {
293-
if (MinLeadingZeros) {
294-
MinLeadingZeros += MinShiftAmount.getZExtValue();
295-
MinLeadingZeros = std::min(MinLeadingZeros, BitWidth);
296-
}
297-
if (MinLeadingOnes) {
298-
MinLeadingOnes += MinShiftAmount.getZExtValue();
299-
MinLeadingOnes = std::min(MinLeadingOnes, BitWidth);
300-
}
298+
if (MinShiftAmount.uge(BitWidth))
299+
// Always poison. Return unknown because we don't like returning conflict.
300+
return Known;
301+
302+
if (MinLeadingZeros) {
303+
MinLeadingZeros += MinShiftAmount.getZExtValue();
304+
MinLeadingZeros = std::min(MinLeadingZeros, BitWidth);
305+
}
306+
if (MinLeadingOnes) {
307+
MinLeadingOnes += MinShiftAmount.getZExtValue();
308+
MinLeadingOnes = std::min(MinLeadingOnes, BitWidth);
301309
}
302310

303311
// If the maximum shift is in range, then find the common bits from all
304312
// possible shifts.
305313
APInt MaxShiftAmount = RHS.getMaxValue();
306-
if (MaxShiftAmount.ult(BitWidth) && !LHS.isUnknown()) {
314+
if (!LHS.isUnknown()) {
307315
uint64_t ShiftAmtZeroMask = (~RHS.Zero).getZExtValue();
308316
uint64_t ShiftAmtOneMask = RHS.One.getZExtValue();
309317
assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
310318
Known.Zero.setAllBits();
311319
Known.One.setAllBits();
312320
for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
313-
MaxShiftAmt = MaxShiftAmount.getZExtValue();
321+
MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1);
314322
ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
315323
// Skip if the shift amount is impossible.
316324
if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||

llvm/test/CodeGen/AMDGPU/amdgpu.private-memory.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ for.end:
221221
; SI-PROMOTE-VECT: s_load_dword [[IDX:s[0-9]+]]
222222
; SI-PROMOTE-VECT: s_lshl_b32 [[SCALED_IDX:s[0-9]+]], [[IDX]], 4
223223
; SI-PROMOTE-VECT: s_lshr_b32 [[SREG:s[0-9]+]], 0x10000, [[SCALED_IDX]]
224-
; SI-PROMOTE-VECT: s_and_b32 s{{[0-9]+}}, [[SREG]], 0xffff
224+
; SI-PROMOTE-VECT: s_and_b32 s{{[0-9]+}}, [[SREG]], 1
225225
define amdgpu_kernel void @short_array(ptr addrspace(1) %out, i32 %index) #0 {
226226
entry:
227227
%0 = alloca [2 x i16], addrspace(5)

llvm/test/Transforms/InstCombine/not-add.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ define void @pr50370(i32 %x) {
172172
; CHECK-NEXT: entry:
173173
; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[X:%.*]], 1
174174
; CHECK-NEXT: [[B15:%.*]] = srem i32 ashr (i32 65536, i32 or (i32 zext (i1 icmp eq (ptr @g, ptr null) to i32), i32 65537)), [[XOR]]
175-
; CHECK-NEXT: [[B12:%.*]] = add nuw nsw i32 [[B15]], ashr (i32 65536, i32 or (i32 zext (i1 icmp eq (ptr @g, ptr null) to i32), i32 65537))
175+
; CHECK-NEXT: [[B12:%.*]] = add nsw i32 [[B15]], ashr (i32 65536, i32 or (i32 zext (i1 icmp eq (ptr @g, ptr null) to i32), i32 65537))
176176
; CHECK-NEXT: [[B:%.*]] = xor i32 [[B12]], -1
177177
; CHECK-NEXT: store i32 [[B]], ptr undef, align 4
178178
; CHECK-NEXT: ret void

llvm/unittests/Support/KnownBitsTest.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,6 @@ TEST(KnownBitsTest, BinaryExhaustive) {
270270
},
271271
checkCorrectnessOnlyBinary);
272272

273-
// TODO: Make optimal for non-constant cases.
274273
testBinaryOpExhaustive(
275274
[](const KnownBits &Known1, const KnownBits &Known2) {
276275
return KnownBits::shl(Known1, Known2);
@@ -279,9 +278,6 @@ TEST(KnownBitsTest, BinaryExhaustive) {
279278
if (N2.uge(N2.getBitWidth()))
280279
return std::nullopt;
281280
return N1.shl(N2);
282-
},
283-
[](const KnownBits &, const KnownBits &Known) {
284-
return Known.isConstant();
285281
});
286282
testBinaryOpExhaustive(
287283
[](const KnownBits &Known1, const KnownBits &Known2) {
@@ -291,9 +287,6 @@ TEST(KnownBitsTest, BinaryExhaustive) {
291287
if (N2.uge(N2.getBitWidth()))
292288
return std::nullopt;
293289
return N1.lshr(N2);
294-
},
295-
[](const KnownBits &, const KnownBits &Known) {
296-
return Known.isConstant();
297290
});
298291
testBinaryOpExhaustive(
299292
[](const KnownBits &Known1, const KnownBits &Known2) {
@@ -303,9 +296,6 @@ TEST(KnownBitsTest, BinaryExhaustive) {
303296
if (N2.uge(N2.getBitWidth()))
304297
return std::nullopt;
305298
return N1.ashr(N2);
306-
},
307-
[](const KnownBits &, const KnownBits &Known) {
308-
return Known.isConstant();
309299
});
310300

311301
testBinaryOpExhaustive(

0 commit comments

Comments
 (0)