@@ -6453,6 +6453,17 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I,
6453
6453
// a predicated block since it will become a fall-through, although we
6454
6454
// may decide in the future to call TTI for all branches.
6455
6455
}
6456
+ case Instruction::Switch: {
6457
+ if (VF.isScalar ())
6458
+ return TTI.getCFInstrCost (Instruction::Switch, CostKind);
6459
+ auto *Switch = cast<SwitchInst>(I);
6460
+ return Switch->getNumCases () *
6461
+ TTI.getCmpSelInstrCost (
6462
+ Instruction::ICmp,
6463
+ ToVectorTy (Switch->getCondition ()->getType (), VF),
6464
+ ToVectorTy (Type::getInt1Ty (I->getContext ()), VF),
6465
+ CmpInst::ICMP_EQ, CostKind);
6466
+ }
6456
6467
case Instruction::PHI: {
6457
6468
auto *Phi = cast<PHINode>(I);
6458
6469
@@ -7841,6 +7852,62 @@ VPRecipeBuilder::mapToVPValues(User::op_range Operands) {
7841
7852
return map_range (Operands, Fn);
7842
7853
}
7843
7854
7855
+ void VPRecipeBuilder::createSwitchEdgeMasks (SwitchInst *SI) {
7856
+ BasicBlock *Src = SI->getParent ();
7857
+ assert (!OrigLoop->isLoopExiting (Src) &&
7858
+ all_of (successors (Src),
7859
+ [this ](BasicBlock *Succ) {
7860
+ return OrigLoop->getHeader () != Succ;
7861
+ }) &&
7862
+ " unsupported switch either exiting loop or continuing to header" );
7863
+ // Create masks where the terminator in Src is a switch. We create mask for
7864
+ // all edges at the same time. This is more efficient, as we can create and
7865
+ // collect compares for all cases once.
7866
+ VPValue *Cond = getVPValueOrAddLiveIn (SI->getCondition (), Plan);
7867
+ BasicBlock *DefaultDst = SI->getDefaultDest ();
7868
+ MapVector<BasicBlock *, SmallVector<VPValue *>> Dst2Compares;
7869
+ for (auto &C : SI->cases ()) {
7870
+ BasicBlock *Dst = C.getCaseSuccessor ();
7871
+ assert (!EdgeMaskCache.contains ({Src, Dst}) && " Edge masks already created" );
7872
+ // Cases whose destination is the same as default are redundant and can be
7873
+ // ignored - they will get there anyhow.
7874
+ if (Dst == DefaultDst)
7875
+ continue ;
7876
+ auto I = Dst2Compares.insert ({Dst, {}});
7877
+ VPValue *V = getVPValueOrAddLiveIn (C.getCaseValue (), Plan);
7878
+ I.first ->second .push_back (Builder.createICmp (CmpInst::ICMP_EQ, Cond, V));
7879
+ }
7880
+
7881
+ // We need to handle 2 separate cases below for all entries in Dst2Compares,
7882
+ // which excludes destinations matching the default destination.
7883
+ VPValue *SrcMask = getBlockInMask (Src);
7884
+ VPValue *DefaultMask = nullptr ;
7885
+ for (const auto &[Dst, Conds] : Dst2Compares) {
7886
+ // 1. Dst is not the default destination. Dst is reached if any of the cases
7887
+ // with destination == Dst are taken. Join the conditions for each case
7888
+ // whose destination == Dst using an OR.
7889
+ VPValue *Mask = Conds[0 ];
7890
+ for (VPValue *V : ArrayRef<VPValue *>(Conds).drop_front ())
7891
+ Mask = Builder.createOr (Mask, V);
7892
+ if (SrcMask)
7893
+ Mask = Builder.createLogicalAnd (SrcMask, Mask);
7894
+ EdgeMaskCache[{Src, Dst}] = Mask;
7895
+
7896
+ // 2. Create the mask for the default destination, which is reached if none
7897
+ // of the cases with destination != default destination are taken. Join the
7898
+ // conditions for each case where the destination is != Dst using an OR and
7899
+ // negate it.
7900
+ DefaultMask = DefaultMask ? Builder.createOr (DefaultMask, Mask) : Mask;
7901
+ }
7902
+
7903
+ if (DefaultMask) {
7904
+ DefaultMask = Builder.createNot (DefaultMask);
7905
+ if (SrcMask)
7906
+ DefaultMask = Builder.createLogicalAnd (SrcMask, DefaultMask);
7907
+ }
7908
+ EdgeMaskCache[{Src, DefaultDst}] = DefaultMask;
7909
+ }
7910
+
7844
7911
VPValue *VPRecipeBuilder::createEdgeMask (BasicBlock *Src, BasicBlock *Dst) {
7845
7912
assert (is_contained (predecessors (Dst), Src) && " Invalid edge" );
7846
7913
@@ -7850,12 +7917,17 @@ VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) {
7850
7917
if (ECEntryIt != EdgeMaskCache.end ())
7851
7918
return ECEntryIt->second ;
7852
7919
7920
+ if (auto *SI = dyn_cast<SwitchInst>(Src->getTerminator ())) {
7921
+ createSwitchEdgeMasks (SI);
7922
+ assert (EdgeMaskCache.contains (Edge) && " Mask for Edge not created?" );
7923
+ return EdgeMaskCache[Edge];
7924
+ }
7925
+
7853
7926
VPValue *SrcMask = getBlockInMask (Src);
7854
7927
7855
7928
// The terminator has to be a branch inst!
7856
7929
BranchInst *BI = dyn_cast<BranchInst>(Src->getTerminator ());
7857
7930
assert (BI && " Unexpected terminator found" );
7858
-
7859
7931
if (!BI->isConditional () || BI->getSuccessor (0 ) == BI->getSuccessor (1 ))
7860
7932
return EdgeMaskCache[Edge] = SrcMask;
7861
7933
0 commit comments