70
70
#include " llvm/Transforms/Scalar/LoopPassManager.h"
71
71
#include " llvm/Transforms/Utils/Local.h"
72
72
#include " llvm/Transforms/Utils/LoopUtils.h"
73
+ #include " llvm/Transforms/Utils/LoopVersioning.h"
73
74
#include " llvm/Transforms/Utils/ScalarEvolutionExpander.h"
74
75
#include " llvm/Transforms/Utils/SimplifyIndVar.h"
75
76
#include < optional>
@@ -97,6 +98,10 @@ static cl::opt<bool>
97
98
cl::desc(" Widen the loop induction variables, if possible, so "
98
99
" overflow checks won't reject flattening" ));
99
100
101
+ static cl::opt<bool >
102
+ VersionLoops (" loop-flatten-version-loops" , cl::Hidden, cl::init(true ),
103
+ cl::desc(" Version loops if flattened loop could overflow" ));
104
+
100
105
namespace {
101
106
// We require all uses of both induction variables to match this pattern:
102
107
//
@@ -141,6 +146,8 @@ struct FlattenInfo {
141
146
// has been applied. Used to skip
142
147
// checks on phi nodes.
143
148
149
+ Value *NewTripCount = nullptr ; // The tripcount of the flattened loop.
150
+
144
151
FlattenInfo (Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL){};
145
152
146
153
bool isNarrowInductionPhi (PHINode *Phi) {
@@ -752,11 +759,13 @@ static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
752
759
ORE.emit (Remark);
753
760
}
754
761
755
- Value *NewTripCount = BinaryOperator::CreateMul (
756
- FI.InnerTripCount , FI.OuterTripCount , " flatten.tripcount" ,
757
- FI.OuterLoop ->getLoopPreheader ()->getTerminator ());
758
- LLVM_DEBUG (dbgs () << " Created new trip count in preheader: " ;
759
- NewTripCount->dump ());
762
+ if (!FI.NewTripCount ) {
763
+ FI.NewTripCount = BinaryOperator::CreateMul (
764
+ FI.InnerTripCount , FI.OuterTripCount , " flatten.tripcount" ,
765
+ FI.OuterLoop ->getLoopPreheader ()->getTerminator ());
766
+ LLVM_DEBUG (dbgs () << " Created new trip count in preheader: " ;
767
+ FI.NewTripCount ->dump ());
768
+ }
760
769
761
770
// Fix up PHI nodes that take values from the inner loop back-edge, which
762
771
// we are about to remove.
@@ -769,7 +778,7 @@ static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
769
778
770
779
// Modify the trip count of the outer loop to be the product of the two
771
780
// trip counts.
772
- cast<User>(FI.OuterBranch ->getCondition ())->setOperand (1 , NewTripCount);
781
+ cast<User>(FI.OuterBranch ->getCondition ())->setOperand (1 , FI. NewTripCount );
773
782
774
783
// Replace the inner loop backedge with an unconditional branch to the exit.
775
784
BasicBlock *InnerExitBlock = FI.InnerLoop ->getExitBlock ();
@@ -891,7 +900,8 @@ static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
891
900
static bool FlattenLoopPair (FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
892
901
ScalarEvolution *SE, AssumptionCache *AC,
893
902
const TargetTransformInfo *TTI, LPMUpdater *U,
894
- MemorySSAUpdater *MSSAU) {
903
+ MemorySSAUpdater *MSSAU,
904
+ const LoopAccessInfo &LAI) {
895
905
LLVM_DEBUG (
896
906
dbgs () << " Loop flattening running on outer loop "
897
907
<< FI.OuterLoop ->getHeader ()->getName () << " and inner loop "
@@ -926,18 +936,55 @@ static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
926
936
// variable might overflow. In this case, we need to version the loop, and
927
937
// select the original version at runtime if the iteration space is too
928
938
// large.
929
- // TODO: We currently don't version the loop.
930
939
OverflowResult OR = checkOverflow (FI, DT, AC);
931
940
if (OR == OverflowResult::AlwaysOverflowsHigh ||
932
941
OR == OverflowResult::AlwaysOverflowsLow) {
933
942
LLVM_DEBUG (dbgs () << " Multiply would always overflow, so not profitable\n " );
934
943
return false ;
935
944
} else if (OR == OverflowResult::MayOverflow) {
936
- LLVM_DEBUG (dbgs () << " Multiply might overflow, not flattening\n " );
937
- return false ;
945
+ Module *M = FI.OuterLoop ->getHeader ()->getParent ()->getParent ();
946
+ const DataLayout &DL = M->getDataLayout ();
947
+ if (!VersionLoops) {
948
+ LLVM_DEBUG (dbgs () << " Multiply might overflow, not flattening\n " );
949
+ return false ;
950
+ } else if (!DL.isLegalInteger (
951
+ FI.OuterTripCount ->getType ()->getScalarSizeInBits ())) {
952
+ // If the trip count type isn't legal then it won't be possible to check
953
+ // for overflow using only a single multiply instruction, so don't
954
+ // flatten.
955
+ LLVM_DEBUG (
956
+ dbgs () << " Can't check overflow efficiently, not flattening\n " );
957
+ return false ;
958
+ }
959
+ LLVM_DEBUG (dbgs () << " Multiply might overflow, versioning loop\n " );
960
+
961
+ // Version the loop. The overflow check isn't a runtime pointer check, so we
962
+ // pass an empty list of runtime pointer checks, causing LoopVersioning to
963
+ // emit 'false' as the branch condition, and add our own check afterwards.
964
+ BasicBlock *CheckBlock = FI.OuterLoop ->getLoopPreheader ();
965
+ ArrayRef<RuntimePointerCheck> Checks (nullptr , nullptr );
966
+ LoopVersioning LVer (LAI, Checks, FI.OuterLoop , LI, DT, SE);
967
+ LVer.versionLoop ();
968
+
969
+ // Check for overflow by calculating the new tripcount using
970
+ // umul_with_overflow and then checking if it overflowed.
971
+ BranchInst *Br = cast<BranchInst>(CheckBlock->getTerminator ());
972
+ assert (Br->isConditional () &&
973
+ " Expected LoopVersioning to generate a conditional branch" );
974
+ assert (match (Br->getCondition (), m_Zero ()) &&
975
+ " Expected branch condition to be false" );
976
+ IRBuilder<> Builder (Br);
977
+ Function *F = Intrinsic::getDeclaration (M, Intrinsic::umul_with_overflow,
978
+ FI.OuterTripCount ->getType ());
979
+ Value *Call = Builder.CreateCall (F, {FI.OuterTripCount , FI.InnerTripCount },
980
+ " flatten.mul" );
981
+ FI.NewTripCount = Builder.CreateExtractValue (Call, 0 , " flatten.tripcount" );
982
+ Value *Overflow = Builder.CreateExtractValue (Call, 1 , " flatten.overflow" );
983
+ Br->setCondition (Overflow);
984
+ } else {
985
+ LLVM_DEBUG (dbgs () << " Multiply cannot overflow, modifying loop in-place\n " );
938
986
}
939
987
940
- LLVM_DEBUG (dbgs () << " Multiply cannot overflow, modifying loop in-place\n " );
941
988
return DoFlattenLoopPair (FI, DT, LI, SE, AC, TTI, U, MSSAU);
942
989
}
943
990
@@ -958,13 +1005,15 @@ PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM,
958
1005
// in simplified form, and also needs LCSSA. Running
959
1006
// this pass will simplify all loops that contain inner loops,
960
1007
// regardless of whether anything ends up being flattened.
1008
+ LoopAccessInfoManager LAIM (AR.SE , AR.AA , AR.DT , AR.LI , nullptr );
961
1009
for (Loop *InnerLoop : LN.getLoops ()) {
962
1010
auto *OuterLoop = InnerLoop->getParentLoop ();
963
1011
if (!OuterLoop)
964
1012
continue ;
965
1013
FlattenInfo FI (OuterLoop, InnerLoop);
966
- Changed |= FlattenLoopPair (FI, &AR.DT , &AR.LI , &AR.SE , &AR.AC , &AR.TTI , &U,
967
- MSSAU ? &*MSSAU : nullptr );
1014
+ Changed |=
1015
+ FlattenLoopPair (FI, &AR.DT , &AR.LI , &AR.SE , &AR.AC , &AR.TTI , &U,
1016
+ MSSAU ? &*MSSAU : nullptr , LAIM.getInfo (*OuterLoop));
968
1017
}
969
1018
970
1019
if (!Changed)
0 commit comments