Skip to content

Commit c102c78

Browse files
committed
[OpenMPIRBuilder] introduce createStaticWorkshareLoop
Introduce a function that creates a statically-scheduled workshare loop out of a canonical loop created earlier by the OpenMPIRBuilder. This basically amounts to injecting runtime calls to the preheader and the after block and updating the trip count. Static scheduling kind is currently hardcoded and needs to be extracted from the runtime library into common TableGen definitions. Differential Revision: https://reviews.llvm.org/D92476
1 parent 6249bfe commit c102c78

File tree

3 files changed

+231
-3
lines changed

3 files changed

+231
-3
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,32 @@ class OpenMPIRBuilder {
260260
Value *Start, Value *Stop, Value *Step,
261261
bool IsSigned, bool InclusiveStop);
262262

263+
/// Modifies the canonical loop to be a statically-scheduled workshare loop.
264+
///
265+
/// This takes a \p LoopInfo representing a canonical loop, such as the one
266+
/// created by \p createCanonicalLoop and emits additional instructions to
267+
/// turn it into a workshare loop. In particular, it calls to an OpenMP
268+
/// runtime function in the preheader to obtain the loop bounds to be used in
269+
/// the current thread, updates the relevant instructions in the canonical
270+
/// loop and calls to an OpenMP runtime finalization function after the loop.
271+
///
272+
/// \param Loc The source location description, the insertion location
273+
/// is not used.
274+
/// \param CLI A descriptor of the canonical loop to workshare.
275+
/// \param AllocaIP An insertion point for Alloca instructions usable in the
276+
/// preheader of the loop.
277+
/// \param NeedsBarrier Indicates whether a barrier must be insterted after
278+
/// the loop.
279+
/// \param Chunk The size of loop chunk considered as a unit when
280+
/// scheduling. If \p nullptr, defaults to 1.
281+
///
282+
/// \returns Updated CanonicalLoopInfo.
283+
CanonicalLoopInfo *createStaticWorkshareLoop(const LocationDescription &Loc,
284+
CanonicalLoopInfo *CLI,
285+
InsertPointTy AllocaIP,
286+
bool NeedsBarrier,
287+
Value *Chunk = nullptr);
288+
263289
/// Generator for '#omp flush'
264290
///
265291
/// \param Loc The location where the flush directive was encountered
@@ -636,15 +662,19 @@ class OpenMPIRBuilder {
636662
/// | Cond---\
637663
/// | | |
638664
/// | Body |
639-
/// | | |
665+
/// | | | |
666+
/// | <...> |
667+
/// | | | |
640668
/// \--Latch |
641669
/// |
642670
/// Exit
643671
/// |
644672
/// After
645673
///
646674
/// Code in the header, condition block, latch and exit block must not have any
647-
/// side-effect.
675+
/// side-effect. The body block is the single entry point into the loop body,
676+
/// which may contain arbitrary control flow as long as all control paths
677+
/// eventually branch to the latch block.
648678
///
649679
/// Defined outside OpenMPIRBuilder because one cannot forward-declare nested
650680
/// classes.
@@ -701,7 +731,7 @@ class CanonicalLoopInfo {
701731
/// statements/cancellations).
702732
BasicBlock *getAfter() const { return After; }
703733

704-
/// Returns the llvm::Value containing the number of loop iterations. I must
734+
/// Returns the llvm::Value containing the number of loop iterations. It must
705735
/// be valid in the preheader and always interpreted as an unsigned integer of
706736
/// any bit-width.
707737
Value *getTripCount() const {

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -999,6 +999,118 @@ CanonicalLoopInfo *OpenMPIRBuilder::createCanonicalLoop(
999999
return createCanonicalLoop(Builder.saveIP(), BodyGen, TripCount);
10001000
}
10011001

1002+
// Returns an LLVM function to call for initializing loop bounds using OpenMP
1003+
// static scheduling depending on `type`. Only i32 and i64 are supported by the
1004+
// runtime. Always interpret integers as unsigned similarly to
1005+
// CanonicalLoopInfo.
1006+
static FunctionCallee getKmpcForStaticInitForType(Type *Ty, Module &M,
1007+
OpenMPIRBuilder &OMPBuilder) {
1008+
unsigned Bitwidth = Ty->getIntegerBitWidth();
1009+
if (Bitwidth == 32)
1010+
return OMPBuilder.getOrCreateRuntimeFunction(
1011+
M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_4u);
1012+
if (Bitwidth == 64)
1013+
return OMPBuilder.getOrCreateRuntimeFunction(
1014+
M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_8u);
1015+
llvm_unreachable("unknown OpenMP loop iterator bitwidth");
1016+
}
1017+
1018+
// Sets the number of loop iterations to the given value. This value must be
1019+
// valid in the condition block (i.e., defined in the preheader) and is
1020+
// interpreted as an unsigned integer.
1021+
void setCanonicalLoopTripCount(CanonicalLoopInfo *CLI, Value *TripCount) {
1022+
Instruction *CmpI = &CLI->getCond()->front();
1023+
assert(isa<CmpInst>(CmpI) && "First inst must compare IV with TripCount");
1024+
CmpI->setOperand(1, TripCount);
1025+
CLI->assertOK();
1026+
}
1027+
1028+
CanonicalLoopInfo *OpenMPIRBuilder::createStaticWorkshareLoop(
1029+
const LocationDescription &Loc, CanonicalLoopInfo *CLI,
1030+
InsertPointTy AllocaIP, bool NeedsBarrier, Value *Chunk) {
1031+
// Set up the source location value for OpenMP runtime.
1032+
if (!updateToLocation(Loc))
1033+
return nullptr;
1034+
1035+
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
1036+
Value *SrcLoc = getOrCreateIdent(SrcLocStr);
1037+
1038+
// Declare useful OpenMP runtime functions.
1039+
Value *IV = CLI->getIndVar();
1040+
Type *IVTy = IV->getType();
1041+
FunctionCallee StaticInit = getKmpcForStaticInitForType(IVTy, M, *this);
1042+
FunctionCallee StaticFini =
1043+
getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_for_static_fini);
1044+
1045+
// Allocate space for computed loop bounds as expected by the "init" function.
1046+
Builder.restoreIP(AllocaIP);
1047+
Type *I32Type = Type::getInt32Ty(M.getContext());
1048+
Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
1049+
Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
1050+
Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
1051+
Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
1052+
1053+
// At the end of the preheader, prepare for calling the "init" function by
1054+
// storing the current loop bounds into the allocated space. A canonical loop
1055+
// always iterates from 0 to trip-count with step 1. Note that "init" expects
1056+
// and produces an inclusive upper bound.
1057+
Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
1058+
Constant *Zero = ConstantInt::get(IVTy, 0);
1059+
Constant *One = ConstantInt::get(IVTy, 1);
1060+
Builder.CreateStore(Zero, PLowerBound);
1061+
Value *UpperBound = Builder.CreateSub(CLI->getTripCount(), One);
1062+
Builder.CreateStore(UpperBound, PUpperBound);
1063+
Builder.CreateStore(One, PStride);
1064+
1065+
if (!Chunk)
1066+
Chunk = One;
1067+
1068+
Value *ThreadNum = getOrCreateThreadID(SrcLoc);
1069+
1070+
// TODO: extract scheduling type and map it to OMP constant. This is curently
1071+
// happening in kmp.h and its ilk and needs to be moved to OpenMP.td first.
1072+
constexpr int StaticSchedType = 34;
1073+
Constant *SchedulingType = ConstantInt::get(I32Type, StaticSchedType);
1074+
1075+
// Call the "init" function and update the trip count of the loop with the
1076+
// value it produced.
1077+
Builder.CreateCall(StaticInit,
1078+
{SrcLoc, ThreadNum, SchedulingType, PLastIter, PLowerBound,
1079+
PUpperBound, PStride, One, Chunk});
1080+
Value *LowerBound = Builder.CreateLoad(PLowerBound);
1081+
Value *InclusiveUpperBound = Builder.CreateLoad(PUpperBound);
1082+
Value *TripCountMinusOne = Builder.CreateSub(InclusiveUpperBound, LowerBound);
1083+
Value *TripCount = Builder.CreateAdd(TripCountMinusOne, One);
1084+
setCanonicalLoopTripCount(CLI, TripCount);
1085+
1086+
// Update all uses of the induction variable except the one in the condition
1087+
// block that compares it with the actual upper bound, and the increment in
1088+
// the latch block.
1089+
// TODO: this can eventually move to CanonicalLoopInfo or to a new
1090+
// CanonicalLoopInfoUpdater interface.
1091+
Builder.SetInsertPoint(CLI->getBody(), CLI->getBody()->getFirstInsertionPt());
1092+
Value *UpdatedIV = Builder.CreateAdd(IV, LowerBound);
1093+
IV->replaceUsesWithIf(UpdatedIV, [&](Use &U) {
1094+
auto *Instr = dyn_cast<Instruction>(U.getUser());
1095+
return !Instr ||
1096+
(Instr->getParent() != CLI->getCond() &&
1097+
Instr->getParent() != CLI->getLatch() && Instr != UpdatedIV);
1098+
});
1099+
1100+
// In the "exit" block, call the "fini" function.
1101+
Builder.SetInsertPoint(CLI->getExit(),
1102+
CLI->getExit()->getTerminator()->getIterator());
1103+
Builder.CreateCall(StaticFini, {SrcLoc, ThreadNum});
1104+
1105+
// Add the barrier if requested.
1106+
if (NeedsBarrier)
1107+
createBarrier(Loc, omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
1108+
/* CheckCancelFlag */ false);
1109+
1110+
CLI->assertOK();
1111+
return CLI;
1112+
}
1113+
10021114
void CanonicalLoopInfo::eraseFromParent() {
10031115
assert(IsValid && "can only erase previously valid loop cfg");
10041116
IsValid = false;

llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,6 +1071,92 @@ TEST_F(OpenMPIRBuilderTest, CanonicalLoopBounds) {
10711071
EXPECT_FALSE(verifyModule(*M, &errs()));
10721072
}
10731073

1074+
TEST_F(OpenMPIRBuilderTest, StaticWorkShareLoop) {
1075+
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
1076+
OpenMPIRBuilder OMPBuilder(*M);
1077+
OMPBuilder.initialize();
1078+
IRBuilder<> Builder(BB);
1079+
OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
1080+
1081+
Type *LCTy = Type::getInt32Ty(Ctx);
1082+
Value *StartVal = ConstantInt::get(LCTy, 10);
1083+
Value *StopVal = ConstantInt::get(LCTy, 52);
1084+
Value *StepVal = ConstantInt::get(LCTy, 2);
1085+
auto LoopBodyGen = [&](InsertPointTy, llvm::Value *) {};
1086+
1087+
CanonicalLoopInfo *CLI = OMPBuilder.createCanonicalLoop(
1088+
Loc, LoopBodyGen, StartVal, StopVal, StepVal,
1089+
/*IsSigned=*/false, /*InclusiveStop=*/false);
1090+
1091+
Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
1092+
InsertPointTy AllocaIP = Builder.saveIP();
1093+
1094+
CLI = OMPBuilder.createStaticWorkshareLoop(Loc, CLI, AllocaIP,
1095+
/*NeedsBarrier=*/true);
1096+
auto AllocaIter = BB->begin();
1097+
ASSERT_GE(std::distance(BB->begin(), BB->end()), 4);
1098+
AllocaInst *PLastIter = dyn_cast<AllocaInst>(&*(AllocaIter++));
1099+
AllocaInst *PLowerBound = dyn_cast<AllocaInst>(&*(AllocaIter++));
1100+
AllocaInst *PUpperBound = dyn_cast<AllocaInst>(&*(AllocaIter++));
1101+
AllocaInst *PStride = dyn_cast<AllocaInst>(&*(AllocaIter++));
1102+
EXPECT_NE(PLastIter, nullptr);
1103+
EXPECT_NE(PLowerBound, nullptr);
1104+
EXPECT_NE(PUpperBound, nullptr);
1105+
EXPECT_NE(PStride, nullptr);
1106+
1107+
auto PreheaderIter = CLI->getPreheader()->begin();
1108+
ASSERT_GE(
1109+
std::distance(CLI->getPreheader()->begin(), CLI->getPreheader()->end()),
1110+
7);
1111+
StoreInst *LowerBoundStore = dyn_cast<StoreInst>(&*(PreheaderIter++));
1112+
StoreInst *UpperBoundStore = dyn_cast<StoreInst>(&*(PreheaderIter++));
1113+
StoreInst *StrideStore = dyn_cast<StoreInst>(&*(PreheaderIter++));
1114+
ASSERT_NE(LowerBoundStore, nullptr);
1115+
ASSERT_NE(UpperBoundStore, nullptr);
1116+
ASSERT_NE(StrideStore, nullptr);
1117+
1118+
auto *OrigLowerBound =
1119+
dyn_cast<ConstantInt>(LowerBoundStore->getValueOperand());
1120+
auto *OrigUpperBound =
1121+
dyn_cast<ConstantInt>(UpperBoundStore->getValueOperand());
1122+
auto *OrigStride = dyn_cast<ConstantInt>(StrideStore->getValueOperand());
1123+
ASSERT_NE(OrigLowerBound, nullptr);
1124+
ASSERT_NE(OrigUpperBound, nullptr);
1125+
ASSERT_NE(OrigStride, nullptr);
1126+
EXPECT_EQ(OrigLowerBound->getValue(), 0);
1127+
EXPECT_EQ(OrigUpperBound->getValue(), 20);
1128+
EXPECT_EQ(OrigStride->getValue(), 1);
1129+
1130+
// Check that the loop IV is updated to account for the lower bound returned
1131+
// by the OpenMP runtime call.
1132+
BinaryOperator *Add = dyn_cast<BinaryOperator>(&CLI->getBody()->front());
1133+
EXPECT_EQ(Add->getOperand(0), CLI->getIndVar());
1134+
auto *LoadedLowerBound = dyn_cast<LoadInst>(Add->getOperand(1));
1135+
ASSERT_NE(LoadedLowerBound, nullptr);
1136+
EXPECT_EQ(LoadedLowerBound->getPointerOperand(), PLowerBound);
1137+
1138+
// Check that the trip count is updated to account for the lower and upper
1139+
// bounds return by the OpenMP runtime call.
1140+
auto *AddOne = dyn_cast<Instruction>(CLI->getTripCount());
1141+
ASSERT_NE(AddOne, nullptr);
1142+
ASSERT_TRUE(AddOne->isBinaryOp());
1143+
auto *One = dyn_cast<ConstantInt>(AddOne->getOperand(1));
1144+
ASSERT_NE(One, nullptr);
1145+
EXPECT_EQ(One->getValue(), 1);
1146+
auto *Difference = dyn_cast<Instruction>(AddOne->getOperand(0));
1147+
ASSERT_NE(Difference, nullptr);
1148+
ASSERT_TRUE(Difference->isBinaryOp());
1149+
EXPECT_EQ(Difference->getOperand(1), LoadedLowerBound);
1150+
auto *LoadedUpperBound = dyn_cast<LoadInst>(Difference->getOperand(0));
1151+
ASSERT_NE(LoadedUpperBound, nullptr);
1152+
EXPECT_EQ(LoadedUpperBound->getPointerOperand(), PUpperBound);
1153+
1154+
// The original loop iterator should only be used in the condition, in the
1155+
// increment and in the statement that adds the lower bound to it.
1156+
Value *IV = CLI->getIndVar();
1157+
EXPECT_EQ(std::distance(IV->use_begin(), IV->use_end()), 3);
1158+
}
1159+
10741160
TEST_F(OpenMPIRBuilderTest, MasterDirective) {
10751161
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
10761162
OpenMPIRBuilder OMPBuilder(*M);

0 commit comments

Comments
 (0)