Skip to content

Commit d067014

Browse files
liamsemeriaRKSimon
andauthored
[APInt] Added APInt::clearBits() method (#137098)
Added APInt::clearBits(unsigned loBit, unsigned hiBit) that clears bits within a certain range. Fixes #136550 --------- Co-authored-by: Simon Pilgrim <[email protected]>
1 parent 572add0 commit d067014

File tree

5 files changed

+117
-4
lines changed

5 files changed

+117
-4
lines changed

llvm/include/llvm/ADT/APInt.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1412,6 +1412,25 @@ class [[nodiscard]] APInt {
14121412
U.pVal[whichWord(BitPosition)] &= Mask;
14131413
}
14141414

1415+
/// Clear the bits from LoBit (inclusive) to HiBit (exclusive) to 0.
1416+
/// This function handles case when \p LoBit <= \p HiBit.
1417+
void clearBits(unsigned LoBit, unsigned HiBit) {
1418+
assert(HiBit <= BitWidth && "HiBit out of range");
1419+
assert(LoBit <= HiBit && "LoBit greater than HiBit");
1420+
if (LoBit == HiBit)
1421+
return;
1422+
if (HiBit <= APINT_BITS_PER_WORD) {
1423+
uint64_t Mask = WORDTYPE_MAX >> (APINT_BITS_PER_WORD - (HiBit - LoBit));
1424+
Mask = ~(Mask << LoBit);
1425+
if (isSingleWord())
1426+
U.VAL &= Mask;
1427+
else
1428+
U.pVal[0] &= Mask;
1429+
} else {
1430+
clearBitsSlowCase(LoBit, HiBit);
1431+
}
1432+
}
1433+
14151434
/// Set bottom loBits bits to 0.
14161435
void clearLowBits(unsigned loBits) {
14171436
assert(loBits <= BitWidth && "More bits than bitwidth");
@@ -2052,6 +2071,9 @@ class [[nodiscard]] APInt {
20522071
/// out-of-line slow case for setBits.
20532072
void setBitsSlowCase(unsigned loBit, unsigned hiBit);
20542073

2074+
/// out-of-line slow case for clearBits.
2075+
void clearBitsSlowCase(unsigned LoBit, unsigned HiBit);
2076+
20552077
/// out-of-line slow case for flipAllBits.
20562078
void flipAllBitsSlowCase();
20572079

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3546,7 +3546,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
35463546
unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
35473547
APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
35483548
APInt DemandedSrcElts = DemandedElts;
3549-
DemandedSrcElts.insertBits(APInt::getZero(NumSubElts), Idx);
3549+
DemandedSrcElts.clearBits(Idx, Idx + NumSubElts);
35503550

35513551
Known.One.setAllBits();
35523552
Known.Zero.setAllBits();
@@ -5230,7 +5230,7 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
52305230
unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
52315231
APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
52325232
APInt DemandedSrcElts = DemandedElts;
5233-
DemandedSrcElts.insertBits(APInt::getZero(NumSubElts), Idx);
5233+
DemandedSrcElts.clearBits(Idx, Idx + NumSubElts);
52345234

52355235
Tmp = std::numeric_limits<unsigned>::max();
52365236
if (!!DemandedSubElts) {

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,7 +1290,7 @@ bool TargetLowering::SimplifyDemandedBits(
12901290
unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
12911291
APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
12921292
APInt DemandedSrcElts = DemandedElts;
1293-
DemandedSrcElts.insertBits(APInt::getZero(NumSubElts), Idx);
1293+
DemandedSrcElts.clearBits(Idx, Idx + NumSubElts);
12941294

12951295
KnownBits KnownSub, KnownSrc;
12961296
if (SimplifyDemandedBits(Sub, DemandedBits, DemandedSubElts, KnownSub, TLO,
@@ -3357,7 +3357,7 @@ bool TargetLowering::SimplifyDemandedVectorElts(
33573357
unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
33583358
APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
33593359
APInt DemandedSrcElts = DemandedElts;
3360-
DemandedSrcElts.insertBits(APInt::getZero(NumSubElts), Idx);
3360+
DemandedSrcElts.clearBits(Idx, Idx + NumSubElts);
33613361

33623362
APInt SubUndef, SubZero;
33633363
if (SimplifyDemandedVectorElts(Sub, DemandedSubElts, SubUndef, SubZero, TLO,

llvm/lib/Support/APInt.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,33 @@ void APInt::setBitsSlowCase(unsigned loBit, unsigned hiBit) {
336336
U.pVal[word] = WORDTYPE_MAX;
337337
}
338338

339+
void APInt::clearBitsSlowCase(unsigned LoBit, unsigned HiBit) {
340+
unsigned LoWord = whichWord(LoBit);
341+
unsigned HiWord = whichWord(HiBit);
342+
343+
// Create an initial mask for the low word with ones below loBit.
344+
uint64_t LoMask = ~(WORDTYPE_MAX << whichBit(LoBit));
345+
346+
// If HiBit is not aligned, we need a high mask.
347+
unsigned HiShiftAmt = whichBit(HiBit);
348+
if (HiShiftAmt != 0) {
349+
// Create a high mask with ones above HiBit.
350+
uint64_t HiMask = ~(WORDTYPE_MAX >> (APINT_BITS_PER_WORD - HiShiftAmt));
351+
// If LoWord and HiWord are equal, then we combine the masks. Otherwise,
352+
// set the bits in HiWord.
353+
if (HiWord == LoWord)
354+
LoMask &= HiMask;
355+
else
356+
U.pVal[HiWord] &= HiMask;
357+
}
358+
// Apply the mask to the low word.
359+
U.pVal[LoWord] &= LoMask;
360+
361+
// Fill any words between LoWord and HiWord with all zeros.
362+
for (unsigned Word = LoWord + 1; Word < HiWord; ++Word)
363+
U.pVal[Word] = 0;
364+
}
365+
339366
// Complement a bignum in-place.
340367
static void tcComplement(APInt::WordType *dst, unsigned parts) {
341368
for (unsigned i = 0; i < parts; i++)

llvm/unittests/ADT/APIntTest.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2520,6 +2520,70 @@ TEST(APIntTest, setAllBits) {
25202520
EXPECT_EQ(128u, i128.popcount());
25212521
}
25222522

2523+
TEST(APIntTest, clearBits) {
2524+
APInt i32 = APInt::getAllOnes(32);
2525+
i32.clearBits(1, 3);
2526+
EXPECT_EQ(1u, i32.countr_one());
2527+
EXPECT_EQ(0u, i32.countr_zero());
2528+
EXPECT_EQ(32u, i32.getActiveBits());
2529+
EXPECT_EQ(0u, i32.countl_zero());
2530+
EXPECT_EQ(29u, i32.countl_one());
2531+
EXPECT_EQ(30u, i32.popcount());
2532+
2533+
i32.clearBits(15, 15);
2534+
EXPECT_EQ(1u, i32.countr_one());
2535+
EXPECT_EQ(0u, i32.countr_zero());
2536+
EXPECT_EQ(32u, i32.getActiveBits());
2537+
EXPECT_EQ(0u, i32.countl_zero());
2538+
EXPECT_EQ(29u, i32.countl_one());
2539+
EXPECT_EQ(30u, i32.popcount());
2540+
2541+
i32.clearBits(28, 31);
2542+
EXPECT_EQ(1u, i32.countr_one());
2543+
EXPECT_EQ(0u, i32.countr_zero());
2544+
EXPECT_EQ(32u, i32.getActiveBits());
2545+
EXPECT_EQ(0u, i32.countl_zero());
2546+
EXPECT_EQ(1u, i32.countl_one());
2547+
EXPECT_EQ(27u, i32.popcount());
2548+
EXPECT_EQ(APInt(32, "8FFFFFF9", 16), i32);
2549+
2550+
APInt i256 = APInt::getAllOnes(256);
2551+
i256.clearBits(10, 250);
2552+
EXPECT_EQ(10u, i256.countr_one());
2553+
EXPECT_EQ(0u, i256.countr_zero());
2554+
EXPECT_EQ(256u, i256.getActiveBits());
2555+
EXPECT_EQ(0u, i256.countl_zero());
2556+
EXPECT_EQ(6u, i256.countl_one());
2557+
EXPECT_EQ(16u, i256.popcount());
2558+
2559+
APInt i311 = APInt::getAllOnes(311);
2560+
i311.clearBits(33, 99);
2561+
EXPECT_EQ(33u, i311.countr_one());
2562+
EXPECT_EQ(0u, i311.countr_zero());
2563+
EXPECT_EQ(311u, i311.getActiveBits());
2564+
EXPECT_EQ(0u, i311.countl_zero());
2565+
EXPECT_EQ(212u, i311.countl_one());
2566+
EXPECT_EQ(245u, i311.popcount());
2567+
2568+
APInt i64hi32 = APInt::getAllOnes(64);
2569+
i64hi32.clearBits(0, 32);
2570+
EXPECT_EQ(32u, i64hi32.countl_one());
2571+
EXPECT_EQ(0u, i64hi32.countl_zero());
2572+
EXPECT_EQ(64u, i64hi32.getActiveBits());
2573+
EXPECT_EQ(32u, i64hi32.countr_zero());
2574+
EXPECT_EQ(0u, i64hi32.countr_one());
2575+
EXPECT_EQ(32u, i64hi32.popcount());
2576+
2577+
i64hi32 = APInt::getAllOnes(64);
2578+
i64hi32.clearBits(32, 64);
2579+
EXPECT_EQ(32u, i64hi32.countr_one());
2580+
EXPECT_EQ(0u, i64hi32.countr_zero());
2581+
EXPECT_EQ(32u, i64hi32.getActiveBits());
2582+
EXPECT_EQ(32u, i64hi32.countl_zero());
2583+
EXPECT_EQ(0u, i64hi32.countl_one());
2584+
EXPECT_EQ(32u, i64hi32.popcount());
2585+
}
2586+
25232587
TEST(APIntTest, getLoBits) {
25242588
APInt i32(32, 0xfa);
25252589
i32.setHighBits(1);

0 commit comments

Comments
 (0)