Skip to content

[SimplifyCFG] Add optimization for switches of powers of two #70977

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 77 additions & 3 deletions llvm/lib/Transforms/Utils/SimplifyCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6792,9 +6792,6 @@ static bool ReduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder,

// This transform can be done speculatively because it is so cheap - it
// results in a single rotate operation being inserted.
// FIXME: It's possible that optimizing a switch on powers of two might also
// be beneficial - flag values are often powers of two and we could use a CLZ
// as the key function.

// countTrailingZeros(0) returns 64. As Values is guaranteed to have more than
// one element and LLVM disallows duplicate cases, Shift is guaranteed to be
Expand Down Expand Up @@ -6839,6 +6836,80 @@ static bool ReduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder,
return true;
}

/// Tries to transform switch of powers of two to reduce switch range.
/// For example, switch like:
/// switch (C) { case 1: case 2: case 64: case 128: }
/// will be transformed to:
/// switch (count_trailing_zeros(C)) { case 0: case 1: case 6: case 7: }
///
/// This transformation allows better lowering and could allow transforming into
/// a lookup table.
static bool simplifySwitchOfPowersOfTwo(SwitchInst *SI, IRBuilder<> &Builder,
const DataLayout &DL,
const TargetTransformInfo &TTI) {
Value *Condition = SI->getCondition();
LLVMContext &Context = SI->getContext();
auto *CondTy = cast<IntegerType>(Condition->getType());

if (CondTy->getIntegerBitWidth() > 64 ||
!DL.fitsInLegalInteger(CondTy->getIntegerBitWidth()))
return false;

const auto CttzIntrinsicCost = TTI.getIntrinsicInstrCost(
IntrinsicCostAttributes(Intrinsic::cttz, CondTy,
{Condition, ConstantInt::getTrue(Context)}),
TTI::TCK_SizeAndLatency);

if (CttzIntrinsicCost > TTI::TCC_Basic)
// Inserting intrinsic is too expensive.
return false;

// Only bother with this optimization if there are more than 3 switch cases.
// SDAG will only bother creating jump tables for 4 or more cases.
if (SI->getNumCases() < 4)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth noting that AArch64 started using a larger limit very recently (#71166). But I think they can add a TTI hook for this if it becomes a problem for them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will post a patch later.

return false;

// We perform this optimization only for switches with
// unreachable default case.
// This assumtion will save us from checking if `Condition` is a power of two.
if (!isa<UnreachableInst>(SI->getDefaultDest()->getFirstNonPHIOrDbg()))
return false;

// Check that switch cases are powers of two.
SmallVector<uint64_t, 4> Values;
for (const auto &Case : SI->cases()) {
uint64_t CaseValue = Case.getCaseValue()->getValue().getZExtValue();
if (llvm::has_single_bit(CaseValue))
Values.push_back(CaseValue);
else
return false;
}

// isSwichDense requires case values to be sorted.
llvm::sort(Values);
if (!isSwitchDense(Values.size(), llvm::countr_zero(Values.back()) -
llvm::countr_zero(Values.front()) + 1))
// Transform is unable to generate dense switch.
return false;

Builder.SetInsertPoint(SI);

// Replace each case with its trailing zeros number.
for (auto &Case : SI->cases()) {
auto *OrigValue = Case.getCaseValue();
Case.setValue(ConstantInt::get(OrigValue->getType(),
OrigValue->getValue().countr_zero()));
}

// Replace condition with its trailing zeros number.
auto *ConditionTrailingZeros = Builder.CreateIntrinsic(
Intrinsic::cttz, {CondTy}, {Condition, ConstantInt::getTrue(Context)});

SI->setCondition(ConditionTrailingZeros);

return true;
}

bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
BasicBlock *BB = SI->getParent();

Expand Down Expand Up @@ -6886,6 +6957,9 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
SwitchToLookupTable(SI, Builder, DTU, DL, TTI))
return requestResimplify();

if (simplifySwitchOfPowersOfTwo(SI, Builder, DL, TTI))
return requestResimplify();

if (ReduceSwitchRange(SI, Builder, DL, TTI))
return requestResimplify();

Expand Down
2 changes: 1 addition & 1 deletion llvm/test/CodeGen/AArch64/switch-unreachable-default.ll
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ entry:
i32 8, label %bb2
i32 16, label %bb3
i32 32, label %bb4
i32 64, label %bb5
i32 -64, label %bb5
]

; The switch is lowered with a jump table for cases 1--32 and case 64 handled
Expand Down
Loading