Skip to content

Commit bb4484d

Browse files
[OpenMPIRBuilder] Add support for target workshare loops (#73360)
The workshare loop for target region uses the new OpenMP device runtime. The code generation scheme for the new device runtime is presented below: Input code: ``` workshare-loop { loop-body } ``` Output code: helper function which represents loop body: ``` function-loop-body(counter, loop-body-args) { loop-body } ``` workshare-loop is replaced by the proper device runtime call: ``` call __kmpc_new_worksharing_rtl(function-loop-body, loop-body-args, loop-tripcount, ...) ``` This PR uses the new device runtime functions which were added in PR: #73225
1 parent de21308 commit bb4484d

File tree

4 files changed

+342
-4
lines changed

4 files changed

+342
-4
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPConstants.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,16 @@ enum class RTLDependenceKindTy {
277277
DepOmpAllMem = 0x80,
278278
};
279279

280+
/// A type of worksharing loop construct
281+
enum class WorksharingLoopType {
282+
// Worksharing `for`-loop
283+
ForStaticLoop,
284+
// Worksharing `distrbute`-loop
285+
DistributeStaticLoop,
286+
// Worksharing `distrbute parallel for`-loop
287+
DistributeForStaticLoop
288+
};
289+
280290
} // end namespace omp
281291

282292
} // end namespace llvm

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,28 @@ class OpenMPIRBuilder {
900900
omp::OpenMPOffloadMappingFlags MemberOfFlag);
901901

902902
private:
903+
/// Modifies the canonical loop to be a statically-scheduled workshare loop
904+
/// which is executed on the device
905+
///
906+
/// This takes a \p CLI representing a canonical loop, such as the one
907+
/// created by \see createCanonicalLoop and emits additional instructions to
908+
/// turn it into a workshare loop. In particular, it calls to an OpenMP
909+
/// runtime function in the preheader to call OpenMP device rtl function
910+
/// which handles worksharing of loop body interations.
911+
///
912+
/// \param DL Debug location for instructions added for the
913+
/// workshare-loop construct itself.
914+
/// \param CLI A descriptor of the canonical loop to workshare.
915+
/// \param AllocaIP An insertion point for Alloca instructions usable in the
916+
/// preheader of the loop.
917+
/// \param LoopType Information about type of loop worksharing.
918+
/// It corresponds to type of loop workshare OpenMP pragma.
919+
///
920+
/// \returns Point where to insert code after the workshare construct.
921+
InsertPointTy applyWorkshareLoopTarget(DebugLoc DL, CanonicalLoopInfo *CLI,
922+
InsertPointTy AllocaIP,
923+
omp::WorksharingLoopType LoopType);
924+
903925
/// Modifies the canonical loop to be a statically-scheduled workshare loop.
904926
///
905927
/// This takes a \p LoopInfo representing a canonical loop, such as the one
@@ -1012,6 +1034,8 @@ class OpenMPIRBuilder {
10121034
/// present in the schedule clause.
10131035
/// \param HasOrderedClause Whether the (parameterless) ordered clause is
10141036
/// present.
1037+
/// \param LoopType Information about type of loop worksharing.
1038+
/// It corresponds to type of loop workshare OpenMP pragma.
10151039
///
10161040
/// \returns Point where to insert code after the workshare construct.
10171041
InsertPointTy applyWorkshareLoop(
@@ -1020,7 +1044,9 @@ class OpenMPIRBuilder {
10201044
llvm::omp::ScheduleKind SchedKind = llvm::omp::OMP_SCHEDULE_Default,
10211045
Value *ChunkSize = nullptr, bool HasSimdModifier = false,
10221046
bool HasMonotonicModifier = false, bool HasNonmonotonicModifier = false,
1023-
bool HasOrderedClause = false);
1047+
bool HasOrderedClause = false,
1048+
omp::WorksharingLoopType LoopType =
1049+
omp::WorksharingLoopType::ForStaticLoop);
10241050

10251051
/// Tile a loop nest.
10261052
///

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 234 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2674,11 +2674,242 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
26742674
return {DispatchAfter, DispatchAfter->getFirstInsertionPt()};
26752675
}
26762676

2677+
// Returns an LLVM function to call for executing an OpenMP static worksharing
2678+
// for loop depending on `type`. Only i32 and i64 are supported by the runtime.
2679+
// Always interpret integers as unsigned similarly to CanonicalLoopInfo.
2680+
static FunctionCallee
2681+
getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder,
2682+
WorksharingLoopType LoopType) {
2683+
unsigned Bitwidth = Ty->getIntegerBitWidth();
2684+
Module &M = OMPBuilder->M;
2685+
switch (LoopType) {
2686+
case WorksharingLoopType::ForStaticLoop:
2687+
if (Bitwidth == 32)
2688+
return OMPBuilder->getOrCreateRuntimeFunction(
2689+
M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_4u);
2690+
if (Bitwidth == 64)
2691+
return OMPBuilder->getOrCreateRuntimeFunction(
2692+
M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_8u);
2693+
break;
2694+
case WorksharingLoopType::DistributeStaticLoop:
2695+
if (Bitwidth == 32)
2696+
return OMPBuilder->getOrCreateRuntimeFunction(
2697+
M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_4u);
2698+
if (Bitwidth == 64)
2699+
return OMPBuilder->getOrCreateRuntimeFunction(
2700+
M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_8u);
2701+
break;
2702+
case WorksharingLoopType::DistributeForStaticLoop:
2703+
if (Bitwidth == 32)
2704+
return OMPBuilder->getOrCreateRuntimeFunction(
2705+
M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_4u);
2706+
if (Bitwidth == 64)
2707+
return OMPBuilder->getOrCreateRuntimeFunction(
2708+
M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_8u);
2709+
break;
2710+
}
2711+
if (Bitwidth != 32 && Bitwidth != 64) {
2712+
llvm_unreachable("Unknown OpenMP loop iterator bitwidth");
2713+
}
2714+
llvm_unreachable("Unknown type of OpenMP worksharing loop");
2715+
}
2716+
2717+
// Inserts a call to proper OpenMP Device RTL function which handles
2718+
// loop worksharing.
2719+
static void createTargetLoopWorkshareCall(
2720+
OpenMPIRBuilder *OMPBuilder, WorksharingLoopType LoopType,
2721+
BasicBlock *InsertBlock, Value *Ident, Value *LoopBodyArg,
2722+
Type *ParallelTaskPtr, Value *TripCount, Function &LoopBodyFn) {
2723+
Type *TripCountTy = TripCount->getType();
2724+
Module &M = OMPBuilder->M;
2725+
IRBuilder<> &Builder = OMPBuilder->Builder;
2726+
FunctionCallee RTLFn =
2727+
getKmpcForStaticLoopForType(TripCountTy, OMPBuilder, LoopType);
2728+
SmallVector<Value *, 8> RealArgs;
2729+
RealArgs.push_back(Ident);
2730+
RealArgs.push_back(Builder.CreateBitCast(&LoopBodyFn, ParallelTaskPtr));
2731+
RealArgs.push_back(LoopBodyArg);
2732+
RealArgs.push_back(TripCount);
2733+
if (LoopType == WorksharingLoopType::DistributeStaticLoop) {
2734+
RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
2735+
Builder.CreateCall(RTLFn, RealArgs);
2736+
return;
2737+
}
2738+
FunctionCallee RTLNumThreads = OMPBuilder->getOrCreateRuntimeFunction(
2739+
M, omp::RuntimeFunction::OMPRTL_omp_get_num_threads);
2740+
Builder.restoreIP({InsertBlock, std::prev(InsertBlock->end())});
2741+
Value *NumThreads = Builder.CreateCall(RTLNumThreads, {});
2742+
2743+
RealArgs.push_back(
2744+
Builder.CreateZExtOrTrunc(NumThreads, TripCountTy, "num.threads.cast"));
2745+
RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
2746+
if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
2747+
RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
2748+
}
2749+
2750+
Builder.CreateCall(RTLFn, RealArgs);
2751+
}
2752+
2753+
static void
2754+
workshareLoopTargetCallback(OpenMPIRBuilder *OMPIRBuilder,
2755+
CanonicalLoopInfo *CLI, Value *Ident,
2756+
Function &OutlinedFn, Type *ParallelTaskPtr,
2757+
const SmallVector<Instruction *, 4> &ToBeDeleted,
2758+
WorksharingLoopType LoopType) {
2759+
IRBuilder<> &Builder = OMPIRBuilder->Builder;
2760+
BasicBlock *Preheader = CLI->getPreheader();
2761+
Value *TripCount = CLI->getTripCount();
2762+
2763+
// After loop body outling, the loop body contains only set up
2764+
// of loop body argument structure and the call to the outlined
2765+
// loop body function. Firstly, we need to move setup of loop body args
2766+
// into loop preheader.
2767+
Preheader->splice(std::prev(Preheader->end()), CLI->getBody(),
2768+
CLI->getBody()->begin(), std::prev(CLI->getBody()->end()));
2769+
2770+
// The next step is to remove the whole loop. We do not it need anymore.
2771+
// That's why make an unconditional branch from loop preheader to loop
2772+
// exit block
2773+
Builder.restoreIP({Preheader, Preheader->end()});
2774+
Preheader->getTerminator()->eraseFromParent();
2775+
Builder.CreateBr(CLI->getExit());
2776+
2777+
// Delete dead loop blocks
2778+
OpenMPIRBuilder::OutlineInfo CleanUpInfo;
2779+
SmallPtrSet<BasicBlock *, 32> RegionBlockSet;
2780+
SmallVector<BasicBlock *, 32> BlocksToBeRemoved;
2781+
CleanUpInfo.EntryBB = CLI->getHeader();
2782+
CleanUpInfo.ExitBB = CLI->getExit();
2783+
CleanUpInfo.collectBlocks(RegionBlockSet, BlocksToBeRemoved);
2784+
DeleteDeadBlocks(BlocksToBeRemoved);
2785+
2786+
// Find the instruction which corresponds to loop body argument structure
2787+
// and remove the call to loop body function instruction.
2788+
Value *LoopBodyArg;
2789+
User *OutlinedFnUser = OutlinedFn.getUniqueUndroppableUser();
2790+
assert(OutlinedFnUser &&
2791+
"Expected unique undroppable user of outlined function");
2792+
CallInst *OutlinedFnCallInstruction = dyn_cast<CallInst>(OutlinedFnUser);
2793+
assert(OutlinedFnCallInstruction && "Expected outlined function call");
2794+
assert((OutlinedFnCallInstruction->getParent() == Preheader) &&
2795+
"Expected outlined function call to be located in loop preheader");
2796+
// Check in case no argument structure has been passed.
2797+
if (OutlinedFnCallInstruction->arg_size() > 1)
2798+
LoopBodyArg = OutlinedFnCallInstruction->getArgOperand(1);
2799+
else
2800+
LoopBodyArg = Constant::getNullValue(Builder.getPtrTy());
2801+
OutlinedFnCallInstruction->eraseFromParent();
2802+
2803+
createTargetLoopWorkshareCall(OMPIRBuilder, LoopType, Preheader, Ident,
2804+
LoopBodyArg, ParallelTaskPtr, TripCount,
2805+
OutlinedFn);
2806+
2807+
for (auto &ToBeDeletedItem : ToBeDeleted)
2808+
ToBeDeletedItem->eraseFromParent();
2809+
CLI->invalidate();
2810+
}
2811+
2812+
OpenMPIRBuilder::InsertPointTy
2813+
OpenMPIRBuilder::applyWorkshareLoopTarget(DebugLoc DL, CanonicalLoopInfo *CLI,
2814+
InsertPointTy AllocaIP,
2815+
WorksharingLoopType LoopType) {
2816+
uint32_t SrcLocStrSize;
2817+
Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
2818+
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2819+
2820+
OutlineInfo OI;
2821+
OI.OuterAllocaBB = CLI->getPreheader();
2822+
Function *OuterFn = CLI->getPreheader()->getParent();
2823+
2824+
// Instructions which need to be deleted at the end of code generation
2825+
SmallVector<Instruction *, 4> ToBeDeleted;
2826+
2827+
OI.OuterAllocaBB = AllocaIP.getBlock();
2828+
2829+
// Mark the body loop as region which needs to be extracted
2830+
OI.EntryBB = CLI->getBody();
2831+
OI.ExitBB = CLI->getLatch()->splitBasicBlock(CLI->getLatch()->begin(),
2832+
"omp.prelatch", true);
2833+
2834+
// Prepare loop body for extraction
2835+
Builder.restoreIP({CLI->getPreheader(), CLI->getPreheader()->begin()});
2836+
2837+
// Insert new loop counter variable which will be used only in loop
2838+
// body.
2839+
AllocaInst *NewLoopCnt = Builder.CreateAlloca(CLI->getIndVarType(), 0, "");
2840+
Instruction *NewLoopCntLoad =
2841+
Builder.CreateLoad(CLI->getIndVarType(), NewLoopCnt);
2842+
// New loop counter instructions are redundant in the loop preheader when
2843+
// code generation for workshare loop is finshed. That's why mark them as
2844+
// ready for deletion.
2845+
ToBeDeleted.push_back(NewLoopCntLoad);
2846+
ToBeDeleted.push_back(NewLoopCnt);
2847+
2848+
// Analyse loop body region. Find all input variables which are used inside
2849+
// loop body region.
2850+
SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
2851+
SmallVector<BasicBlock *, 32> Blocks;
2852+
OI.collectBlocks(ParallelRegionBlockSet, Blocks);
2853+
SmallVector<BasicBlock *, 32> BlocksT(ParallelRegionBlockSet.begin(),
2854+
ParallelRegionBlockSet.end());
2855+
2856+
CodeExtractorAnalysisCache CEAC(*OuterFn);
2857+
CodeExtractor Extractor(Blocks,
2858+
/* DominatorTree */ nullptr,
2859+
/* AggregateArgs */ true,
2860+
/* BlockFrequencyInfo */ nullptr,
2861+
/* BranchProbabilityInfo */ nullptr,
2862+
/* AssumptionCache */ nullptr,
2863+
/* AllowVarArgs */ true,
2864+
/* AllowAlloca */ true,
2865+
/* AllocationBlock */ CLI->getPreheader(),
2866+
/* Suffix */ ".omp_wsloop",
2867+
/* AggrArgsIn0AddrSpace */ true);
2868+
2869+
BasicBlock *CommonExit = nullptr;
2870+
SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
2871+
2872+
// Find allocas outside the loop body region which are used inside loop
2873+
// body
2874+
Extractor.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
2875+
2876+
// We need to model loop body region as the function f(cnt, loop_arg).
2877+
// That's why we replace loop induction variable by the new counter
2878+
// which will be one of loop body function argument
2879+
for (auto Use = CLI->getIndVar()->user_begin();
2880+
Use != CLI->getIndVar()->user_end(); ++Use) {
2881+
if (Instruction *Inst = dyn_cast<Instruction>(*Use)) {
2882+
if (ParallelRegionBlockSet.count(Inst->getParent())) {
2883+
Inst->replaceUsesOfWith(CLI->getIndVar(), NewLoopCntLoad);
2884+
}
2885+
}
2886+
}
2887+
// Make sure that loop counter variable is not merged into loop body
2888+
// function argument structure and it is passed as separate variable
2889+
OI.ExcludeArgsFromAggregate.push_back(NewLoopCntLoad);
2890+
2891+
// PostOutline CB is invoked when loop body function is outlined and
2892+
// loop body is replaced by call to outlined function. We need to add
2893+
// call to OpenMP device rtl inside loop preheader. OpenMP device rtl
2894+
// function will handle loop control logic.
2895+
//
2896+
OI.PostOutlineCB = [=, ToBeDeletedVec =
2897+
std::move(ToBeDeleted)](Function &OutlinedFn) {
2898+
workshareLoopTargetCallback(this, CLI, Ident, OutlinedFn, ParallelTaskPtr,
2899+
ToBeDeletedVec, LoopType);
2900+
};
2901+
addOutlineInfo(std::move(OI));
2902+
return CLI->getAfterIP();
2903+
}
2904+
26772905
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoop(
26782906
DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
2679-
bool NeedsBarrier, llvm::omp::ScheduleKind SchedKind,
2680-
llvm::Value *ChunkSize, bool HasSimdModifier, bool HasMonotonicModifier,
2681-
bool HasNonmonotonicModifier, bool HasOrderedClause) {
2907+
bool NeedsBarrier, omp::ScheduleKind SchedKind, Value *ChunkSize,
2908+
bool HasSimdModifier, bool HasMonotonicModifier,
2909+
bool HasNonmonotonicModifier, bool HasOrderedClause,
2910+
WorksharingLoopType LoopType) {
2911+
if (Config.isTargetDevice())
2912+
return applyWorkshareLoopTarget(DL, CLI, AllocaIP, LoopType);
26822913
OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType(
26832914
SchedKind, ChunkSize, HasSimdModifier, HasMonotonicModifier,
26842915
HasNonmonotonicModifier, HasOrderedClause);

0 commit comments

Comments
 (0)