Skip to content

Commit 945337b

Browse files
committed
[OpenMP]Update use_device_clause lowering
This patch updates the use_device_ptr and use_device_addr clauses to use the mapInfoOps for lowering. This allows all the types that are handle by the map clauses such as derived types to also be supported by the use_device_clauses. This is patch 2/2 in a series of patches.
1 parent d996277 commit 945337b

File tree

4 files changed

+194
-135
lines changed

4 files changed

+194
-135
lines changed

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6351,7 +6351,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
63516351
// Disable TargetData CodeGen on Device pass.
63526352
if (Config.IsTargetDevice.value_or(false)) {
63536353
if (BodyGenCB)
6354-
Builder.restoreIP(BodyGenCB(Builder.saveIP(), BodyGenTy::NoPriv));
6354+
Builder.restoreIP(BodyGenCB(CodeGenIP, BodyGenTy::NoPriv));
63556355
return Builder.saveIP();
63566356
}
63576357

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 155 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -2110,6 +2110,8 @@ getRefPtrIfDeclareTarget(mlir::Value value,
21102110
struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy {
21112111
llvm::SmallVector<bool, 4> IsDeclareTarget;
21122112
llvm::SmallVector<bool, 4> IsAMember;
2113+
// Identify if mapping was added by mapClause or use_device clauses.
2114+
llvm::SmallVector<bool, 4> IsAMapping;
21132115
llvm::SmallVector<mlir::Operation *, 4> MapClause;
21142116
llvm::SmallVector<llvm::Value *, 4> OriginalValue;
21152117
// Stripped off array/pointer to get the underlying
@@ -2193,62 +2195,125 @@ llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
21932195
return builder.getInt64(dl.getTypeSizeInBits(type) / 8);
21942196
}
21952197

2196-
void collectMapDataFromMapVars(MapInfoData &mapData,
2197-
llvm::SmallVectorImpl<Value> &mapVars,
2198-
LLVM::ModuleTranslation &moduleTranslation,
2199-
DataLayout &dl, llvm::IRBuilderBase &builder) {
2198+
void collectMapDataFromMapOperands(
2199+
MapInfoData &mapData, llvm::SmallVectorImpl<Value> &mapVars,
2200+
LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl,
2201+
llvm::IRBuilderBase &builder,
2202+
const llvm::ArrayRef<Value> &useDevPtrOperands = {},
2203+
const llvm::ArrayRef<Value> &useDevAddrOperands = {}) {
2204+
// Process MapOperands
22002205
for (mlir::Value mapValue : mapVars) {
2201-
if (auto mapOp = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
2202-
mapValue.getDefiningOp())) {
2203-
mlir::Value offloadPtr =
2204-
mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
2205-
mapData.OriginalValue.push_back(
2206-
moduleTranslation.lookupValue(offloadPtr));
2207-
mapData.Pointers.push_back(mapData.OriginalValue.back());
2208-
2209-
if (llvm::Value *refPtr =
2210-
getRefPtrIfDeclareTarget(offloadPtr,
2211-
moduleTranslation)) { // declare target
2212-
mapData.IsDeclareTarget.push_back(true);
2213-
mapData.BasePointers.push_back(refPtr);
2214-
} else { // regular mapped variable
2215-
mapData.IsDeclareTarget.push_back(false);
2216-
mapData.BasePointers.push_back(mapData.OriginalValue.back());
2217-
}
2206+
auto mapOp = mlir::cast<mlir::omp::MapInfoOp>(mapValue.getDefiningOp());
2207+
mlir::Value offloadPtr =
2208+
mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
2209+
mapData.OriginalValue.push_back(moduleTranslation.lookupValue(offloadPtr));
2210+
mapData.Pointers.push_back(mapData.OriginalValue.back());
2211+
2212+
if (llvm::Value *refPtr =
2213+
getRefPtrIfDeclareTarget(offloadPtr,
2214+
moduleTranslation)) { // declare target
2215+
mapData.IsDeclareTarget.push_back(true);
2216+
mapData.BasePointers.push_back(refPtr);
2217+
} else { // regular mapped variable
2218+
mapData.IsDeclareTarget.push_back(false);
2219+
mapData.BasePointers.push_back(mapData.OriginalValue.back());
2220+
}
22182221

2219-
mapData.BaseType.push_back(
2220-
moduleTranslation.convertType(mapOp.getVarType()));
2221-
mapData.Sizes.push_back(
2222-
getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
2223-
mapData.BaseType.back(), builder, moduleTranslation));
2224-
mapData.MapClause.push_back(mapOp.getOperation());
2225-
mapData.Types.push_back(
2226-
llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType().value()));
2227-
mapData.Names.push_back(LLVM::createMappingInformation(
2228-
mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
2229-
mapData.DevicePointers.push_back(
2230-
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
2231-
2232-
// Check if this is a member mapping and correctly assign that it is, if
2233-
// it is a member of a larger object.
2234-
// TODO: Need better handling of members, and distinguishing of members
2235-
// that are implicitly allocated on device vs explicitly passed in as
2236-
// arguments.
2237-
// TODO: May require some further additions to support nested record
2238-
// types, i.e. member maps that can have member maps.
2239-
mapData.IsAMember.push_back(false);
2240-
for (mlir::Value mapValue : mapVars) {
2241-
if (auto map = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
2242-
mapValue.getDefiningOp())) {
2243-
for (auto member : map.getMembers()) {
2244-
if (member == mapOp) {
2245-
mapData.IsAMember.back() = true;
2246-
}
2222+
mapData.BaseType.push_back(
2223+
moduleTranslation.convertType(mapOp.getVarType()));
2224+
mapData.Sizes.push_back(
2225+
getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
2226+
mapData.BaseType.back(), builder, moduleTranslation));
2227+
mapData.MapClause.push_back(mapOp.getOperation());
2228+
mapData.Types.push_back(
2229+
llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType().value()));
2230+
mapData.Names.push_back(LLVM::createMappingInformation(
2231+
mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
2232+
mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
2233+
mapData.IsAMapping.push_back(true);
2234+
2235+
// Check if this is a member mapping and correctly assign that it is, if
2236+
// it is a member of a larger object.
2237+
// TODO: Need better handling of members, and distinguishing of members
2238+
// that are implicitly allocated on device vs explicitly passed in as
2239+
// arguments.
2240+
// TODO: May require some further additions to support nested record
2241+
// types, i.e. member maps that can have member maps.
2242+
mapData.IsAMember.push_back(false);
2243+
for (mlir::Value mapValue : mapVars) {
2244+
if (auto map = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
2245+
mapValue.getDefiningOp())) {
2246+
for (auto member : map.getMembers()) {
2247+
if (member == mapOp) {
2248+
mapData.IsAMember.back() = true;
22472249
}
22482250
}
22492251
}
22502252
}
22512253
}
2254+
2255+
auto findMapInfo = [&mapData](llvm::Value *val,
2256+
llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
2257+
unsigned index = 0;
2258+
bool found = false;
2259+
for (llvm::Value *basePtr : mapData.OriginalValue) {
2260+
if (basePtr == val && mapData.IsAMapping[index]) {
2261+
found = true;
2262+
mapData.Types[index] |=
2263+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
2264+
mapData.DevicePointers[index] = devInfoTy;
2265+
}
2266+
index++;
2267+
}
2268+
return found;
2269+
};
2270+
2271+
// Process useDevPtr(Addr)Operands
2272+
auto addDevInfos = [&](const llvm::ArrayRef<Value> &useDevOperands,
2273+
llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
2274+
for (mlir::Value mapValue : useDevOperands) {
2275+
auto mapOp = mlir::cast<mlir::omp::MapInfoOp>(mapValue.getDefiningOp());
2276+
mlir::Value offloadPtr =
2277+
mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
2278+
llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
2279+
2280+
// Check if map info is already present for this entry.
2281+
if (!findMapInfo(origValue, devInfoTy)) {
2282+
mapData.OriginalValue.push_back(origValue);
2283+
mapData.Pointers.push_back(mapData.OriginalValue.back());
2284+
mapData.IsDeclareTarget.push_back(false);
2285+
mapData.BasePointers.push_back(mapData.OriginalValue.back());
2286+
mapData.BaseType.push_back(
2287+
moduleTranslation.convertType(mapOp.getVarType()));
2288+
mapData.Sizes.push_back(builder.getInt64(0));
2289+
mapData.MapClause.push_back(mapOp.getOperation());
2290+
mapData.Types.push_back(
2291+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
2292+
mapData.Names.push_back(LLVM::createMappingInformation(
2293+
mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
2294+
mapData.DevicePointers.push_back(devInfoTy);
2295+
mapData.IsAMapping.push_back(true);
2296+
2297+
// Check if this is a member mapping and correctly assign that it is,
2298+
// if it is a member of a larger object.
2299+
// TODO: Need better handling of members, and distinguishing of
2300+
// members that are implicitly allocated on device vs explicitly
2301+
// passed in as arguments.
2302+
// TODO: May require some further additions to support nested record
2303+
// types, i.e. member maps that can have member maps.
2304+
mapData.IsAMember.push_back(false);
2305+
for (mlir::Value mapValue : useDevOperands)
2306+
if (auto map = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
2307+
mapValue.getDefiningOp()))
2308+
for (auto member : map.getMembers())
2309+
if (member == mapOp)
2310+
mapData.IsAMember.back() = true;
2311+
}
2312+
}
2313+
};
2314+
2315+
addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
2316+
addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
22522317
}
22532318

22542319
static int getMapDataMemberIdx(MapInfoData &mapData,
@@ -2426,7 +2491,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
24262491
? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
24272492
: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE);
24282493
combinedInfo.DevicePointers.emplace_back(
2429-
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
2494+
mapData.DevicePointers[mapDataIndex]);
24302495
combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
24312496
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
24322497
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
@@ -2553,7 +2618,7 @@ static void processMapMembersWithParent(
25532618

25542619
combinedInfo.Types.emplace_back(mapFlag);
25552620
combinedInfo.DevicePointers.emplace_back(
2556-
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
2621+
mapData.DevicePointers[memberDataIdx]);
25572622
combinedInfo.Names.emplace_back(
25582623
LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
25592624
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
@@ -2714,10 +2779,7 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
27142779
LLVM::ModuleTranslation &moduleTranslation,
27152780
DataLayout &dl,
27162781
llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo,
2717-
MapInfoData &mapData,
2718-
const SmallVector<Value> &useDevicePtrVars = {},
2719-
const SmallVector<Value> &useDeviceAddrVars = {},
2720-
bool isTargetParams = false) {
2782+
MapInfoData &mapData, bool isTargetParams = false) {
27212783
// We wish to modify some of the methods in which arguments are
27222784
// passed based on their capture type by the target region, this can
27232785
// involve generating new loads and stores, which changes the
@@ -2734,15 +2796,6 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
27342796

27352797
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
27362798

2737-
auto fail = [&combinedInfo]() -> void {
2738-
combinedInfo.BasePointers.clear();
2739-
combinedInfo.Pointers.clear();
2740-
combinedInfo.DevicePointers.clear();
2741-
combinedInfo.Sizes.clear();
2742-
combinedInfo.Types.clear();
2743-
combinedInfo.Names.clear();
2744-
};
2745-
27462799
// We operate under the assumption that all vectors that are
27472800
// required in MapInfoData are of equal lengths (either filled with
27482801
// default constructed data or appropiate information) so we can
@@ -2763,46 +2816,6 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
27632816

27642817
processIndividualMap(mapData, i, combinedInfo, isTargetParams);
27652818
}
2766-
2767-
auto findMapInfo = [&combinedInfo](llvm::Value *val, unsigned &index) {
2768-
index = 0;
2769-
for (llvm::Value *basePtr : combinedInfo.BasePointers) {
2770-
if (basePtr == val)
2771-
return true;
2772-
index++;
2773-
}
2774-
return false;
2775-
};
2776-
2777-
auto addDevInfos = [&, fail](auto useDeviceVars, auto devOpType) -> void {
2778-
for (const auto &useDeviceVar : useDeviceVars) {
2779-
// TODO: Only LLVMPointerTypes are handled.
2780-
if (!isa<LLVM::LLVMPointerType>(useDeviceVar.getType()))
2781-
return fail();
2782-
2783-
llvm::Value *mapOpValue = moduleTranslation.lookupValue(useDeviceVar);
2784-
2785-
// Check if map info is already present for this entry.
2786-
unsigned infoIndex;
2787-
if (findMapInfo(mapOpValue, infoIndex)) {
2788-
combinedInfo.Types[infoIndex] |=
2789-
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
2790-
combinedInfo.DevicePointers[infoIndex] = devOpType;
2791-
} else {
2792-
combinedInfo.BasePointers.emplace_back(mapOpValue);
2793-
combinedInfo.Pointers.emplace_back(mapOpValue);
2794-
combinedInfo.DevicePointers.emplace_back(devOpType);
2795-
combinedInfo.Names.emplace_back(
2796-
LLVM::createMappingInformation(useDeviceVar.getLoc(), *ompBuilder));
2797-
combinedInfo.Types.emplace_back(
2798-
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
2799-
combinedInfo.Sizes.emplace_back(builder.getInt64(0));
2800-
}
2801-
}
2802-
};
2803-
2804-
addDevInfos(useDevicePtrVars, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
2805-
addDevInfos(useDeviceAddrVars, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
28062819
}
28072820

28082821
static LogicalResult
@@ -2899,19 +2912,15 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
28992912
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
29002913

29012914
MapInfoData mapData;
2902-
collectMapDataFromMapVars(mapData, mapVars, moduleTranslation, DL, builder);
2915+
collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, DL,
2916+
builder, useDevicePtrVars, useDeviceAddrVars);
29032917

29042918
// Fill up the arrays with all the mapped variables.
29052919
llvm::OpenMPIRBuilder::MapInfosTy combinedInfo;
29062920
auto genMapInfoCB =
29072921
[&](InsertPointTy codeGenIP) -> llvm::OpenMPIRBuilder::MapInfosTy & {
29082922
builder.restoreIP(codeGenIP);
2909-
if (auto dataOp = dyn_cast<omp::TargetDataOp>(op)) {
2910-
genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
2911-
useDevicePtrVars, useDeviceAddrVars);
2912-
} else {
2913-
genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData);
2914-
}
2923+
genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData);
29152924
return combinedInfo;
29162925
};
29172926

@@ -2930,21 +2939,23 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
29302939
if (!info.DevicePtrInfoMap.empty()) {
29312940
builder.restoreIP(codeGenIP);
29322941
unsigned argIndex = 0;
2933-
for (auto &devPtrOp : useDevicePtrVars) {
2934-
llvm::Value *mapOpValue = moduleTranslation.lookupValue(devPtrOp);
2935-
const auto &arg = region.front().getArgument(argIndex);
2936-
moduleTranslation.mapValue(arg,
2937-
info.DevicePtrInfoMap[mapOpValue].second);
2938-
argIndex++;
2939-
}
2940-
2941-
for (auto &devAddrOp : useDeviceAddrVars) {
2942-
llvm::Value *mapOpValue = moduleTranslation.lookupValue(devAddrOp);
2943-
const auto &arg = region.front().getArgument(argIndex);
2944-
auto *LI = builder.CreateLoad(
2945-
builder.getPtrTy(), info.DevicePtrInfoMap[mapOpValue].second);
2946-
moduleTranslation.mapValue(arg, LI);
2947-
argIndex++;
2942+
for (size_t i = 0; i < combinedInfo.BasePointers.size(); ++i) {
2943+
if (combinedInfo.DevicePointers[i] ==
2944+
llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer) {
2945+
const auto &arg = region.front().getArgument(argIndex);
2946+
moduleTranslation.mapValue(
2947+
arg,
2948+
info.DevicePtrInfoMap[combinedInfo.BasePointers[i]].second);
2949+
argIndex++;
2950+
} else if (combinedInfo.DevicePointers[i] ==
2951+
llvm::OpenMPIRBuilder::DeviceInfoTy::Address) {
2952+
const auto &arg = region.front().getArgument(argIndex);
2953+
auto *loadInst = builder.CreateLoad(
2954+
builder.getPtrTy(),
2955+
info.DevicePtrInfoMap[combinedInfo.BasePointers[i]].second);
2956+
moduleTranslation.mapValue(arg, loadInst);
2957+
argIndex++;
2958+
}
29482959
}
29492960

29502961
bodyGenStatus = inlineConvertOmpRegions(region, "omp.data.region",
@@ -2957,6 +2968,21 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
29572968
// If device info is available then region has already been generated
29582969
if (info.DevicePtrInfoMap.empty()) {
29592970
builder.restoreIP(codeGenIP);
2971+
// For device pass, if use_device_ptr(addr) mappings were present,
2972+
// we need to link them here before codegen.
2973+
if (ompBuilder->Config.IsTargetDevice.value_or(false)) {
2974+
unsigned argIndex = 0;
2975+
for (size_t i = 0; i < mapData.BasePointers.size(); ++i) {
2976+
if (mapData.DevicePointers[i] ==
2977+
llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer ||
2978+
mapData.DevicePointers[i] ==
2979+
llvm::OpenMPIRBuilder::DeviceInfoTy::Address) {
2980+
const auto &arg = region.front().getArgument(argIndex);
2981+
moduleTranslation.mapValue(arg, mapData.BasePointers[i]);
2982+
argIndex++;
2983+
}
2984+
}
2985+
}
29602986
bodyGenStatus = inlineConvertOmpRegions(region, "omp.data.region",
29612987
builder, moduleTranslation);
29622988
}
@@ -3292,14 +3318,14 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
32923318
findAllocaInsertPoint(builder, moduleTranslation);
32933319

32943320
MapInfoData mapData;
3295-
collectMapDataFromMapVars(mapData, mapVars, moduleTranslation, dl, builder);
3321+
collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
3322+
builder);
32963323

32973324
llvm::OpenMPIRBuilder::MapInfosTy combinedInfos;
32983325
auto genMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP)
32993326
-> llvm::OpenMPIRBuilder::MapInfosTy & {
33003327
builder.restoreIP(codeGenIP);
3301-
genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData, {}, {},
3302-
true);
3328+
genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData, true);
33033329
return combinedInfos;
33043330
};
33053331

0 commit comments

Comments
 (0)