@@ -7131,6 +7131,119 @@ static bool simplifySwitchOfPowersOfTwo(SwitchInst *SI, IRBuilder<> &Builder,
7131
7131
return true ;
7132
7132
}
7133
7133
7134
+ // / Fold switch over ucmp/scmp intrinsic to br if two of the switch arms have
7135
+ // / the same destination.
7136
+ static bool simplifySwitchOfCmpIntrinsic (SwitchInst *SI, IRBuilderBase &Builder,
7137
+ DomTreeUpdater *DTU) {
7138
+ auto *Cmp = dyn_cast<CmpIntrinsic>(SI->getCondition ());
7139
+ if (!Cmp || !Cmp->hasOneUse ())
7140
+ return false ;
7141
+
7142
+ SmallVector<uint32_t , 4 > Weights;
7143
+ bool HasWeights = extractBranchWeights (getBranchWeightMDNode (*SI), Weights);
7144
+ if (!HasWeights)
7145
+ Weights.resize (4 ); // Avoid checking HasWeights everywhere.
7146
+
7147
+ // Normalize to [us]cmp == Res ? Succ : OtherSucc.
7148
+ int64_t Res;
7149
+ BasicBlock *Succ, *OtherSucc;
7150
+ uint32_t SuccWeight = 0 , OtherSuccWeight = 0 ;
7151
+ BasicBlock *Unreachable = nullptr ;
7152
+
7153
+ if (SI->getNumCases () == 2 ) {
7154
+ // Find which of 1, 0 or -1 is missing (handled by default dest).
7155
+ SmallSet<int64_t , 3 > Missing;
7156
+ Missing.insert (1 );
7157
+ Missing.insert (0 );
7158
+ Missing.insert (-1 );
7159
+
7160
+ Succ = SI->getDefaultDest ();
7161
+ SuccWeight = Weights[0 ];
7162
+ OtherSucc = nullptr ;
7163
+ for (auto &Case : SI->cases ()) {
7164
+ std::optional<int64_t > Val =
7165
+ Case.getCaseValue ()->getValue ().trySExtValue ();
7166
+ if (!Val)
7167
+ return false ;
7168
+ if (!Missing.erase (*Val))
7169
+ return false ;
7170
+ if (OtherSucc && OtherSucc != Case.getCaseSuccessor ())
7171
+ return false ;
7172
+ OtherSucc = Case.getCaseSuccessor ();
7173
+ OtherSuccWeight += Weights[Case.getSuccessorIndex ()];
7174
+ }
7175
+
7176
+ assert (Missing.size () == 1 && " Should have one case left" );
7177
+ Res = *Missing.begin ();
7178
+ } else if (SI->getNumCases () == 3 && SI->defaultDestUndefined ()) {
7179
+ // Normalize so that Succ is taken once and OtherSucc twice.
7180
+ Unreachable = SI->getDefaultDest ();
7181
+ Succ = OtherSucc = nullptr ;
7182
+ for (auto &Case : SI->cases ()) {
7183
+ BasicBlock *NewSucc = Case.getCaseSuccessor ();
7184
+ uint32_t Weight = Weights[Case.getSuccessorIndex ()];
7185
+ if (!OtherSucc || OtherSucc == NewSucc) {
7186
+ OtherSucc = NewSucc;
7187
+ OtherSuccWeight += Weight;
7188
+ } else if (!Succ) {
7189
+ Succ = NewSucc;
7190
+ SuccWeight = Weight;
7191
+ } else if (Succ == NewSucc) {
7192
+ std::swap (Succ, OtherSucc);
7193
+ std::swap (SuccWeight, OtherSuccWeight);
7194
+ } else
7195
+ return false ;
7196
+ }
7197
+ for (auto &Case : SI->cases ()) {
7198
+ std::optional<int64_t > Val =
7199
+ Case.getCaseValue ()->getValue ().trySExtValue ();
7200
+ if (!Val || (Val != 1 && Val != 0 && Val != -1 ))
7201
+ return false ;
7202
+ if (Case.getCaseSuccessor () == Succ) {
7203
+ Res = *Val;
7204
+ break ;
7205
+ }
7206
+ }
7207
+ } else {
7208
+ return false ;
7209
+ }
7210
+
7211
+ // Determine predicate for the missing case.
7212
+ ICmpInst::Predicate Pred;
7213
+ switch (Res) {
7214
+ case 1 :
7215
+ Pred = ICmpInst::ICMP_UGT;
7216
+ break ;
7217
+ case 0 :
7218
+ Pred = ICmpInst::ICMP_EQ;
7219
+ break ;
7220
+ case -1 :
7221
+ Pred = ICmpInst::ICMP_ULT;
7222
+ break ;
7223
+ }
7224
+ if (Cmp->isSigned ())
7225
+ Pred = ICmpInst::getSignedPredicate (Pred);
7226
+
7227
+ MDNode *NewWeights = nullptr ;
7228
+ if (HasWeights)
7229
+ NewWeights = MDBuilder (SI->getContext ())
7230
+ .createBranchWeights (SuccWeight, OtherSuccWeight);
7231
+
7232
+ BasicBlock *BB = SI->getParent ();
7233
+ Builder.SetInsertPoint (SI->getIterator ());
7234
+ Value *ICmp = Builder.CreateICmp (Pred, Cmp->getLHS (), Cmp->getRHS ());
7235
+ Builder.CreateCondBr (ICmp, Succ, OtherSucc, NewWeights,
7236
+ SI->getMetadata (LLVMContext::MD_unpredictable));
7237
+ OtherSucc->removePredecessor (BB);
7238
+ if (Unreachable)
7239
+ Unreachable->removePredecessor (BB);
7240
+ SI->eraseFromParent ();
7241
+ Cmp->eraseFromParent ();
7242
+ if (DTU && Unreachable)
7243
+ DTU->applyUpdates ({{DominatorTree::Delete, BB, Unreachable}});
7244
+ return true ;
7245
+ }
7246
+
7134
7247
bool SimplifyCFGOpt::simplifySwitch (SwitchInst *SI, IRBuilder<> &Builder) {
7135
7248
BasicBlock *BB = SI->getParent ();
7136
7249
@@ -7163,6 +7276,9 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
7163
7276
if (eliminateDeadSwitchCases (SI, DTU, Options.AC , DL))
7164
7277
return requestResimplify ();
7165
7278
7279
+ if (simplifySwitchOfCmpIntrinsic (SI, Builder, DTU))
7280
+ return requestResimplify ();
7281
+
7166
7282
if (trySwitchToSelect (SI, Builder, DTU, DL, TTI))
7167
7283
return requestResimplify ();
7168
7284
0 commit comments