Skip to content

Commit fec01ee

Browse files
author
Djordje Todorovic
committed
[AggressiveInstCombine] Lower Table Based CTTZ
This patch introduces recognition of table-based ctz implementation during the AggressiveInstCombine. This fixes the [0]. [0] https://bugs.llvm.org/show_bug.cgi?id=46434 Differential Revision: https://reviews.llvm.org/D113291
1 parent f458d9f commit fec01ee

File tree

7 files changed

+682
-0
lines changed

7 files changed

+682
-0
lines changed

llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,168 @@ foldSqrt(Instruction &I, TargetTransformInfo &TTI, TargetLibraryInfo &TLI) {
473473
return false;
474474
}
475475

476+
// Check if this array of constants represents a cttz table.
477+
// Iterate over the elements from \p Table by trying to find/match all
478+
// the numbers from 0 to \p InputBits that should represent cttz results.
479+
static bool isCTTZTable(const ConstantDataArray &Table, uint64_t Mul,
480+
uint64_t Shift, uint64_t InputBits) {
481+
unsigned Length = Table.getNumElements();
482+
if (Length < InputBits || Length > InputBits * 2)
483+
return false;
484+
485+
APInt Mask = APInt::getBitsSetFrom(InputBits, Shift);
486+
unsigned Matched = 0;
487+
488+
for (unsigned i = 0; i < Length; i++) {
489+
uint64_t Element = Table.getElementAsInteger(i);
490+
if (Element >= InputBits)
491+
continue;
492+
493+
// Check if \p Element matches a concrete answer. It could fail for some
494+
// elements that are never accessed, so we keep iterating over each element
495+
// from the table. The number of matched elements should be equal to the
496+
// number of potential right answers which is \p InputBits actually.
497+
if ((((Mul << Element) & Mask.getZExtValue()) >> Shift) == i)
498+
Matched++;
499+
}
500+
501+
return Matched == InputBits;
502+
}
503+
504+
// Try to recognize table-based ctz implementation.
505+
// E.g., an example in C (for more cases please see the llvm/tests):
506+
// int f(unsigned x) {
507+
// static const char table[32] =
508+
// {0, 1, 28, 2, 29, 14, 24, 3, 30,
509+
// 22, 20, 15, 25, 17, 4, 8, 31, 27,
510+
// 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9};
511+
// return table[((unsigned)((x & -x) * 0x077CB531U)) >> 27];
512+
// }
513+
// this can be lowered to `cttz` instruction.
514+
// There is also a special case when the element is 0.
515+
//
516+
// Here are some examples or LLVM IR for a 64-bit target:
517+
//
518+
// CASE 1:
519+
// %sub = sub i32 0, %x
520+
// %and = and i32 %sub, %x
521+
// %mul = mul i32 %and, 125613361
522+
// %shr = lshr i32 %mul, 27
523+
// %idxprom = zext i32 %shr to i64
524+
// %arrayidx = getelementptr inbounds [32 x i8], [32 x i8]* @ctz1.table, i64 0,
525+
// i64 %idxprom %0 = load i8, i8* %arrayidx, align 1, !tbaa !8
526+
//
527+
// CASE 2:
528+
// %sub = sub i32 0, %x
529+
// %and = and i32 %sub, %x
530+
// %mul = mul i32 %and, 72416175
531+
// %shr = lshr i32 %mul, 26
532+
// %idxprom = zext i32 %shr to i64
533+
// %arrayidx = getelementptr inbounds [64 x i16], [64 x i16]* @ctz2.table, i64
534+
// 0, i64 %idxprom %0 = load i16, i16* %arrayidx, align 2, !tbaa !8
535+
//
536+
// CASE 3:
537+
// %sub = sub i32 0, %x
538+
// %and = and i32 %sub, %x
539+
// %mul = mul i32 %and, 81224991
540+
// %shr = lshr i32 %mul, 27
541+
// %idxprom = zext i32 %shr to i64
542+
// %arrayidx = getelementptr inbounds [32 x i32], [32 x i32]* @ctz3.table, i64
543+
// 0, i64 %idxprom %0 = load i32, i32* %arrayidx, align 4, !tbaa !8
544+
//
545+
// CASE 4:
546+
// %sub = sub i64 0, %x
547+
// %and = and i64 %sub, %x
548+
// %mul = mul i64 %and, 283881067100198605
549+
// %shr = lshr i64 %mul, 58
550+
// %arrayidx = getelementptr inbounds [64 x i8], [64 x i8]* @table, i64 0, i64
551+
// %shr %0 = load i8, i8* %arrayidx, align 1, !tbaa !8
552+
//
553+
// All this can be lowered to @llvm.cttz.i32/64 intrinsic.
554+
static bool tryToRecognizeTableBasedCttz(Instruction &I) {
555+
LoadInst *LI = dyn_cast<LoadInst>(&I);
556+
if (!LI)
557+
return false;
558+
559+
Type *AccessType = LI->getType();
560+
if (!AccessType->isIntegerTy())
561+
return false;
562+
563+
GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getPointerOperand());
564+
if (!GEP || !GEP->isInBounds() || GEP->getNumIndices() != 2)
565+
return false;
566+
567+
if (!GEP->getSourceElementType()->isArrayTy())
568+
return false;
569+
570+
uint64_t ArraySize = GEP->getSourceElementType()->getArrayNumElements();
571+
if (ArraySize != 32 && ArraySize != 64)
572+
return false;
573+
574+
GlobalVariable *GVTable = dyn_cast<GlobalVariable>(GEP->getPointerOperand());
575+
if (!GVTable || !GVTable->hasInitializer())
576+
return false;
577+
578+
ConstantDataArray *ConstData =
579+
dyn_cast<ConstantDataArray>(GVTable->getInitializer());
580+
if (!ConstData)
581+
return false;
582+
583+
if (!match(GEP->idx_begin()->get(), m_ZeroInt()))
584+
return false;
585+
586+
Value *Idx2 = std::next(GEP->idx_begin())->get();
587+
Value *X1;
588+
uint64_t MulConst, ShiftConst;
589+
// FIXME: 64-bit targets have `i64` type for the GEP index, so this match will
590+
// probably fail for other (e.g. 32-bit) targets.
591+
if (!match(Idx2, m_ZExtOrSelf(
592+
m_LShr(m_Mul(m_c_And(m_Neg(m_Value(X1)), m_Deferred(X1)),
593+
m_ConstantInt(MulConst)),
594+
m_ConstantInt(ShiftConst)))))
595+
return false;
596+
597+
unsigned InputBits = X1->getType()->getScalarSizeInBits();
598+
if (InputBits != 32 && InputBits != 64)
599+
return false;
600+
601+
// Shift should extract top 5..7 bits.
602+
if (InputBits - Log2_32(InputBits) != ShiftConst &&
603+
InputBits - Log2_32(InputBits) - 1 != ShiftConst)
604+
return false;
605+
606+
if (!isCTTZTable(*ConstData, MulConst, ShiftConst, InputBits))
607+
return false;
608+
609+
auto ZeroTableElem = ConstData->getElementAsInteger(0);
610+
bool DefinedForZero = ZeroTableElem == InputBits;
611+
612+
IRBuilder<> B(LI);
613+
ConstantInt *BoolConst = B.getInt1(!DefinedForZero);
614+
Type *XType = X1->getType();
615+
auto Cttz = B.CreateIntrinsic(Intrinsic::cttz, {XType}, {X1, BoolConst});
616+
Value *ZExtOrTrunc = nullptr;
617+
618+
if (DefinedForZero) {
619+
ZExtOrTrunc = B.CreateZExtOrTrunc(Cttz, AccessType);
620+
} else {
621+
// If the value in elem 0 isn't the same as InputBits, we still want to
622+
// produce the value from the table.
623+
auto Cmp = B.CreateICmpEQ(X1, ConstantInt::get(XType, 0));
624+
auto Select =
625+
B.CreateSelect(Cmp, ConstantInt::get(XType, ZeroTableElem), Cttz);
626+
627+
// NOTE: If the table[0] is 0, but the cttz(0) is defined by the Target
628+
// it should be handled as: `cttz(x) & (typeSize - 1)`.
629+
630+
ZExtOrTrunc = B.CreateZExtOrTrunc(Select, AccessType);
631+
}
632+
633+
LI->replaceAllUsesWith(ZExtOrTrunc);
634+
635+
return true;
636+
}
637+
476638
/// This is the entry point for folds that could be implemented in regular
477639
/// InstCombine, but they are separated because they are not expected to
478640
/// occur frequently and/or have more than a constant-length pattern match.
@@ -496,6 +658,7 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
496658
MadeChange |= tryToRecognizePopCount(I);
497659
MadeChange |= tryToFPToSat(I, TTI);
498660
MadeChange |= foldSqrt(I, TTI, TLI);
661+
MadeChange |= tryToRecognizeTableBasedCttz(I);
499662
}
500663
}
501664

0 commit comments

Comments
 (0)