Skip to content

[OpenMPIRBuilder] Add support for target workshare loops #73360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions llvm/include/llvm/Frontend/OpenMP/OMPConstants.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,16 @@ enum class RTLDependenceKindTy {
DepOmpAllMem = 0x80,
};

/// A type of worksharing loop construct
enum class WorksharingLoopType {
// Worksharing `for`-loop
ForStaticLoop,
// Worksharing `distrbute`-loop
DistributeStaticLoop,
// Worksharing `distrbute parallel for`-loop
DistributeForStaticLoop
};

} // end namespace omp

} // end namespace llvm
Expand Down
28 changes: 27 additions & 1 deletion llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,28 @@ class OpenMPIRBuilder {
omp::OpenMPOffloadMappingFlags MemberOfFlag);

private:
/// Modifies the canonical loop to be a statically-scheduled workshare loop
/// which is executed on the device
///
/// This takes a \p CLI representing a canonical loop, such as the one
/// created by \see createCanonicalLoop and emits additional instructions to
/// turn it into a workshare loop. In particular, it calls to an OpenMP
/// runtime function in the preheader to call OpenMP device rtl function
/// which handles worksharing of loop body interations.
///
/// \param DL Debug location for instructions added for the
/// workshare-loop construct itself.
/// \param CLI A descriptor of the canonical loop to workshare.
/// \param AllocaIP An insertion point for Alloca instructions usable in the
/// preheader of the loop.
/// \param LoopType Information about type of loop worksharing.
/// It corresponds to type of loop workshare OpenMP pragma.
///
/// \returns Point where to insert code after the workshare construct.
InsertPointTy applyWorkshareLoopTarget(DebugLoc DL, CanonicalLoopInfo *CLI,
InsertPointTy AllocaIP,
omp::WorksharingLoopType LoopType);

/// Modifies the canonical loop to be a statically-scheduled workshare loop.
///
/// This takes a \p LoopInfo representing a canonical loop, such as the one
Expand Down Expand Up @@ -1012,6 +1034,8 @@ class OpenMPIRBuilder {
/// present in the schedule clause.
/// \param HasOrderedClause Whether the (parameterless) ordered clause is
/// present.
/// \param LoopType Information about type of loop worksharing.
/// It corresponds to type of loop workshare OpenMP pragma.
///
/// \returns Point where to insert code after the workshare construct.
InsertPointTy applyWorkshareLoop(
Expand All @@ -1020,7 +1044,9 @@ class OpenMPIRBuilder {
llvm::omp::ScheduleKind SchedKind = llvm::omp::OMP_SCHEDULE_Default,
Value *ChunkSize = nullptr, bool HasSimdModifier = false,
bool HasMonotonicModifier = false, bool HasNonmonotonicModifier = false,
bool HasOrderedClause = false);
bool HasOrderedClause = false,
omp::WorksharingLoopType LoopType =
omp::WorksharingLoopType::ForStaticLoop);

/// Tile a loop nest.
///
Expand Down
237 changes: 234 additions & 3 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2674,11 +2674,242 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
return {DispatchAfter, DispatchAfter->getFirstInsertionPt()};
}

// Returns an LLVM function to call for executing an OpenMP static worksharing
// for loop depending on `type`. Only i32 and i64 are supported by the runtime.
// Always interpret integers as unsigned similarly to CanonicalLoopInfo.
static FunctionCallee
getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder,
WorksharingLoopType LoopType) {
unsigned Bitwidth = Ty->getIntegerBitWidth();
Module &M = OMPBuilder->M;
switch (LoopType) {
case WorksharingLoopType::ForStaticLoop:
if (Bitwidth == 32)
return OMPBuilder->getOrCreateRuntimeFunction(
M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_4u);
if (Bitwidth == 64)
return OMPBuilder->getOrCreateRuntimeFunction(
M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_8u);
break;
case WorksharingLoopType::DistributeStaticLoop:
if (Bitwidth == 32)
return OMPBuilder->getOrCreateRuntimeFunction(
M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_4u);
if (Bitwidth == 64)
return OMPBuilder->getOrCreateRuntimeFunction(
M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_8u);
break;
case WorksharingLoopType::DistributeForStaticLoop:
if (Bitwidth == 32)
return OMPBuilder->getOrCreateRuntimeFunction(
M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_4u);
if (Bitwidth == 64)
return OMPBuilder->getOrCreateRuntimeFunction(
M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_8u);
break;
}
if (Bitwidth != 32 && Bitwidth != 64) {
llvm_unreachable("Unknown OpenMP loop iterator bitwidth");
}
llvm_unreachable("Unknown type of OpenMP worksharing loop");
}

// Inserts a call to proper OpenMP Device RTL function which handles
// loop worksharing.
static void createTargetLoopWorkshareCall(
OpenMPIRBuilder *OMPBuilder, WorksharingLoopType LoopType,
BasicBlock *InsertBlock, Value *Ident, Value *LoopBodyArg,
Type *ParallelTaskPtr, Value *TripCount, Function &LoopBodyFn) {
Type *TripCountTy = TripCount->getType();
Module &M = OMPBuilder->M;
IRBuilder<> &Builder = OMPBuilder->Builder;
FunctionCallee RTLFn =
getKmpcForStaticLoopForType(TripCountTy, OMPBuilder, LoopType);
SmallVector<Value *, 8> RealArgs;
RealArgs.push_back(Ident);
RealArgs.push_back(Builder.CreateBitCast(&LoopBodyFn, ParallelTaskPtr));
RealArgs.push_back(LoopBodyArg);
RealArgs.push_back(TripCount);
if (LoopType == WorksharingLoopType::DistributeStaticLoop) {
RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
Builder.CreateCall(RTLFn, RealArgs);
return;
}
FunctionCallee RTLNumThreads = OMPBuilder->getOrCreateRuntimeFunction(
M, omp::RuntimeFunction::OMPRTL_omp_get_num_threads);
Builder.restoreIP({InsertBlock, std::prev(InsertBlock->end())});
Value *NumThreads = Builder.CreateCall(RTLNumThreads, {});

RealArgs.push_back(
Builder.CreateZExtOrTrunc(NumThreads, TripCountTy, "num.threads.cast"));
RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
}

Builder.CreateCall(RTLFn, RealArgs);
}

static void
workshareLoopTargetCallback(OpenMPIRBuilder *OMPIRBuilder,
CanonicalLoopInfo *CLI, Value *Ident,
Function &OutlinedFn, Type *ParallelTaskPtr,
const SmallVector<Instruction *, 4> &ToBeDeleted,
WorksharingLoopType LoopType) {
IRBuilder<> &Builder = OMPIRBuilder->Builder;
BasicBlock *Preheader = CLI->getPreheader();
Value *TripCount = CLI->getTripCount();

// After loop body outling, the loop body contains only set up
// of loop body argument structure and the call to the outlined
// loop body function. Firstly, we need to move setup of loop body args
// into loop preheader.
Preheader->splice(std::prev(Preheader->end()), CLI->getBody(),
CLI->getBody()->begin(), std::prev(CLI->getBody()->end()));

// The next step is to remove the whole loop. We do not it need anymore.
// That's why make an unconditional branch from loop preheader to loop
// exit block
Builder.restoreIP({Preheader, Preheader->end()});
Preheader->getTerminator()->eraseFromParent();
Builder.CreateBr(CLI->getExit());

// Delete dead loop blocks
OpenMPIRBuilder::OutlineInfo CleanUpInfo;
SmallPtrSet<BasicBlock *, 32> RegionBlockSet;
SmallVector<BasicBlock *, 32> BlocksToBeRemoved;
CleanUpInfo.EntryBB = CLI->getHeader();
CleanUpInfo.ExitBB = CLI->getExit();
CleanUpInfo.collectBlocks(RegionBlockSet, BlocksToBeRemoved);
DeleteDeadBlocks(BlocksToBeRemoved);

// Find the instruction which corresponds to loop body argument structure
// and remove the call to loop body function instruction.
Value *LoopBodyArg;
User *OutlinedFnUser = OutlinedFn.getUniqueUndroppableUser();
assert(OutlinedFnUser &&
"Expected unique undroppable user of outlined function");
CallInst *OutlinedFnCallInstruction = dyn_cast<CallInst>(OutlinedFnUser);
assert(OutlinedFnCallInstruction && "Expected outlined function call");
assert((OutlinedFnCallInstruction->getParent() == Preheader) &&
"Expected outlined function call to be located in loop preheader");
// Check in case no argument structure has been passed.
if (OutlinedFnCallInstruction->arg_size() > 1)
LoopBodyArg = OutlinedFnCallInstruction->getArgOperand(1);
else
LoopBodyArg = Constant::getNullValue(Builder.getPtrTy());
OutlinedFnCallInstruction->eraseFromParent();

createTargetLoopWorkshareCall(OMPIRBuilder, LoopType, Preheader, Ident,
LoopBodyArg, ParallelTaskPtr, TripCount,
OutlinedFn);

for (auto &ToBeDeletedItem : ToBeDeleted)
ToBeDeletedItem->eraseFromParent();
CLI->invalidate();
}

OpenMPIRBuilder::InsertPointTy
OpenMPIRBuilder::applyWorkshareLoopTarget(DebugLoc DL, CanonicalLoopInfo *CLI,
InsertPointTy AllocaIP,
WorksharingLoopType LoopType) {
uint32_t SrcLocStrSize;
Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);

OutlineInfo OI;
OI.OuterAllocaBB = CLI->getPreheader();
Function *OuterFn = CLI->getPreheader()->getParent();

// Instructions which need to be deleted at the end of code generation
SmallVector<Instruction *, 4> ToBeDeleted;

OI.OuterAllocaBB = AllocaIP.getBlock();

// Mark the body loop as region which needs to be extracted
OI.EntryBB = CLI->getBody();
OI.ExitBB = CLI->getLatch()->splitBasicBlock(CLI->getLatch()->begin(),
"omp.prelatch", true);

// Prepare loop body for extraction
Builder.restoreIP({CLI->getPreheader(), CLI->getPreheader()->begin()});

// Insert new loop counter variable which will be used only in loop
// body.
AllocaInst *NewLoopCnt = Builder.CreateAlloca(CLI->getIndVarType(), 0, "");
Instruction *NewLoopCntLoad =
Builder.CreateLoad(CLI->getIndVarType(), NewLoopCnt);
// New loop counter instructions are redundant in the loop preheader when
// code generation for workshare loop is finshed. That's why mark them as
// ready for deletion.
ToBeDeleted.push_back(NewLoopCntLoad);
ToBeDeleted.push_back(NewLoopCnt);

// Analyse loop body region. Find all input variables which are used inside
// loop body region.
SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
SmallVector<BasicBlock *, 32> Blocks;
OI.collectBlocks(ParallelRegionBlockSet, Blocks);
SmallVector<BasicBlock *, 32> BlocksT(ParallelRegionBlockSet.begin(),
ParallelRegionBlockSet.end());

CodeExtractorAnalysisCache CEAC(*OuterFn);
CodeExtractor Extractor(Blocks,
/* DominatorTree */ nullptr,
/* AggregateArgs */ true,
/* BlockFrequencyInfo */ nullptr,
/* BranchProbabilityInfo */ nullptr,
/* AssumptionCache */ nullptr,
/* AllowVarArgs */ true,
/* AllowAlloca */ true,
/* AllocationBlock */ CLI->getPreheader(),
/* Suffix */ ".omp_wsloop",
/* AggrArgsIn0AddrSpace */ true);

BasicBlock *CommonExit = nullptr;
SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;

// Find allocas outside the loop body region which are used inside loop
// body
Extractor.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);

// We need to model loop body region as the function f(cnt, loop_arg).
// That's why we replace loop induction variable by the new counter
// which will be one of loop body function argument
for (auto Use = CLI->getIndVar()->user_begin();
Use != CLI->getIndVar()->user_end(); ++Use) {
if (Instruction *Inst = dyn_cast<Instruction>(*Use)) {
if (ParallelRegionBlockSet.count(Inst->getParent())) {
Inst->replaceUsesOfWith(CLI->getIndVar(), NewLoopCntLoad);
}
}
}
// Make sure that loop counter variable is not merged into loop body
// function argument structure and it is passed as separate variable
OI.ExcludeArgsFromAggregate.push_back(NewLoopCntLoad);

// PostOutline CB is invoked when loop body function is outlined and
// loop body is replaced by call to outlined function. We need to add
// call to OpenMP device rtl inside loop preheader. OpenMP device rtl
// function will handle loop control logic.
//
OI.PostOutlineCB = [=, ToBeDeletedVec =
std::move(ToBeDeleted)](Function &OutlinedFn) {
workshareLoopTargetCallback(this, CLI, Ident, OutlinedFn, ParallelTaskPtr,
ToBeDeletedVec, LoopType);
};
addOutlineInfo(std::move(OI));
return CLI->getAfterIP();
}

OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoop(
DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
bool NeedsBarrier, llvm::omp::ScheduleKind SchedKind,
llvm::Value *ChunkSize, bool HasSimdModifier, bool HasMonotonicModifier,
bool HasNonmonotonicModifier, bool HasOrderedClause) {
bool NeedsBarrier, omp::ScheduleKind SchedKind, Value *ChunkSize,
bool HasSimdModifier, bool HasMonotonicModifier,
bool HasNonmonotonicModifier, bool HasOrderedClause,
WorksharingLoopType LoopType) {
if (Config.isTargetDevice())
return applyWorkshareLoopTarget(DL, CLI, AllocaIP, LoopType);
OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType(
SchedKind, ChunkSize, HasSimdModifier, HasMonotonicModifier,
HasNonmonotonicModifier, HasOrderedClause);
Expand Down
Loading