@@ -374,9 +374,27 @@ Pass *llvm::createLoopUnswitchPass(bool Os) {
374
374
return new LoopUnswitch (Os);
375
375
}
376
376
377
+ // / Operator chain lattice.
378
+ enum OperatorChain {
379
+ OC_OpChainNone, // /< There is no operator.
380
+ OC_OpChainOr, // /< There are only ORs.
381
+ OC_OpChainAnd, // /< There are only ANDs.
382
+ OC_OpChainMixed // /< There are ANDs and ORs.
383
+ };
384
+
377
385
// / Cond is a condition that occurs in L. If it is invariant in the loop, or has
378
386
// / an invariant piece, return the invariant. Otherwise, return null.
387
+ //
388
+ // / NOTE: FindLIVLoopCondition will not return a partial LIV by walking up a
389
+ // / mixed operator chain, as we can not reliably find a value which will simplify
390
+ // / the operator chain. If the chain is AND-only or OR-only, we can use 0 or ~0
391
+ // / to simplify the chain.
392
+ // /
393
+ // / NOTE: In case a partial LIV and a mixed operator chain, we may be able to
394
+ // / simplify the condition itself to a loop variant condition, but at the
395
+ // / cost of creating an entirely new loop.
379
396
static Value *FindLIVLoopCondition (Value *Cond, Loop *L, bool &Changed,
397
+ OperatorChain &ParentChain,
380
398
DenseMap<Value *, Value *> &Cache) {
381
399
auto CacheIt = Cache.find (Cond);
382
400
if (CacheIt != Cache.end ())
@@ -400,31 +418,75 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,
400
418
return Cond;
401
419
}
402
420
421
+ // Walk up the operator chain to find partial invariant conditions.
403
422
if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Cond))
404
423
if (BO->getOpcode () == Instruction::And ||
405
424
BO->getOpcode () == Instruction::Or) {
406
- // If either the left or right side is invariant, we can unswitch on this,
407
- // which will cause the branch to go away in one loop and the condition to
408
- // simplify in the other one.
409
- if (Value *LHS =
410
- FindLIVLoopCondition (BO->getOperand (0 ), L, Changed, Cache)) {
411
- Cache[Cond] = LHS;
412
- return LHS;
425
+ // Given the previous operator, compute the current operator chain status.
426
+ OperatorChain NewChain;
427
+ switch (ParentChain) {
428
+ case OC_OpChainNone:
429
+ NewChain = BO->getOpcode () == Instruction::And ? OC_OpChainAnd :
430
+ OC_OpChainOr;
431
+ break ;
432
+ case OC_OpChainOr:
433
+ NewChain = BO->getOpcode () == Instruction::Or ? OC_OpChainOr :
434
+ OC_OpChainMixed;
435
+ break ;
436
+ case OC_OpChainAnd:
437
+ NewChain = BO->getOpcode () == Instruction::And ? OC_OpChainAnd :
438
+ OC_OpChainMixed;
439
+ break ;
440
+ case OC_OpChainMixed:
441
+ NewChain = OC_OpChainMixed;
442
+ break ;
413
443
}
414
- if (Value *RHS =
415
- FindLIVLoopCondition (BO->getOperand (1 ), L, Changed, Cache)) {
416
- Cache[Cond] = RHS;
417
- return RHS;
444
+
445
+ // If we reach a Mixed state, we do not want to keep walking up as we can not
446
+ // reliably find a value that will simplify the chain. With this check, we
447
+ // will return null on the first sight of mixed chain and the caller will
448
+ // either backtrack to find partial LIV in other operand or return null.
449
+ if (NewChain != OC_OpChainMixed) {
450
+ // Update the current operator chain type before we search up the chain.
451
+ ParentChain = NewChain;
452
+ // If either the left or right side is invariant, we can unswitch on this,
453
+ // which will cause the branch to go away in one loop and the condition to
454
+ // simplify in the other one.
455
+ if (Value *LHS = FindLIVLoopCondition (BO->getOperand (0 ), L, Changed,
456
+ ParentChain, Cache)) {
457
+ Cache[Cond] = LHS;
458
+ return LHS;
459
+ }
460
+ // We did not manage to find a partial LIV in operand(0). Backtrack and try
461
+ // operand(1).
462
+ ParentChain = NewChain;
463
+ if (Value *RHS = FindLIVLoopCondition (BO->getOperand (1 ), L, Changed,
464
+ ParentChain, Cache)) {
465
+ Cache[Cond] = RHS;
466
+ return RHS;
467
+ }
418
468
}
419
469
}
420
470
421
471
Cache[Cond] = nullptr ;
422
472
return nullptr ;
423
473
}
424
474
425
- static Value *FindLIVLoopCondition (Value *Cond, Loop *L, bool &Changed) {
475
+ // / Cond is a condition that occurs in L. If it is invariant in the loop, or has
476
+ // / an invariant piece, return the invariant along with the operator chain type.
477
+ // / Otherwise, return null.
478
+ static std::pair<Value *, OperatorChain> FindLIVLoopCondition (Value *Cond,
479
+ Loop *L,
480
+ bool &Changed) {
426
481
DenseMap<Value *, Value *> Cache;
427
- return FindLIVLoopCondition (Cond, L, Changed, Cache);
482
+ OperatorChain OpChain = OC_OpChainNone;
483
+ Value *FCond = FindLIVLoopCondition (Cond, L, Changed, OpChain, Cache);
484
+
485
+ // In case we do find a LIV, it can not be obtained by walking up a mixed
486
+ // operator chain.
487
+ assert ((!FCond || OpChain != OC_OpChainMixed) &&
488
+ " Do not expect a partial LIV with mixed operator chain" );
489
+ return {FCond, OpChain};
428
490
}
429
491
430
492
bool LoopUnswitch::runOnLoop (Loop *L, LPPassManager &LPM_Ref) {
@@ -556,7 +618,7 @@ bool LoopUnswitch::processCurrentLoop() {
556
618
557
619
for (IntrinsicInst *Guard : Guards) {
558
620
Value *LoopCond =
559
- FindLIVLoopCondition (Guard->getOperand (0 ), currentLoop, Changed);
621
+ FindLIVLoopCondition (Guard->getOperand (0 ), currentLoop, Changed). first ;
560
622
if (LoopCond &&
561
623
UnswitchIfProfitable (LoopCond, ConstantInt::getTrue (Context))) {
562
624
// NB! Unswitching (if successful) could have erased some of the
@@ -597,32 +659,57 @@ bool LoopUnswitch::processCurrentLoop() {
597
659
// See if this, or some part of it, is loop invariant. If so, we can
598
660
// unswitch on it if we desire.
599
661
Value *LoopCond = FindLIVLoopCondition (BI->getCondition (),
600
- currentLoop, Changed);
662
+ currentLoop, Changed). first ;
601
663
if (LoopCond &&
602
664
UnswitchIfProfitable (LoopCond, ConstantInt::getTrue (Context), TI)) {
603
665
++NumBranches;
604
666
return true ;
605
667
}
606
668
}
607
669
} else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) {
608
- Value *LoopCond = FindLIVLoopCondition (SI->getCondition (),
609
- currentLoop, Changed);
670
+ Value *SC = SI->getCondition ();
671
+ Value *LoopCond;
672
+ OperatorChain OpChain;
673
+ std::tie (LoopCond, OpChain) =
674
+ FindLIVLoopCondition (SC, currentLoop, Changed);
675
+
610
676
unsigned NumCases = SI->getNumCases ();
611
677
if (LoopCond && NumCases) {
612
678
// Find a value to unswitch on:
613
679
// FIXME: this should chose the most expensive case!
614
680
// FIXME: scan for a case with a non-critical edge?
615
681
Constant *UnswitchVal = nullptr ;
616
-
617
- // Do not process same value again and again.
618
- // At this point we have some cases already unswitched and
619
- // some not yet unswitched. Let's find the first not yet unswitched one.
620
- for (SwitchInst::CaseIt i = SI->case_begin (), e = SI->case_end ();
621
- i != e; ++i) {
622
- Constant *UnswitchValCandidate = i.getCaseValue ();
623
- if (!BranchesInfo.isUnswitched (SI, UnswitchValCandidate)) {
624
- UnswitchVal = UnswitchValCandidate;
625
- break ;
682
+ // Find a case value such that at least one case value is unswitched
683
+ // out.
684
+ if (OpChain == OC_OpChainAnd) {
685
+ // If the chain only has ANDs and the switch has a case value of 0.
686
+ // Dropping in a 0 to the chain will unswitch out the 0-casevalue.
687
+ auto *AllZero = cast<ConstantInt>(Constant::getNullValue (SC->getType ()));
688
+ if (BranchesInfo.isUnswitched (SI, AllZero))
689
+ continue ;
690
+ // We are unswitching 0 out.
691
+ UnswitchVal = AllZero;
692
+ } else if (OpChain == OC_OpChainOr) {
693
+ // If the chain only has ORs and the switch has a case value of ~0.
694
+ // Dropping in a ~0 to the chain will unswitch out the ~0-casevalue.
695
+ auto *AllOne = cast<ConstantInt>(Constant::getAllOnesValue (SC->getType ()));
696
+ if (BranchesInfo.isUnswitched (SI, AllOne))
697
+ continue ;
698
+ // We are unswitching ~0 out.
699
+ UnswitchVal = AllOne;
700
+ } else {
701
+ assert (OpChain == OC_OpChainNone &&
702
+ " Expect to unswitch on trivial chain" );
703
+ // Do not process same value again and again.
704
+ // At this point we have some cases already unswitched and
705
+ // some not yet unswitched. Let's find the first not yet unswitched one.
706
+ for (SwitchInst::CaseIt i = SI->case_begin (), e = SI->case_end ();
707
+ i != e; ++i) {
708
+ Constant *UnswitchValCandidate = i.getCaseValue ();
709
+ if (!BranchesInfo.isUnswitched (SI, UnswitchValCandidate)) {
710
+ UnswitchVal = UnswitchValCandidate;
711
+ break ;
712
+ }
626
713
}
627
714
}
628
715
@@ -631,6 +718,11 @@ bool LoopUnswitch::processCurrentLoop() {
631
718
632
719
if (UnswitchIfProfitable (LoopCond, UnswitchVal)) {
633
720
++NumSwitches;
721
+ // In case of a full LIV, UnswitchVal is the value we unswitched out.
722
+ // In case of a partial LIV, we only unswitch when its an AND-chain
723
+ // or OR-chain. In both cases switch input value simplifies to
724
+ // UnswitchVal.
725
+ BranchesInfo.setUnswitched (SI, UnswitchVal);
634
726
return true ;
635
727
}
636
728
}
@@ -641,7 +733,7 @@ bool LoopUnswitch::processCurrentLoop() {
641
733
BBI != E; ++BBI)
642
734
if (SelectInst *SI = dyn_cast<SelectInst>(BBI)) {
643
735
Value *LoopCond = FindLIVLoopCondition (SI->getCondition (),
644
- currentLoop, Changed);
736
+ currentLoop, Changed). first ;
645
737
if (LoopCond && UnswitchIfProfitable (LoopCond,
646
738
ConstantInt::getTrue (Context))) {
647
739
++NumSelects;
@@ -900,7 +992,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {
900
992
return false ;
901
993
902
994
Value *LoopCond = FindLIVLoopCondition (BI->getCondition (),
903
- currentLoop, Changed);
995
+ currentLoop, Changed). first ;
904
996
905
997
// Unswitch only if the trivial condition itself is an LIV (not
906
998
// partial LIV which could occur in and/or)
@@ -931,7 +1023,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {
931
1023
} else if (SwitchInst *SI = dyn_cast<SwitchInst>(CurrentTerm)) {
932
1024
// If this isn't switching on an invariant condition, we can't unswitch it.
933
1025
Value *LoopCond = FindLIVLoopCondition (SI->getCondition (),
934
- currentLoop, Changed);
1026
+ currentLoop, Changed). first ;
935
1027
936
1028
// Unswitch only if the trivial condition itself is an LIV (not
937
1029
// partial LIV which could occur in and/or)
@@ -969,6 +1061,9 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {
969
1061
970
1062
UnswitchTrivialCondition (currentLoop, LoopCond, CondVal, LoopExitBB,
971
1063
nullptr );
1064
+
1065
+ // We are only unswitching full LIV.
1066
+ BranchesInfo.setUnswitched (SI, CondVal);
972
1067
++NumSwitches;
973
1068
return true ;
974
1069
}
@@ -1250,6 +1345,9 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,
1250
1345
SwitchInst *SI = dyn_cast<SwitchInst>(UI);
1251
1346
if (!SI || !isa<ConstantInt>(Val)) continue ;
1252
1347
1348
+ // NOTE: if a case value for the switch is unswitched out, we record it
1349
+ // after the unswitch finishes. We can not record it here as the switch
1350
+ // is not a direct user of the partial LIV.
1253
1351
SwitchInst::CaseIt DeadCase = SI->findCaseValue (cast<ConstantInt>(Val));
1254
1352
// Default case is live for multiple values.
1255
1353
if (DeadCase == SI->case_default ()) continue ;
@@ -1262,8 +1360,6 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,
1262
1360
BasicBlock *SISucc = DeadCase.getCaseSuccessor ();
1263
1361
BasicBlock *Latch = L->getLoopLatch ();
1264
1362
1265
- BranchesInfo.setUnswitched (SI, Val);
1266
-
1267
1363
if (!SI->findCaseDest (SISucc)) continue ; // Edge is critical.
1268
1364
// If the DeadCase successor dominates the loop latch, then the
1269
1365
// transformation isn't safe since it will delete the sole predecessor edge
0 commit comments