Skip to content

Commit 66fe4f9

Browse files
committed
[MLIR][OpenMP] Add LLVM translation support for OpenMP UserDefinedMappers
This patch adds OpenMPToLLVMIRTranslation support for the OpenMP Declare Mapper directive. Since both MLIR and Clang now support custom mappers, I've made the relative params required instead of optional as well. Depends on #121005
1 parent af3dcbf commit 66fe4f9

File tree

7 files changed

+437
-99
lines changed

7 files changed

+437
-99
lines changed

clang/lib/CodeGen/CGOpenMPRuntime.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8889,8 +8889,8 @@ static void emitOffloadingArraysAndArgs(
88898889
return MFunc;
88908890
};
88918891
OMPBuilder.emitOffloadingArraysAndArgs(
8892-
AllocaIP, CodeGenIP, Info, Info.RTArgs, CombinedInfo, IsNonContiguous,
8893-
ForEndCall, DeviceAddrCB, CustomMapperCB);
8892+
AllocaIP, CodeGenIP, Info, Info.RTArgs, CombinedInfo, CustomMapperCB,
8893+
IsNonContiguous, ForEndCall, DeviceAddrCB);
88948894
}
88958895

88968896
/// Check for inner distribute directive.
@@ -9099,9 +9099,10 @@ void CGOpenMPRuntime::emitUserDefinedMapper(const OMPDeclareMapperDecl *D,
90999099
CGM.getCXXABI().getMangleContext().mangleCanonicalTypeName(Ty, Out);
91009100
std::string Name = getName({"omp_mapper", TyStr, D->getName()});
91019101

9102-
auto *NewFn = OMPBuilder.emitUserDefinedMapper(PrivatizeAndGenMapInfoCB,
9103-
ElemTy, Name, CustomMapperCB);
9104-
UDMMap.try_emplace(D, NewFn);
9102+
llvm::Expected<llvm::Function *> NewFn = OMPBuilder.emitUserDefinedMapper(
9103+
PrivatizeAndGenMapInfoCB, ElemTy, Name, CustomMapperCB);
9104+
assert(NewFn && "Unexpected error in emitUserDefinedMapper");
9105+
UDMMap.try_emplace(D, *NewFn);
91059106
if (CGF)
91069107
FunctionUDMMap[CGF->CurFn].push_back(D);
91079108
}

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2399,6 +2399,7 @@ class OpenMPIRBuilder {
23992399
CurInfo.NonContigInfo.Strides.end());
24002400
}
24012401
};
2402+
using MapInfosOrErrorTy = Expected<MapInfosTy &>;
24022403

24032404
/// Callback function type for functions emitting the host fallback code that
24042405
/// is executed when the kernel launch fails. It takes an insertion point as
@@ -2475,9 +2476,9 @@ class OpenMPIRBuilder {
24752476
/// including base pointers, pointers, sizes, map types, user-defined mappers.
24762477
void emitOffloadingArrays(
24772478
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
2478-
TargetDataInfo &Info, bool IsNonContiguous = false,
2479-
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
2480-
function_ref<Value *(unsigned int)> CustomMapperCB = nullptr);
2479+
TargetDataInfo &Info, function_ref<Value *(unsigned int)> CustomMapperCB,
2480+
bool IsNonContiguous = false,
2481+
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr);
24812482

24822483
/// Allocates memory for and populates the arrays required for offloading
24832484
/// (offload_{baseptrs|ptrs|mappers|sizes|maptypes|mapnames}). Then, it
@@ -2488,9 +2489,9 @@ class OpenMPIRBuilder {
24882489
void emitOffloadingArraysAndArgs(
24892490
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info,
24902491
TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo,
2492+
function_ref<Value *(unsigned int)> CustomMapperCB,
24912493
bool IsNonContiguous = false, bool ForEndCall = false,
2492-
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
2493-
function_ref<Value *(unsigned int)> CustomMapperCB = nullptr);
2494+
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr);
24942495

24952496
/// Creates offloading entry for the provided entry ID \a ID, address \a
24962497
/// Addr, size \a Size, and flags \a Flags.
@@ -2950,12 +2951,12 @@ class OpenMPIRBuilder {
29502951
/// \param FuncName Optional param to specify mapper function name.
29512952
/// \param CustomMapperCB Optional callback to generate code related to
29522953
/// custom mappers.
2953-
Function *emitUserDefinedMapper(
2954-
function_ref<MapInfosTy &(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
2955-
llvm::Value *BeginArg)>
2954+
Expected<Function *> emitUserDefinedMapper(
2955+
function_ref<MapInfosOrErrorTy(
2956+
InsertPointTy CodeGenIP, llvm::Value *PtrPHI, llvm::Value *BeginArg)>
29562957
PrivAndGenMapInfoCB,
29572958
llvm::Type *ElemTy, StringRef FuncName,
2958-
function_ref<bool(unsigned int, Function **)> CustomMapperCB = nullptr);
2959+
function_ref<bool(unsigned int, Function **)> CustomMapperCB);
29592960

29602961
/// Generator for '#omp target data'
29612962
///
@@ -2969,21 +2970,21 @@ class OpenMPIRBuilder {
29692970
/// \param IfCond Value which corresponds to the if clause condition.
29702971
/// \param Info Stores all information realted to the Target Data directive.
29712972
/// \param GenMapInfoCB Callback that populates the MapInfos and returns.
2973+
/// \param CustomMapperCB Callback to generate code related to
2974+
/// custom mappers.
29722975
/// \param BodyGenCB Optional Callback to generate the region code.
29732976
/// \param DeviceAddrCB Optional callback to generate code related to
29742977
/// use_device_ptr and use_device_addr.
2975-
/// \param CustomMapperCB Optional callback to generate code related to
2976-
/// custom mappers.
29772978
InsertPointOrErrorTy createTargetData(
29782979
const LocationDescription &Loc, InsertPointTy AllocaIP,
29792980
InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
29802981
TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
2982+
function_ref<Value *(unsigned int)> CustomMapperCB,
29812983
omp::RuntimeFunction *MapperFunc = nullptr,
29822984
function_ref<InsertPointOrErrorTy(InsertPointTy CodeGenIP,
29832985
BodyGenTy BodyGenType)>
29842986
BodyGenCB = nullptr,
29852987
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
2986-
function_ref<Value *(unsigned int)> CustomMapperCB = nullptr,
29872988
Value *SrcLocInfo = nullptr);
29882989

29892990
using TargetBodyGenCallbackTy = function_ref<InsertPointOrErrorTy(
@@ -2999,6 +3000,7 @@ class OpenMPIRBuilder {
29993000
/// \param IsOffloadEntry whether it is an offload entry.
30003001
/// \param CodeGenIP The insertion point where the call to the outlined
30013002
/// function should be emitted.
3003+
/// \param Info Stores all information realted to the Target directive.
30023004
/// \param EntryInfo The entry information about the function.
30033005
/// \param DefaultAttrs Structure containing the default attributes, including
30043006
/// numbers of threads and teams to launch the kernel with.
@@ -3010,20 +3012,23 @@ class OpenMPIRBuilder {
30103012
/// \param BodyGenCB Callback that will generate the region code.
30113013
/// \param ArgAccessorFuncCB Callback that will generate accessors
30123014
/// instructions for passed in target arguments where neccessary
3015+
/// \param CustomMapperCB Callback to generate code related to
3016+
/// custom mappers.
30133017
/// \param Dependencies A vector of DependData objects that carry
30143018
/// dependency information as passed in the depend clause
30153019
/// \param HasNowait Whether the target construct has a `nowait` clause or
30163020
/// not.
30173021
InsertPointOrErrorTy createTarget(
30183022
const LocationDescription &Loc, bool IsOffloadEntry,
30193023
OpenMPIRBuilder::InsertPointTy AllocaIP,
3020-
OpenMPIRBuilder::InsertPointTy CodeGenIP,
3024+
OpenMPIRBuilder::InsertPointTy CodeGenIP, TargetDataInfo &Info,
30213025
TargetRegionEntryInfo &EntryInfo,
30223026
const TargetKernelDefaultAttrs &DefaultAttrs,
30233027
const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
30243028
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
30253029
TargetBodyGenCallbackTy BodyGenCB,
30263030
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
3031+
function_ref<Value *(unsigned int)> CustomMapperCB,
30273032
SmallVector<DependData> Dependencies = {}, bool HasNowait = false);
30283033

30293034
/// Returns __kmpc_for_static_init_* runtime function for the specified

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6549,12 +6549,12 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData(
65496549
const LocationDescription &Loc, InsertPointTy AllocaIP,
65506550
InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
65516551
TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
6552+
function_ref<Value *(unsigned int)> CustomMapperCB,
65526553
omp::RuntimeFunction *MapperFunc,
65536554
function_ref<InsertPointOrErrorTy(InsertPointTy CodeGenIP,
65546555
BodyGenTy BodyGenType)>
65556556
BodyGenCB,
6556-
function_ref<void(unsigned int, Value *)> DeviceAddrCB,
6557-
function_ref<Value *(unsigned int)> CustomMapperCB, Value *SrcLocInfo) {
6557+
function_ref<void(unsigned int, Value *)> DeviceAddrCB, Value *SrcLocInfo) {
65586558
if (!updateToLocation(Loc))
65596559
return InsertPointTy();
65606560

@@ -6580,8 +6580,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData(
65806580
InsertPointTy CodeGenIP) -> Error {
65816581
MapInfo = &GenMapInfoCB(Builder.saveIP());
65826582
emitOffloadingArrays(AllocaIP, Builder.saveIP(), *MapInfo, Info,
6583-
/*IsNonContiguous=*/true, DeviceAddrCB,
6584-
CustomMapperCB);
6583+
CustomMapperCB,
6584+
/*IsNonContiguous=*/true, DeviceAddrCB);
65856585

65866586
TargetDataRTArgs RTArgs;
65876587
emitOffloadingArraysArgument(Builder, RTArgs, Info);
@@ -7394,24 +7394,26 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask(
73947394

73957395
void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
73967396
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info,
7397-
TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo, bool IsNonContiguous,
7398-
bool ForEndCall, function_ref<void(unsigned int, Value *)> DeviceAddrCB,
7399-
function_ref<Value *(unsigned int)> CustomMapperCB) {
7400-
emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info, IsNonContiguous,
7401-
DeviceAddrCB, CustomMapperCB);
7397+
TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo,
7398+
function_ref<Value *(unsigned int)> CustomMapperCB, bool IsNonContiguous,
7399+
bool ForEndCall, function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
7400+
emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info, CustomMapperCB,
7401+
IsNonContiguous, DeviceAddrCB);
74027402
emitOffloadingArraysArgument(Builder, RTArgs, Info, ForEndCall);
74037403
}
74047404

74057405
static void
74067406
emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
74077407
OpenMPIRBuilder::InsertPointTy AllocaIP,
7408+
OpenMPIRBuilder::TargetDataInfo &Info,
74087409
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
74097410
const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
74107411
Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID,
74117412
SmallVectorImpl<Value *> &Args,
74127413
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
7413-
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
7414-
bool HasNoWait = false) {
7414+
function_ref<Value *(unsigned int)> CustomMapperCB,
7415+
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies,
7416+
bool HasNoWait) {
74157417
// Generate a function call to the host fallback implementation of the target
74167418
// region. This is called by the host when no offload entry was generated for
74177419
// the target region and when the offloading call fails at runtime.
@@ -7489,7 +7491,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
74897491
OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
74907492
OpenMPIRBuilder::TargetDataRTArgs RTArgs;
74917493
OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
7492-
RTArgs, MapInfo,
7494+
RTArgs, MapInfo, CustomMapperCB,
74937495
/*IsNonContiguous=*/true,
74947496
/*ForEndCall=*/false);
74957497

@@ -7593,12 +7595,14 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
75937595

75947596
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
75957597
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
7596-
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
7598+
InsertPointTy CodeGenIP, TargetDataInfo &Info,
7599+
TargetRegionEntryInfo &EntryInfo,
75977600
const TargetKernelDefaultAttrs &DefaultAttrs,
75987601
const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
7599-
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
7602+
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
76007603
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
76017604
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
7605+
function_ref<Value *(unsigned int)> CustomMapperCB,
76027606
SmallVector<DependData> Dependencies, bool HasNowait) {
76037607

76047608
if (!updateToLocation(Loc))
@@ -7613,16 +7617,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
76137617
// and ArgAccessorFuncCB
76147618
if (Error Err = emitTargetOutlinedFunction(
76157619
*this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
7616-
OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
7620+
OutlinedFnID, Inputs, CBFunc, ArgAccessorFuncCB))
76177621
return Err;
76187622

76197623
// If we are not on the target device, then we need to generate code
76207624
// to make a remote call (offload) to the previously outlined function
76217625
// that represents the target region. Do that now.
76227626
if (!Config.isTargetDevice())
7623-
emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs, IfCond,
7624-
OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
7625-
HasNowait);
7627+
emitTargetCall(*this, Builder, AllocaIP, Info, DefaultAttrs, RuntimeAttrs,
7628+
IfCond, OutlinedFn, OutlinedFnID, Inputs, GenMapInfoCB,
7629+
CustomMapperCB, Dependencies, HasNowait);
76267630
return Builder.saveIP();
76277631
}
76287632

@@ -7947,9 +7951,9 @@ void OpenMPIRBuilder::emitUDMapperArrayInitOrDel(
79477951
OffloadingArgs);
79487952
}
79497953

7950-
Function *OpenMPIRBuilder::emitUserDefinedMapper(
7951-
function_ref<MapInfosTy &(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
7952-
llvm::Value *BeginArg)>
7954+
Expected<Function *> OpenMPIRBuilder::emitUserDefinedMapper(
7955+
function_ref<MapInfosOrErrorTy(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
7956+
llvm::Value *BeginArg)>
79537957
GenMapInfoCB,
79547958
Type *ElemTy, StringRef FuncName,
79557959
function_ref<bool(unsigned int, Function **)> CustomMapperCB) {
@@ -8023,7 +8027,9 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper(
80238027
PtrPHI->addIncoming(PtrBegin, HeadBB);
80248028

80258029
// Get map clause information. Fill up the arrays with all mapped variables.
8026-
MapInfosTy &Info = GenMapInfoCB(Builder.saveIP(), PtrPHI, BeginIn);
8030+
MapInfosOrErrorTy Info = GenMapInfoCB(Builder.saveIP(), PtrPHI, BeginIn);
8031+
if (!Info)
8032+
return Info.takeError();
80278033

80288034
// Call the runtime API __tgt_mapper_num_components to get the number of
80298035
// pre-existing components.
@@ -8035,20 +8041,20 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper(
80358041
Builder.CreateShl(PreviousSize, Builder.getInt64(getFlagMemberOffset()));
80368042

80378043
// Fill up the runtime mapper handle for all components.
8038-
for (unsigned I = 0; I < Info.BasePointers.size(); ++I) {
8044+
for (unsigned I = 0; I < Info->BasePointers.size(); ++I) {
80398045
Value *CurBaseArg =
8040-
Builder.CreateBitCast(Info.BasePointers[I], Builder.getPtrTy());
8046+
Builder.CreateBitCast(Info->BasePointers[I], Builder.getPtrTy());
80418047
Value *CurBeginArg =
8042-
Builder.CreateBitCast(Info.Pointers[I], Builder.getPtrTy());
8043-
Value *CurSizeArg = Info.Sizes[I];
8044-
Value *CurNameArg = Info.Names.size()
8045-
? Info.Names[I]
8048+
Builder.CreateBitCast(Info->Pointers[I], Builder.getPtrTy());
8049+
Value *CurSizeArg = Info->Sizes[I];
8050+
Value *CurNameArg = Info->Names.size()
8051+
? Info->Names[I]
80468052
: Constant::getNullValue(Builder.getPtrTy());
80478053

80488054
// Extract the MEMBER_OF field from the map type.
80498055
Value *OriMapType = Builder.getInt64(
80508056
static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8051-
Info.Types[I]));
8057+
Info->Types[I]));
80528058
Value *MemberMapType =
80538059
Builder.CreateNUWAdd(OriMapType, ShiftedPreviousSize);
80548060

@@ -8169,9 +8175,9 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper(
81698175

81708176
void OpenMPIRBuilder::emitOffloadingArrays(
81718177
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
8172-
TargetDataInfo &Info, bool IsNonContiguous,
8173-
function_ref<void(unsigned int, Value *)> DeviceAddrCB,
8174-
function_ref<Value *(unsigned int)> CustomMapperCB) {
8178+
TargetDataInfo &Info, function_ref<Value *(unsigned int)> CustomMapperCB,
8179+
bool IsNonContiguous,
8180+
function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
81758181

81768182
// Reset the array information.
81778183
Info.clearArrayInfo();

0 commit comments

Comments
 (0)