Skip to content

Commit 7029b6f

Browse files
[OpenMPIRBuilder] Add support for target workshare loops
The workshare loop for target region uses the new OpenMP device runtime. The code generation scheme for the new device runtime is presented below: Input code: workshare-loop { loop-body } Output code: helper function: function-loop-body(counter, loop-body-args) { loop-body } workshare-loop is replaced by the proper device runtime call: call __kmpc_new_worksharing_rtl(function-loop-body, loop-body-args, loop-tripcount, ...)
1 parent e78a45d commit 7029b6f

File tree

3 files changed

+348
-2
lines changed

3 files changed

+348
-2
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,16 @@ class OffloadEntriesInfoManager {
439439
/// Each OpenMP directive has a corresponding public generator method.
440440
class OpenMPIRBuilder {
441441
public:
442+
/// A type of worksharing loop construct
443+
enum class WorksharingLoopType {
444+
// Worksharing `for`-loop
445+
ForStaticLoop,
446+
// Worksharing `distrbute`-loop
447+
DistributeStaticLoop,
448+
// Worksharing `distrbute parallel for`-loop
449+
DistributeForStaticLoop
450+
};
451+
442452
/// Create a new OpenMPIRBuilder operating on the given module \p M. This will
443453
/// not have an effect on \p M (see initialize)
444454
OpenMPIRBuilder(Module &M)
@@ -900,6 +910,28 @@ class OpenMPIRBuilder {
900910
omp::OpenMPOffloadMappingFlags MemberOfFlag);
901911

902912
private:
913+
/// Modifies the canonical loop to be a statically-scheduled workshare loop
914+
/// which is executed on the device
915+
///
916+
/// This takes a \p LoopInfo representing a canonical loop, such as the one
917+
/// created by \p createCanonicalLoop and emits additional instructions to
918+
/// turn it into a workshare loop. In particular, it calls to an OpenMP
919+
/// runtime function in the preheader to call OpenMP device rtl function
920+
/// which handles worksharing of loop body interations.
921+
///
922+
/// \param DL Debug location for instructions added for the
923+
/// workshare-loop construct itself.
924+
/// \param CLI A descriptor of the canonical loop to workshare.
925+
/// \param AllocaIP An insertion point for Alloca instructions usable in the
926+
/// preheader of the loop.
927+
/// \param LoopType Information about type of loop worksharing.
928+
/// It corresponds to type of loop workshare OpenMP pragma.
929+
///
930+
/// \returns Point where to insert code after the workshare construct.
931+
InsertPointTy applyWorkshareLoopTarget(DebugLoc DL, CanonicalLoopInfo *CLI,
932+
InsertPointTy AllocaIP,
933+
WorksharingLoopType LoopType);
934+
903935
/// Modifies the canonical loop to be a statically-scheduled workshare loop.
904936
///
905937
/// This takes a \p LoopInfo representing a canonical loop, such as the one
@@ -1012,6 +1044,8 @@ class OpenMPIRBuilder {
10121044
/// present in the schedule clause.
10131045
/// \param HasOrderedClause Whether the (parameterless) ordered clause is
10141046
/// present.
1047+
/// \param LoopType Information about type of loop worksharing.
1048+
/// It corresponds to type of loop workshare OpenMP pragma.
10151049
///
10161050
/// \returns Point where to insert code after the workshare construct.
10171051
InsertPointTy applyWorkshareLoop(
@@ -1020,7 +1054,8 @@ class OpenMPIRBuilder {
10201054
llvm::omp::ScheduleKind SchedKind = llvm::omp::OMP_SCHEDULE_Default,
10211055
Value *ChunkSize = nullptr, bool HasSimdModifier = false,
10221056
bool HasMonotonicModifier = false, bool HasNonmonotonicModifier = false,
1023-
bool HasOrderedClause = false);
1057+
bool HasOrderedClause = false,
1058+
WorksharingLoopType LoopType = WorksharingLoopType::ForStaticLoop);
10241059

10251060
/// Tile a loop nest.
10261061
///

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 245 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2674,11 +2674,255 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
26742674
return {DispatchAfter, DispatchAfter->getFirstInsertionPt()};
26752675
}
26762676

2677+
// Returns an LLVM function to call for executing an OpenMP static worksharing
2678+
// for loop depending on `type`. Only i32 and i64 are supported by the runtime.
2679+
// Always interpret integers as unsigned similarly to CanonicalLoopInfo.
2680+
static FunctionCallee
2681+
getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder,
2682+
OpenMPIRBuilder::WorksharingLoopType LoopType) {
2683+
unsigned Bitwidth = Ty->getIntegerBitWidth();
2684+
Module &M = OMPBuilder->M;
2685+
switch (LoopType) {
2686+
case OpenMPIRBuilder::WorksharingLoopType::ForStaticLoop:
2687+
if (Bitwidth == 32)
2688+
return OMPBuilder->getOrCreateRuntimeFunction(
2689+
M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_4u);
2690+
if (Bitwidth == 64)
2691+
return OMPBuilder->getOrCreateRuntimeFunction(
2692+
M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_8u);
2693+
break;
2694+
case OpenMPIRBuilder::WorksharingLoopType::DistributeStaticLoop:
2695+
if (Bitwidth == 32)
2696+
return OMPBuilder->getOrCreateRuntimeFunction(
2697+
M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_4u);
2698+
if (Bitwidth == 64)
2699+
return OMPBuilder->getOrCreateRuntimeFunction(
2700+
M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_8u);
2701+
break;
2702+
case OpenMPIRBuilder::WorksharingLoopType::DistributeForStaticLoop:
2703+
if (Bitwidth == 32)
2704+
return OMPBuilder->getOrCreateRuntimeFunction(
2705+
M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_4u);
2706+
if (Bitwidth == 64)
2707+
return OMPBuilder->getOrCreateRuntimeFunction(
2708+
M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_8u);
2709+
break;
2710+
}
2711+
if (Bitwidth != 32 && Bitwidth != 64)
2712+
llvm_unreachable("unknown OpenMP loop iterator bitwidth");
2713+
return FunctionCallee();
2714+
}
2715+
2716+
// Inserts a call to proper OpenMP Device RTL function which handles
2717+
// loop worksharing.
2718+
static void createTargetLoopWorkshareCall(
2719+
OpenMPIRBuilder *OMPBuilder, OpenMPIRBuilder::WorksharingLoopType LoopType,
2720+
BasicBlock *InsertBlock, Value *Ident, Value *LoopBodyArg,
2721+
Type *ParallelTaskPtr, Value *TripCount, Function &LoopBodyFn) {
2722+
Type *TripCountTy = TripCount->getType();
2723+
Module &M = OMPBuilder->M;
2724+
IRBuilder<> &Builder = OMPBuilder->Builder;
2725+
FunctionCallee RTLFn =
2726+
getKmpcForStaticLoopForType(TripCountTy, OMPBuilder, LoopType);
2727+
SmallVector<Value *, 8> RealArgs;
2728+
RealArgs.push_back(Ident);
2729+
/*loop body func*/
2730+
RealArgs.push_back(Builder.CreateBitCast(&LoopBodyFn, ParallelTaskPtr));
2731+
/*loop body args*/
2732+
RealArgs.push_back(LoopBodyArg);
2733+
/*num of iters*/
2734+
RealArgs.push_back(TripCount);
2735+
if (LoopType == OpenMPIRBuilder::WorksharingLoopType::DistributeStaticLoop) {
2736+
/*block chunk*/ RealArgs.push_back(TripCountTy->getIntegerBitWidth() == 32
2737+
? Builder.getInt32(0)
2738+
: Builder.getInt64(0));
2739+
Builder.CreateCall(RTLFn, RealArgs);
2740+
return;
2741+
}
2742+
FunctionCallee RTLNumThreads = OMPBuilder->getOrCreateRuntimeFunction(
2743+
M, omp::RuntimeFunction::OMPRTL_omp_get_num_threads);
2744+
Builder.restoreIP({InsertBlock, std::prev(InsertBlock->end())});
2745+
Value *NumThreads = Builder.CreateCall(RTLNumThreads, {});
2746+
2747+
/*num of threads*/ RealArgs.push_back(
2748+
Builder.CreateZExtOrTrunc(NumThreads, TripCountTy, "num.threads.cast"));
2749+
if (LoopType ==
2750+
OpenMPIRBuilder::WorksharingLoopType::DistributeForStaticLoop) {
2751+
/*block chunk*/ RealArgs.push_back(TripCountTy->getIntegerBitWidth() == 32
2752+
? Builder.getInt32(0)
2753+
: Builder.getInt64(0));
2754+
}
2755+
/*thread chunk */ RealArgs.push_back(TripCountTy->getIntegerBitWidth() == 32
2756+
? Builder.getInt32(1)
2757+
: Builder.getInt64(1));
2758+
2759+
Builder.CreateCall(RTLFn, RealArgs);
2760+
}
2761+
2762+
static void
2763+
workshareLoopTargetCallback(OpenMPIRBuilder *OMPIRBuilder,
2764+
CanonicalLoopInfo *CLI, Value *Ident,
2765+
Function &OutlinedFn, Type *ParallelTaskPtr,
2766+
const SmallVector<Instruction *, 4> &ToBeDeleted,
2767+
OpenMPIRBuilder::WorksharingLoopType LoopType) {
2768+
IRBuilder<> &Builder = OMPIRBuilder->Builder;
2769+
BasicBlock *Preheader = CLI->getPreheader();
2770+
Value *TripCount = CLI->getTripCount();
2771+
2772+
// After loop body outling, the loop body contains only set up
2773+
// of loop body argument structure and the call to the outlined
2774+
// loop body function. Firstly, we need to move setup of loop body args
2775+
// into loop preheader.
2776+
Preheader->splice(std::prev(Preheader->end()), CLI->getBody(),
2777+
CLI->getBody()->begin(), std::prev(CLI->getBody()->end()));
2778+
2779+
// The next step is to remove the whole loop. We do not it need anymore.
2780+
// That's why make an unconditional branch from loop preheader to loop
2781+
// exit block
2782+
Builder.restoreIP({Preheader, Preheader->end()});
2783+
Preheader->getTerminator()->eraseFromParent();
2784+
Builder.CreateBr(CLI->getExit());
2785+
2786+
// Delete dead loop blocks
2787+
OpenMPIRBuilder::OutlineInfo CleanUpInfo;
2788+
SmallPtrSet<BasicBlock *, 32> RegionBlockSet;
2789+
SmallVector<BasicBlock *, 32> BlocksToBeRemoved;
2790+
CleanUpInfo.EntryBB = CLI->getHeader();
2791+
CleanUpInfo.ExitBB = CLI->getExit();
2792+
CleanUpInfo.collectBlocks(RegionBlockSet, BlocksToBeRemoved);
2793+
DeleteDeadBlocks(BlocksToBeRemoved);
2794+
2795+
// Find the instruction which corresponds to loop body argument structure
2796+
// and remove the call to loop body function instruction.
2797+
Value *LoopBodyArg;
2798+
for (auto instIt = Preheader->begin(); instIt != Preheader->end(); ++instIt) {
2799+
if (CallInst *CallInstruction = dyn_cast<CallInst>(instIt)) {
2800+
if (CallInstruction->getCalledFunction() == &OutlinedFn) {
2801+
// Check in case no argument structure has been passed.
2802+
if (CallInstruction->arg_size() > 1)
2803+
LoopBodyArg = CallInstruction->getArgOperand(1);
2804+
else
2805+
LoopBodyArg = Constant::getNullValue(Builder.getPtrTy());
2806+
CallInstruction->eraseFromParent();
2807+
break;
2808+
}
2809+
}
2810+
}
2811+
2812+
createTargetLoopWorkshareCall(OMPIRBuilder, LoopType, Preheader, Ident,
2813+
LoopBodyArg, ParallelTaskPtr, TripCount,
2814+
OutlinedFn);
2815+
2816+
for (auto &ToBeDeletedItem : ToBeDeleted)
2817+
ToBeDeletedItem->eraseFromParent();
2818+
CLI->invalidate();
2819+
}
2820+
2821+
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoopTarget(
2822+
DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
2823+
OpenMPIRBuilder::WorksharingLoopType LoopType) {
2824+
uint32_t SrcLocStrSize;
2825+
Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
2826+
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2827+
2828+
OutlineInfo OI;
2829+
OI.OuterAllocaBB = CLI->getPreheader();
2830+
Function *OuterFn = CLI->getPreheader()->getParent();
2831+
2832+
// Instructions which need to be deleted at the end of code generation
2833+
SmallVector<Instruction *, 4> ToBeDeleted;
2834+
2835+
OI.OuterAllocaBB = AllocaIP.getBlock();
2836+
2837+
// Mark the body loop as region which needs to be extracted
2838+
OI.EntryBB = CLI->getBody();
2839+
OI.ExitBB = CLI->getLatch()->splitBasicBlock(CLI->getLatch()->begin(),
2840+
"omp.prelatch", true);
2841+
2842+
// Prepare loop body for extraction
2843+
Builder.restoreIP({CLI->getPreheader(), CLI->getPreheader()->begin()});
2844+
2845+
// Insert new loop counter variable which will be used only in loop
2846+
// body.
2847+
AllocaInst *newLoopCnt = Builder.CreateAlloca(CLI->getIndVarType(), 0, "");
2848+
Instruction *newLoopCntLoad =
2849+
Builder.CreateLoad(CLI->getIndVarType(), newLoopCnt);
2850+
// New loop counter instructions are redundant in the loop preheader when
2851+
// code generation for workshare loop is finshed. That's why mark them as
2852+
// ready for deletion.
2853+
ToBeDeleted.push_back(newLoopCntLoad);
2854+
ToBeDeleted.push_back(newLoopCnt);
2855+
2856+
// Analyse loop body region. Find all input variables which are used inside
2857+
// loop body region.
2858+
SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
2859+
SmallVector<BasicBlock *, 32> Blocks;
2860+
OI.collectBlocks(ParallelRegionBlockSet, Blocks);
2861+
SmallVector<BasicBlock *, 32> BlocksT(ParallelRegionBlockSet.begin(),
2862+
ParallelRegionBlockSet.end());
2863+
2864+
CodeExtractorAnalysisCache CEAC(*OuterFn);
2865+
CodeExtractor Extractor(Blocks,
2866+
/* DominatorTree */ nullptr,
2867+
/* AggregateArgs */ true,
2868+
/* BlockFrequencyInfo */ nullptr,
2869+
/* BranchProbabilityInfo */ nullptr,
2870+
/* AssumptionCache */ nullptr,
2871+
/* AllowVarArgs */ true,
2872+
/* AllowAlloca */ true,
2873+
/* AllocationBlock */ CLI->getPreheader(),
2874+
/* Suffix */ ".omp_wsloop",
2875+
/* AggrArgsIn0AddrSpace */ true);
2876+
2877+
BasicBlock *CommonExit = nullptr;
2878+
SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
2879+
2880+
// Find allocas outside the loop body region which are used inside loop
2881+
// body
2882+
Extractor.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
2883+
2884+
// We need to model loop body region as the function f(cnt, loop_arg).
2885+
// That's why we replace loop induction variable by the new counter
2886+
// which will be one of loop body function argument
2887+
std::vector<User *> Users(CLI->getIndVar()->user_begin(),
2888+
CLI->getIndVar()->user_end());
2889+
for (User *use : Users) {
2890+
if (Instruction *inst = dyn_cast<Instruction>(use)) {
2891+
if (ParallelRegionBlockSet.count(inst->getParent())) {
2892+
inst->replaceUsesOfWith(CLI->getIndVar(), newLoopCntLoad);
2893+
}
2894+
}
2895+
}
2896+
Extractor.findInputsOutputs(Inputs, Outputs, SinkingCands);
2897+
for (Value *Input : Inputs) {
2898+
// Make sure that loop counter variable is not merged into loop body
2899+
// function argument structure and it is passed as separate variable
2900+
if (Input == newLoopCntLoad)
2901+
OI.ExcludeArgsFromAggregate.push_back(Input);
2902+
}
2903+
2904+
// PostOutline CB is invoked when loop body function is outlined and
2905+
// loop body is replaced by call to outlined function. We need to add
2906+
// call to OpenMP device rtl inside loop preheader. OpenMP device rtl
2907+
// function will handle loop control logic.
2908+
//
2909+
OI.PostOutlineCB = [=, ToBeDeletedVec =
2910+
std::move(ToBeDeleted)](Function &OutlinedFn) {
2911+
workshareLoopTargetCallback(this, CLI, Ident, OutlinedFn, ParallelTaskPtr,
2912+
ToBeDeletedVec, LoopType);
2913+
};
2914+
addOutlineInfo(std::move(OI));
2915+
return CLI->getAfterIP();
2916+
}
2917+
26772918
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoop(
26782919
DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
26792920
bool NeedsBarrier, llvm::omp::ScheduleKind SchedKind,
26802921
llvm::Value *ChunkSize, bool HasSimdModifier, bool HasMonotonicModifier,
2681-
bool HasNonmonotonicModifier, bool HasOrderedClause) {
2922+
bool HasNonmonotonicModifier, bool HasOrderedClause,
2923+
OpenMPIRBuilder::WorksharingLoopType LoopType) {
2924+
if (Config.isTargetDevice())
2925+
return applyWorkshareLoopTarget(DL, CLI, AllocaIP, LoopType);
26822926
OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType(
26832927
SchedKind, ChunkSize, HasSimdModifier, HasMonotonicModifier,
26842928
HasNonmonotonicModifier, HasOrderedClause);

0 commit comments

Comments
 (0)