Skip to content

[OMPIRBuilder] Support runtime number of teams and threads, and SPMD mode #116051

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "clang/AST/StmtVisitor.h"
#include "clang/Basic/Cuda.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
#include "llvm/Frontend/OpenMP/OMPGridValues.h"

using namespace clang;
Expand Down Expand Up @@ -745,7 +746,9 @@ void CGOpenMPRuntimeGPU::emitKernelInit(const OMPExecutableDirective &D,
CodeGenFunction &CGF,
EntryFunctionState &EST, bool IsSPMD) {
llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs Attrs;
Attrs.IsSPMD = IsSPMD;
Attrs.ExecFlags =
IsSPMD ? llvm::omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD
: llvm::omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC;
computeMinAndMaxThreadsAndTeams(D, CGF, Attrs);

CGBuilderTy &Bld = CGF.Builder;
Expand Down
38 changes: 33 additions & 5 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -1389,9 +1389,6 @@ class OpenMPIRBuilder {

/// Supporting functions for Reductions CodeGen.
private:
/// Emit the llvm.used metadata.
void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List);

/// Get the id of the current thread on the GPU.
Value *getGPUThreadID();

Expand Down Expand Up @@ -2013,6 +2010,13 @@ class OpenMPIRBuilder {
/// Value.
GlobalValue *createGlobalFlag(unsigned Value, StringRef Name);

/// Emit the llvm.used metadata.
void emitUsed(StringRef Name, ArrayRef<llvm::WeakTrackingVH> List);

/// Emit the kernel execution mode.
GlobalVariable *emitKernelExecutionMode(StringRef KernelName,
omp::OMPTgtExecModeFlags Mode);

/// Generate control flow and cleanup for cancellation.
///
/// \param CancelFlag Flag indicating if the cancellation is performed.
Expand Down Expand Up @@ -2233,13 +2237,34 @@ class OpenMPIRBuilder {
/// time. The number of max values will be 1 except for the case where
/// ompx_bare is set.
struct TargetKernelDefaultAttrs {
bool IsSPMD = false;
omp::OMPTgtExecModeFlags ExecFlags =
omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC;
SmallVector<int32_t, 3> MaxTeams = {-1};
int32_t MinTeams = 1;
SmallVector<int32_t, 3> MaxThreads = {-1};
int32_t MinThreads = 1;
};

/// Container to pass LLVM IR runtime values or constants related to the
/// number of teams and threads with which the kernel must be launched, as
/// well as the trip count of the loop, if it is an SPMD or Generic-SPMD
/// kernel. These must be defined in the host prior to the call to the kernel
/// launch OpenMP RTL function.
struct TargetKernelRuntimeAttrs {
SmallVector<Value *, 3> MaxTeams = {nullptr};
Value *MinTeams = nullptr;
SmallVector<Value *, 3> TargetThreadLimit = {nullptr};
SmallVector<Value *, 3> TeamsThreadLimit = {nullptr};

/// 'parallel' construct 'num_threads' clause value, if present and it is an
/// SPMD kernel.
Value *MaxThreads = nullptr;

/// Total number of iterations of the SPMD or Generic-SPMD kernel or null if
/// it is a generic kernel.
Value *LoopTripCount = nullptr;
};

/// Data structure that contains the needed information to construct the
/// kernel args vector.
struct TargetKernelArgs {
Expand Down Expand Up @@ -2971,7 +2996,9 @@ class OpenMPIRBuilder {
/// \param CodeGenIP The insertion point where the call to the outlined
/// function should be emitted.
/// \param EntryInfo The entry information about the function.
/// \param DefaultAttrs Structure containing the default numbers of threads
/// \param DefaultAttrs Structure containing the default attributes, including
/// numbers of threads and teams to launch the kernel with.
/// \param RuntimeAttrs Structure containing the runtime numbers of threads
/// and teams to launch the kernel with.
/// \param Inputs The input values to the region that will be passed.
/// as arguments to the outlined function.
Expand All @@ -2987,6 +3014,7 @@ class OpenMPIRBuilder {
OpenMPIRBuilder::InsertPointTy CodeGenIP,
TargetRegionEntryInfo &EntryInfo,
const TargetKernelDefaultAttrs &DefaultAttrs,
const TargetKernelRuntimeAttrs &RuntimeAttrs,
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
TargetBodyGenCallbackTy BodyGenCB,
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
Expand Down
129 changes: 92 additions & 37 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,38 @@ GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
return GV;
}

void OpenMPIRBuilder::emitUsed(StringRef Name, ArrayRef<WeakTrackingVH> List) {
if (List.empty())
return;

// Convert List to what ConstantArray needs.
SmallVector<Constant *, 8> UsedArray;
UsedArray.resize(List.size());
for (unsigned I = 0, E = List.size(); I != E; ++I)
UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
cast<Constant>(&*List[I]), Builder.getPtrTy());

if (UsedArray.empty())
return;
ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size());

auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
ConstantArray::get(ATy, UsedArray), Name);

GV->setSection("llvm.metadata");
}

GlobalVariable *
OpenMPIRBuilder::emitKernelExecutionMode(StringRef KernelName,
OMPTgtExecModeFlags Mode) {
auto *Int8Ty = Builder.getInt8Ty();
auto *GVMode = new GlobalVariable(
M, Int8Ty, /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
ConstantInt::get(Int8Ty, Mode), Twine(KernelName, "_exec_mode"));
GVMode->setVisibility(GlobalVariable::ProtectedVisibility);
return GVMode;
}

Constant *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr,
uint32_t SrcLocStrSize,
IdentFlag LocFlags,
Expand Down Expand Up @@ -2260,28 +2292,6 @@ static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) {
return OpenMPIRBuilder::InsertPointTy(I->getParent(), IT);
}

void OpenMPIRBuilder::emitUsed(StringRef Name,
std::vector<WeakTrackingVH> &List) {
if (List.empty())
return;

// Convert List to what ConstantArray needs.
SmallVector<Constant *, 8> UsedArray;
UsedArray.resize(List.size());
for (unsigned I = 0, E = List.size(); I != E; ++I)
UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
cast<Constant>(&*List[I]), Builder.getPtrTy());

if (UsedArray.empty())
return;
ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size());

auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
ConstantArray::get(ATy, UsedArray), Name);

GV->setSection("llvm.metadata");
}

Value *OpenMPIRBuilder::getGPUThreadID() {
return Builder.CreateCall(
getOrCreateRuntimeFunction(M,
Expand Down Expand Up @@ -6131,10 +6141,9 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
uint32_t SrcLocStrSize;
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
Constant *IsSPMDVal = ConstantInt::getSigned(
Int8, Attrs.IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
Constant *UseGenericStateMachineVal =
ConstantInt::getSigned(Int8, !Attrs.IsSPMD);
Constant *IsSPMDVal = ConstantInt::getSigned(Int8, Attrs.ExecFlags);
Constant *UseGenericStateMachineVal = ConstantInt::getSigned(
Int8, Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD);
Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Int8, true);
Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Int16, 0);

Expand Down Expand Up @@ -6765,6 +6774,12 @@ static Expected<Function *> createOutlinedFunction(
auto Func =
Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M);

if (OMPBuilder.Config.isTargetDevice()) {
Value *ExecMode =
OMPBuilder.emitKernelExecutionMode(FuncName, DefaultAttrs.ExecFlags);
OMPBuilder.emitUsed("llvm.compiler.used", {ExecMode});
}

// Save insert point.
IRBuilder<>::InsertPointGuard IPG(Builder);
// If there's a DISubprogram associated with current function, then
Expand Down Expand Up @@ -7312,6 +7327,7 @@ static void
emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
OpenMPIRBuilder::InsertPointTy AllocaIP,
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
Function *OutlinedFn, Constant *OutlinedFnID,
SmallVectorImpl<Value *> &Args,
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
Expand Down Expand Up @@ -7393,11 +7409,43 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
/*ForEndCall=*/false);

SmallVector<Value *, 3> NumTeamsC;
for (auto [DefaultVal, RuntimeVal] :
zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal));

// Calculate number of threads: 0 if no clauses specified, otherwise it is the
// minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
auto InitMaxThreadsClause = [&Builder](Value *Clause) {
if (Clause)
Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
/*isSigned=*/false);
return Clause;
};
auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
if (Clause)
Result = Result
? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
Result, Clause)
: Clause;
};

// If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
// the NUM_THREADS clause is overriden by THREAD_LIMIT.
SmallVector<Value *, 3> NumThreadsC;
for (auto V : DefaultAttrs.MaxTeams)
NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
for (auto V : DefaultAttrs.MaxThreads)
NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1
? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
: nullptr;

for (auto [TeamsVal, TargetVal] : zip_equal(RuntimeAttrs.TeamsThreadLimit,
RuntimeAttrs.TargetThreadLimit)) {
Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
Value *NumThreads = InitMaxThreadsClause(TargetVal);

CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);

NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
}

unsigned NumTargetItems = Info.NumberOfPtrs;
// TODO: Use correct device ID
Expand All @@ -7406,14 +7454,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
llvm::omp::IdentFlag(0), 0);
// TODO: Use correct NumIterations
Value *NumIterations = Builder.getInt64(0);

Value *TripCount = RuntimeAttrs.LoopTripCount
? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
Builder.getInt64Ty(),
/*isSigned=*/false)
: Builder.getInt64(0);

// TODO: Use correct DynCGGroupMem
Value *DynCGGroupMem = Builder.getInt32(0);

KArgs = OpenMPIRBuilder::TargetKernelArgs(
NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
DynCGGroupMem, HasNoWait);
KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
NumTeamsC, NumThreadsC,
DynCGGroupMem, HasNoWait);

// The presence of certain clauses on the target directive require the
// explicit generation of the target task.
Expand All @@ -7438,6 +7491,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
const TargetKernelDefaultAttrs &DefaultAttrs,
const TargetKernelRuntimeAttrs &RuntimeAttrs,
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
Expand All @@ -7462,8 +7516,9 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
// to make a remote call (offload) to the previously outlined function
// that represents the target region. Do that now.
if (!Config.isTargetDevice())
emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn,
OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
HasNowait);
return Builder.saveIP();
}

Expand Down
Loading
Loading