Skip to content

Commit e3cdc93

Browse files
committed
[OMPIRBuilder] Introduce struct to hold default kernel teams/threads
This patch introduces the `OpenMPIRBuilder::TargetKernelDefaultAttrs` structure used to simplify passing default and constant values for number of teams and threads, and possibly other target kernel-related information in the future. This is used to forward values passed to `createTarget` to `createTargetInit`, which previously used a default unrelated set of values.
1 parent 27ffa9f commit e3cdc93

File tree

8 files changed

+102
-81
lines changed

8 files changed

+102
-81
lines changed

clang/lib/CodeGen/CGOpenMPRuntime.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5881,10 +5881,13 @@ void CGOpenMPRuntime::emitUsesAllocatorsFini(CodeGenFunction &CGF,
58815881

58825882
void CGOpenMPRuntime::computeMinAndMaxThreadsAndTeams(
58835883
const OMPExecutableDirective &D, CodeGenFunction &CGF,
5884-
int32_t &MinThreadsVal, int32_t &MaxThreadsVal, int32_t &MinTeamsVal,
5885-
int32_t &MaxTeamsVal) {
5884+
llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
5885+
assert(Attrs.MaxTeams.size() == 1 && Attrs.MaxThreads.size() == 1 &&
5886+
"invalid default attrs structure");
5887+
int32_t &MaxTeamsVal = Attrs.MaxTeams.front();
5888+
int32_t &MaxThreadsVal = Attrs.MaxThreads.front();
58865889

5887-
getNumTeamsExprForTargetDirective(CGF, D, MinTeamsVal, MaxTeamsVal);
5890+
getNumTeamsExprForTargetDirective(CGF, D, Attrs.MinTeams, MaxTeamsVal);
58885891
getNumThreadsExprForTargetDirective(CGF, D, MaxThreadsVal,
58895892
/*UpperBoundOnly=*/true);
58905893

@@ -5902,12 +5905,12 @@ void CGOpenMPRuntime::computeMinAndMaxThreadsAndTeams(
59025905
else
59035906
continue;
59045907

5905-
MinThreadsVal = std::max(MinThreadsVal, AttrMinThreadsVal);
5908+
Attrs.MinThreads = std::max(Attrs.MinThreads, AttrMinThreadsVal);
59065909
if (AttrMaxThreadsVal > 0)
59075910
MaxThreadsVal = MaxThreadsVal > 0
59085911
? std::min(MaxThreadsVal, AttrMaxThreadsVal)
59095912
: AttrMaxThreadsVal;
5910-
MinTeamsVal = std::max(MinTeamsVal, AttrMinBlocksVal);
5913+
Attrs.MinTeams = std::max(Attrs.MinTeams, AttrMinBlocksVal);
59115914
if (AttrMaxBlocksVal > 0)
59125915
MaxTeamsVal = MaxTeamsVal > 0 ? std::min(MaxTeamsVal, AttrMaxBlocksVal)
59135916
: AttrMaxBlocksVal;

clang/lib/CodeGen/CGOpenMPRuntime.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -313,12 +313,9 @@ class CGOpenMPRuntime {
313313
llvm::OpenMPIRBuilder OMPBuilder;
314314

315315
/// Helper to determine the min/max number of threads/teams for \p D.
316-
void computeMinAndMaxThreadsAndTeams(const OMPExecutableDirective &D,
317-
CodeGenFunction &CGF,
318-
int32_t &MinThreadsVal,
319-
int32_t &MaxThreadsVal,
320-
int32_t &MinTeamsVal,
321-
int32_t &MaxTeamsVal);
316+
void computeMinAndMaxThreadsAndTeams(
317+
const OMPExecutableDirective &D, CodeGenFunction &CGF,
318+
llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs);
322319

323320
/// Helper to emit outlined function for 'target' directive.
324321
/// \param D Directive to emit.

clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -744,14 +744,11 @@ void CGOpenMPRuntimeGPU::emitNonSPMDKernel(const OMPExecutableDirective &D,
744744
void CGOpenMPRuntimeGPU::emitKernelInit(const OMPExecutableDirective &D,
745745
CodeGenFunction &CGF,
746746
EntryFunctionState &EST, bool IsSPMD) {
747-
int32_t MinThreadsVal = 1, MaxThreadsVal = -1, MinTeamsVal = 1,
748-
MaxTeamsVal = -1;
749-
computeMinAndMaxThreadsAndTeams(D, CGF, MinThreadsVal, MaxThreadsVal,
750-
MinTeamsVal, MaxTeamsVal);
747+
llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs Attrs;
748+
computeMinAndMaxThreadsAndTeams(D, CGF, Attrs);
751749

752750
CGBuilderTy &Bld = CGF.Builder;
753-
Bld.restoreIP(OMPBuilder.createTargetInit(
754-
Bld, IsSPMD, MinThreadsVal, MaxThreadsVal, MinTeamsVal, MaxTeamsVal));
751+
Bld.restoreIP(OMPBuilder.createTargetInit(Bld, IsSPMD, Attrs));
755752
if (!IsSPMD)
756753
emitGenericVarsProlog(CGF, EST.Loc);
757754
}

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2223,6 +2223,20 @@ class OpenMPIRBuilder {
22232223
MapNamesArray(MapNamesArray) {}
22242224
};
22252225

2226+
/// Container to pass the default attributes with which a kernel must be
2227+
/// launched, used to set kernel attributes and populate associated static
2228+
/// structures.
2229+
///
2230+
/// For max values, < 0 means unset, == 0 means set but unknown at compile
2231+
/// time. The number of max values will be 1 except for the case where
2232+
/// ompx_bare is set.
2233+
struct TargetKernelDefaultAttrs {
2234+
SmallVector<int32_t, 3> MaxTeams = {-1};
2235+
int32_t MinTeams = 1;
2236+
SmallVector<int32_t, 3> MaxThreads = {-1};
2237+
int32_t MinThreads = 1;
2238+
};
2239+
22262240
/// Data structure that contains the needed information to construct the
22272241
/// kernel args vector.
22282242
struct TargetKernelArgs {
@@ -2726,15 +2740,11 @@ class OpenMPIRBuilder {
27262740
///
27272741
/// \param Loc The insert and source location description.
27282742
/// \param IsSPMD Flag to indicate if the kernel is an SPMD kernel or not.
2729-
/// \param MinThreads Minimal number of threads, or 0.
2730-
/// \param MaxThreads Maximal number of threads, or 0.
2731-
/// \param MinTeams Minimal number of teams, or 0.
2732-
/// \param MaxTeams Maximal number of teams, or 0.
2733-
InsertPointTy createTargetInit(const LocationDescription &Loc, bool IsSPMD,
2734-
int32_t MinThreadsVal = 0,
2735-
int32_t MaxThreadsVal = 0,
2736-
int32_t MinTeamsVal = 0,
2737-
int32_t MaxTeamsVal = 0);
2743+
/// \param Attrs Structure containing the default numbers of threads and teams
2744+
/// to launch the kernel with.
2745+
InsertPointTy createTargetInit(
2746+
const LocationDescription &Loc, bool IsSPMD,
2747+
const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs);
27382748

27392749
/// Create a runtime call for kmpc_target_deinit
27402750
///
@@ -2898,8 +2908,8 @@ class OpenMPIRBuilder {
28982908
/// \param CodeGenIP The insertion point where the call to the outlined
28992909
/// function should be emitted.
29002910
/// \param EntryInfo The entry information about the function.
2901-
/// \param NumTeams Number of teams specified in the num_teams clause.
2902-
/// \param NumThreads Number of teams specified in the thread_limit clause.
2911+
/// \param DefaultAttrs Structure containing the default numbers of threads
2912+
/// and teams to launch the kernel with.
29032913
/// \param Inputs The input values to the region that will be passed.
29042914
/// as arguments to the outlined function.
29052915
/// \param BodyGenCB Callback that will generate the region code.
@@ -2912,9 +2922,10 @@ class OpenMPIRBuilder {
29122922
const LocationDescription &Loc, bool IsOffloadEntry,
29132923
OpenMPIRBuilder::InsertPointTy AllocaIP,
29142924
OpenMPIRBuilder::InsertPointTy CodeGenIP,
2915-
TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
2916-
ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
2917-
GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB,
2925+
TargetRegionEntryInfo &EntryInfo,
2926+
const TargetKernelDefaultAttrs &DefaultAttrs,
2927+
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
2928+
TargetBodyGenCallbackTy BodyGenCB,
29182929
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
29192930
SmallVector<DependData> Dependencies = {}, bool HasNowait = false);
29202931

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6113,10 +6113,12 @@ CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
61136113
return Builder.CreateCall(Fn, Args);
61146114
}
61156115

6116-
OpenMPIRBuilder::InsertPointTy
6117-
OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
6118-
int32_t MinThreadsVal, int32_t MaxThreadsVal,
6119-
int32_t MinTeamsVal, int32_t MaxTeamsVal) {
6116+
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
6117+
const LocationDescription &Loc, bool IsSPMD,
6118+
const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
6119+
assert(!Attrs.MaxThreads.empty() && !Attrs.MaxTeams.empty() &&
6120+
"expected num_threads and num_teams to be specified");
6121+
61206122
if (!updateToLocation(Loc))
61216123
return Loc.IP;
61226124

@@ -6143,21 +6145,23 @@ OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
61436145

61446146
// Manifest the launch configuration in the metadata matching the kernel
61456147
// environment.
6146-
if (MinTeamsVal > 1 || MaxTeamsVal > 0)
6147-
writeTeamsForKernel(T, *Kernel, MinTeamsVal, MaxTeamsVal);
6148+
if (Attrs.MinTeams > 1 || Attrs.MaxTeams.front() > 0)
6149+
writeTeamsForKernel(T, *Kernel, Attrs.MinTeams, Attrs.MaxTeams.front());
61486150

6149-
// For max values, < 0 means unset, == 0 means set but unknown.
6151+
// If MaxThreads not set, select the maximum between the default workgroup
6152+
// size and the MinThreads value.
6153+
int32_t MaxThreadsVal = Attrs.MaxThreads.front();
61506154
if (MaxThreadsVal < 0)
61516155
MaxThreadsVal = std::max(
6152-
int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), MinThreadsVal);
6156+
int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), Attrs.MinThreads);
61536157

61546158
if (MaxThreadsVal > 0)
6155-
writeThreadBoundsForKernel(T, *Kernel, MinThreadsVal, MaxThreadsVal);
6159+
writeThreadBoundsForKernel(T, *Kernel, Attrs.MinThreads, MaxThreadsVal);
61566160

6157-
Constant *MinThreads = ConstantInt::getSigned(Int32, MinThreadsVal);
6161+
Constant *MinThreads = ConstantInt::getSigned(Int32, Attrs.MinThreads);
61586162
Constant *MaxThreads = ConstantInt::getSigned(Int32, MaxThreadsVal);
6159-
Constant *MinTeams = ConstantInt::getSigned(Int32, MinTeamsVal);
6160-
Constant *MaxTeams = ConstantInt::getSigned(Int32, MaxTeamsVal);
6163+
Constant *MinTeams = ConstantInt::getSigned(Int32, Attrs.MinTeams);
6164+
Constant *MaxTeams = ConstantInt::getSigned(Int32, Attrs.MaxTeams.front());
61616165
Constant *ReductionDataSize = ConstantInt::getSigned(Int32, 0);
61626166
Constant *ReductionBufferLength = ConstantInt::getSigned(Int32, 0);
61636167

@@ -6728,8 +6732,9 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
67286732
}
67296733

67306734
static Expected<Function *> createOutlinedFunction(
6731-
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, StringRef FuncName,
6732-
SmallVectorImpl<Value *> &Inputs,
6735+
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
6736+
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
6737+
StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
67336738
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
67346739
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
67356740
SmallVector<Type *> ParameterTypes;
@@ -6796,7 +6801,8 @@ static Expected<Function *> createOutlinedFunction(
67966801

67976802
// Insert target init call in the device compilation pass.
67986803
if (OMPBuilder.Config.isTargetDevice())
6799-
Builder.restoreIP(OMPBuilder.createTargetInit(Builder, /*IsSPMD*/ false));
6804+
Builder.restoreIP(
6805+
OMPBuilder.createTargetInit(Builder, /*IsSPMD=*/false, DefaultAttrs));
68006806

68016807
BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
68026808

@@ -6992,16 +6998,18 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
69926998

69936999
static Error emitTargetOutlinedFunction(
69947000
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
6995-
TargetRegionEntryInfo &EntryInfo, Function *&OutlinedFn,
6996-
Constant *&OutlinedFnID, SmallVectorImpl<Value *> &Inputs,
7001+
TargetRegionEntryInfo &EntryInfo,
7002+
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7003+
Function *&OutlinedFn, Constant *&OutlinedFnID,
7004+
SmallVectorImpl<Value *> &Inputs,
69977005
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
69987006
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
69997007

70007008
OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
7001-
[&OMPBuilder, &Builder, &Inputs, &CBFunc,
7002-
&ArgAccessorFuncCB](StringRef EntryFnName) {
7003-
return createOutlinedFunction(OMPBuilder, Builder, EntryFnName, Inputs,
7004-
CBFunc, ArgAccessorFuncCB);
7009+
[&](StringRef EntryFnName) {
7010+
return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
7011+
EntryFnName, Inputs, CBFunc,
7012+
ArgAccessorFuncCB);
70057013
};
70067014

70077015
return OMPBuilder.emitTargetRegionFunction(
@@ -7297,9 +7305,10 @@ void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
72977305

72987306
static void
72997307
emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7300-
OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
7301-
Constant *OutlinedFnID, ArrayRef<int32_t> NumTeams,
7302-
ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Args,
7308+
OpenMPIRBuilder::InsertPointTy AllocaIP,
7309+
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7310+
Function *OutlinedFn, Constant *OutlinedFnID,
7311+
SmallVectorImpl<Value *> &Args,
73037312
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
73047313
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
73057314
bool HasNoWait = false) {
@@ -7380,9 +7389,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
73807389

73817390
SmallVector<Value *, 3> NumTeamsC;
73827391
SmallVector<Value *, 3> NumThreadsC;
7383-
for (auto V : NumTeams)
7392+
for (auto V : DefaultAttrs.MaxTeams)
73847393
NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
7385-
for (auto V : NumThreads)
7394+
for (auto V : DefaultAttrs.MaxThreads)
73867395
NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
73877396

73887397
unsigned NumTargetItems = Info.NumberOfPtrs;
@@ -7423,7 +7432,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
74237432
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
74247433
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
74257434
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
7426-
ArrayRef<int32_t> NumTeams, ArrayRef<int32_t> NumThreads,
7435+
const TargetKernelDefaultAttrs &DefaultAttrs,
74277436
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
74287437
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
74297438
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
@@ -7440,16 +7449,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
74407449
// the target region itself is generated using the callbacks CBFunc
74417450
// and ArgAccessorFuncCB
74427451
if (Error Err = emitTargetOutlinedFunction(
7443-
*this, Builder, IsOffloadEntry, EntryInfo, OutlinedFn, OutlinedFnID,
7444-
Args, CBFunc, ArgAccessorFuncCB))
7452+
*this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
7453+
OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
74457454
return Err;
74467455

74477456
// If we are not on the target device, then we need to generate code
74487457
// to make a remote call (offload) to the previously outlined function
74497458
// that represents the target region. Do that now.
74507459
if (!Config.isTargetDevice())
7451-
emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
7452-
NumThreads, Args, GenMapInfoCB, Dependencies, HasNowait);
7460+
emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn,
7461+
OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
74537462
return Builder.saveIP();
74547463
}
74557464

llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6182,9 +6182,12 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
61826182

61836183
TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
61846184
OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
6185-
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
6186-
OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), Builder.saveIP(),
6187-
EntryInfo, -1, 0, Inputs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
6185+
OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
6186+
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
6187+
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
6188+
OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
6189+
Builder.saveIP(), EntryInfo, DefaultAttrs, Inputs,
6190+
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
61886191
assert(AfterIP && "unexpected error");
61896192
Builder.restoreIP(*AfterIP);
61906193
OMPBuilder.finalize();
@@ -6292,11 +6295,11 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
62926295
TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
62936296
/*Line=*/3, /*Count=*/0);
62946297

6295-
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
6296-
OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
6297-
EntryInfo, /*NumTeams=*/-1,
6298-
/*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
6299-
BodyGenCB, SimpleArgAccessorCB);
6298+
OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
6299+
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
6300+
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
6301+
Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs,
6302+
CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
63006303
assert(AfterIP && "unexpected error");
63016304
Builder.restoreIP(*AfterIP);
63026305

@@ -6443,11 +6446,11 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
64436446
TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
64446447
/*Line=*/3, /*Count=*/0);
64456448

6446-
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
6447-
OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
6448-
EntryInfo, /*NumTeams=*/-1,
6449-
/*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
6450-
BodyGenCB, SimpleArgAccessorCB);
6449+
OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
6450+
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
6451+
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
6452+
Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs,
6453+
CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
64516454
assert(AfterIP && "unexpected error");
64526455
Builder.restoreIP(*AfterIP);
64536456

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3917,9 +3917,6 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
39173917
if (!getTargetEntryUniqueInfo(entryInfo, targetOp, parentName))
39183918
return failure();
39193919

3920-
int32_t defaultValTeams = -1;
3921-
int32_t defaultValThreads = 0;
3922-
39233920
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
39243921
findAllocaInsertPoint(builder, moduleTranslation);
39253922

@@ -3954,6 +3951,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
39543951
allocaIP, codeGenIP);
39553952
};
39563953

3954+
// TODO: Populate default attributes based on the construct and clauses.
3955+
llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs = {
3956+
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
3957+
39573958
llvm::SmallVector<llvm::Value *, 4> kernelInput;
39583959
for (size_t i = 0; i < mapVars.size(); ++i) {
39593960
// declare target arguments are not passed to kernels as arguments
@@ -3973,8 +3974,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
39733974
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
39743975
moduleTranslation.getOpenMPBuilder()->createTarget(
39753976
ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,
3976-
defaultValTeams, defaultValThreads, kernelInput, genMapInfoCB, bodyCB,
3977-
argAccessorCB, dds, targetOp.getNowait());
3977+
defaultAttrs, kernelInput, genMapInfoCB, bodyCB, argAccessorCB, dds,
3978+
targetOp.getNowait());
39783979

39793980
if (failed(handleError(afterIP, opInst)))
39803981
return failure();

mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ module attributes {omp.is_target_device = true} {
2929
// CHECK: @[[SRC_LOC:.*]] = private unnamed_addr constant [23 x i8] c"{{[^"]*}}", align 1
3030
// CHECK: @[[IDENT:.*]] = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @[[SRC_LOC]] }, align 8
3131
// CHECK: @[[DYNA_ENV:.*]] = weak_odr protected global %struct.DynamicEnvironmentTy zeroinitializer
32-
// CHECK: @[[KERNEL_ENV:.*]] = weak_odr protected constant %struct.KernelEnvironmentTy { %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 1, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0 }, ptr @[[IDENT]], ptr @[[DYNA_ENV]] }
32+
// CHECK: @[[KERNEL_ENV:.*]] = weak_odr protected constant %struct.KernelEnvironmentTy { %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 1, i32 0, i32 0, i32 0, i32 -1, i32 0, i32 0 }, ptr @[[IDENT]], ptr @[[DYNA_ENV]] }
3333
// CHECK: define weak_odr protected void @__omp_offloading_{{[^_]+}}_{{[^_]+}}_omp_target_region__l{{[0-9]+}}(ptr %[[DYN_PTR:.*]], ptr %[[ADDR_A:.*]], ptr %[[ADDR_B:.*]], ptr %[[ADDR_C:.*]])
3434
// CHECK: %[[TMP_A:.*]] = alloca ptr, align 8
3535
// CHECK: store ptr %[[ADDR_A]], ptr %[[TMP_A]], align 8

0 commit comments

Comments
 (0)