Skip to content

Commit 462435f

Browse files
committed
[OpenMP][MLIR] Extend record member map support for omp dialect to LLVM-IR
This patch seeks to refactor slightly and extend the current record type map support that was put in place for Fortran's descriptor types to handle explicit member mapping for record types at a single level of depth (the case of explicit mapping of nested record types is currently unsupported). This patch seeks to support this by extending the OpenMPToLLVMIRTranslation phase to more generally support record types, building on the prior groundwork in the Fortran allocatables/pointers patch. It now supports different kinds of record type mapping, in this case full record type mapping and then explicit member mapping in which there is a special case for certain types when mapped individually to not require any parent map link in the kernel argument structure. To facilitate this required: * The movement of the setting of the map flag type "ptr_and_obj" to respective frontends, now supporting it as a possible flag that can be read and printed in mlir form. Some minor changes to declare target map type setting was neccesary for this. * The addition of a member index array operand, which tracks the position of the member in the parent, required for caclulating the appropriate size to offload to the target, alongside the parents offload pointer (always the first member currently being mapped). * A partial mapping attribute operand, to indicate if the entire record type is being mapped or just member components, aiding the ability to lower record types in the different manners that are possible. * Refactoring bounds calculation for record types and general arrays to one location (as well as load/store generation prior to assigning to the kernel argument structure), as a side affect enter/exit/update/data mapping should now be more correct and fully support bounds mapping, previously this would have only worked for target. Pull Request: #82852
1 parent 50df0ff commit 462435f

File tree

5 files changed

+343
-88
lines changed

5 files changed

+343
-88
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 191 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
#include "llvm/Transforms/Utils/ModuleUtils.h"
3434

3535
#include <any>
36+
#include <cstdint>
3637
#include <iterator>
38+
#include <numeric>
3739
#include <optional>
3840
#include <utility>
3941

@@ -2037,7 +2039,7 @@ llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
20372039
if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
20382040
bounds.getDefiningOp())) {
20392041
// The below calculation for the size to be mapped calculated from the
2040-
// map_info's bounds is: (elemCount * [UB - LB] + 1), later we
2042+
// map.info's bounds is: (elemCount * [UB - LB] + 1), later we
20412043
// multiply by the underlying element types byte size to get the full
20422044
// size to be offloaded based on the bounds
20432045
elementCount = builder.CreateMul(
@@ -2089,9 +2091,9 @@ void collectMapDataFromMapOperands(MapInfoData &mapData,
20892091

20902092
mapData.BaseType.push_back(
20912093
moduleTranslation.convertType(mapOp.getVarType()));
2092-
mapData.Sizes.push_back(getSizeInBytes(
2093-
dl, mapOp.getVarType(), mapOp, mapData.BasePointers.back(),
2094-
mapData.BaseType.back(), builder, moduleTranslation));
2094+
mapData.Sizes.push_back(
2095+
getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
2096+
mapData.BaseType.back(), builder, moduleTranslation));
20952097
mapData.MapClause.push_back(mapOp.getOperation());
20962098
mapData.Types.push_back(
20972099
llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType().value()));
@@ -2122,6 +2124,67 @@ void collectMapDataFromMapOperands(MapInfoData &mapData,
21222124
}
21232125
}
21242126

2127+
static int getMapDataMemberIdx(MapInfoData &mapData,
2128+
mlir::omp::MapInfoOp memberOp) {
2129+
auto *res = llvm::find(mapData.MapClause, memberOp);
2130+
assert(res != mapData.MapClause.end() &&
2131+
"MapInfoOp for member not found in MapData, cannot return index");
2132+
return std::distance(mapData.MapClause.begin(), res);
2133+
}
2134+
2135+
static mlir::omp::MapInfoOp
2136+
getFirstOrLastMappedMemberPtr(mlir::omp::MapInfoOp mapInfo, bool first) {
2137+
mlir::DenseIntElementsAttr indexAttr = mapInfo.getMembersIndexAttr();
2138+
2139+
// Only 1 member has been mapped, we can return it.
2140+
if (indexAttr.size() == 1)
2141+
if (auto mapOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(
2142+
mapInfo.getMembers()[0].getDefiningOp()))
2143+
return mapOp;
2144+
2145+
llvm::ArrayRef<int64_t> shape = indexAttr.getShapedType().getShape();
2146+
llvm::SmallVector<size_t> indices(shape[0]);
2147+
std::iota(indices.begin(), indices.end(), 0);
2148+
2149+
llvm::sort(
2150+
indices.begin(), indices.end(), [&](const size_t a, const size_t b) {
2151+
auto indexValues = indexAttr.getValues<int32_t>();
2152+
for (int i = 0;
2153+
i < shape[1];
2154+
++i) {
2155+
int aIndex = indexValues[a * shape[1] + i];
2156+
int bIndex = indexValues[b * shape[1] + i];
2157+
2158+
if (aIndex != -1 && bIndex == -1)
2159+
return false;
2160+
2161+
if (aIndex == -1 && bIndex != -1)
2162+
return true;
2163+
2164+
if (aIndex == -1)
2165+
return first;
2166+
2167+
if (bIndex == -1)
2168+
return !first;
2169+
2170+
// A is earlier in the record type layout than B
2171+
if (aIndex < bIndex)
2172+
return first;
2173+
2174+
if (bIndex < aIndex)
2175+
return !first;
2176+
}
2177+
2178+
// iterated the entire list and couldn't make a decision, all elements
2179+
// were likely the same, return true for now similar to reaching the end
2180+
// of both and finding invalid indices.
2181+
return true;
2182+
});
2183+
2184+
return llvm::cast<mlir::omp::MapInfoOp>(
2185+
mapInfo.getMembers()[indices.front()].getDefiningOp());
2186+
}
2187+
21252188
/// This function calculates the array/pointer offset for map data provided
21262189
/// with bounds operations, e.g. when provided something like the following:
21272190
///
@@ -2227,6 +2290,9 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
22272290
// which is utilised in subsequent member mappings (by modifying there map type
22282291
// with it) to indicate that a member is part of this parent and should be
22292292
// treated by the runtime as such. Important to achieve the correct mapping.
2293+
//
2294+
// This function borrows a lot from Clang's emitCombinedEntry function
2295+
// inside of CGOpenMPRuntime.cpp
22302296
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
22312297
LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
22322298
llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl,
@@ -2242,7 +2308,6 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
22422308
combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
22432309
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
22442310
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
2245-
combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
22462311

22472312
// Calculate size of the parent object being mapped based on the
22482313
// addresses at runtime, highAddr - lowAddr = size. This of course
@@ -2251,42 +2316,68 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
22512316
// Fortran pointers and allocatables, the mapping of the pointed to
22522317
// data by the descriptor (which itself, is a structure containing
22532318
// runtime information on the dynamically allocated data).
2254-
llvm::Value *lowAddr = builder.CreatePointerCast(
2255-
mapData.Pointers[mapDataIndex], builder.getPtrTy());
2256-
llvm::Value *highAddr = builder.CreatePointerCast(
2257-
builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
2258-
mapData.Pointers[mapDataIndex], 1),
2259-
builder.getPtrTy());
2319+
auto parentClause =
2320+
llvm::cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
2321+
2322+
llvm::Value *lowAddr, *highAddr;
2323+
if (!parentClause.getPartialMap()) {
2324+
lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
2325+
builder.getPtrTy());
2326+
highAddr = builder.CreatePointerCast(
2327+
builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
2328+
mapData.Pointers[mapDataIndex], 1),
2329+
builder.getPtrTy());
2330+
combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
2331+
} else {
2332+
auto mapOp =
2333+
mlir::dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
2334+
int firstMemberIdx = getMapDataMemberIdx(
2335+
mapData, getFirstOrLastMappedMemberPtr(mapOp, true));
2336+
lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
2337+
builder.getPtrTy());
2338+
int lastMemberIdx = getMapDataMemberIdx(
2339+
mapData, getFirstOrLastMappedMemberPtr(mapOp, false));
2340+
highAddr = builder.CreatePointerCast(
2341+
builder.CreateGEP(mapData.BaseType[lastMemberIdx],
2342+
mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
2343+
builder.getPtrTy());
2344+
combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
2345+
}
2346+
22602347
llvm::Value *size = builder.CreateIntCast(
22612348
builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
22622349
builder.getInt64Ty(),
22632350
/*isSigned=*/false);
22642351
combinedInfo.Sizes.push_back(size);
22652352

2266-
// This creates the initial MEMBER_OF mapping that consists of
2267-
// the parent/top level container (same as above effectively, except
2268-
// with a fixed initial compile time size and seperate maptype which
2269-
// indicates the true mape type (tofrom etc.) and that it is a part
2270-
// of a larger mapping and indicating the link between it and it's
2271-
// members that are also explicitly mapped).
2353+
// TODO: This will need to be expanded to include the whole host of logic for
2354+
// the map flags that Clang currently supports (e.g. it should take the map
2355+
// flag of the parent map flag, remove the OMP_MAP_TARGET_PARAM and do some
2356+
// further case specific flag modifications). For the moment, it handles what
2357+
// we support as expected.
22722358
llvm::omp::OpenMPOffloadMappingFlags mapFlag =
22732359
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
2274-
if (isTargetParams)
2275-
mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
22762360

22772361
llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
22782362
ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
22792363
ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
22802364

2281-
combinedInfo.Types.emplace_back(mapFlag);
2282-
combinedInfo.DevicePointers.emplace_back(
2283-
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
2284-
combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
2285-
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
2286-
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
2287-
combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
2288-
combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
2289-
2365+
// This creates the initial MEMBER_OF mapping that consists of
2366+
// the parent/top level container (same as above effectively, except
2367+
// with a fixed initial compile time size and seperate maptype which
2368+
// indicates the true mape type (tofrom etc.). This parent mapping is
2369+
// only relevant if the structure in its totality is being mapped,
2370+
// otherwise the above suffices.
2371+
if (!parentClause.getPartialMap()) {
2372+
combinedInfo.Types.emplace_back(mapFlag);
2373+
combinedInfo.DevicePointers.emplace_back(
2374+
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
2375+
combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
2376+
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
2377+
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
2378+
combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
2379+
combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
2380+
}
22902381
return memberOfFlag;
22912382
}
22922383

@@ -2319,21 +2410,17 @@ static void processMapMembersWithParent(
23192410
uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) {
23202411

23212412
auto parentClause =
2322-
mlir::dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
2413+
llvm::cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
23232414

23242415
for (auto mappedMembers : parentClause.getMembers()) {
23252416
auto memberClause =
2326-
mlir::dyn_cast<mlir::omp::MapInfoOp>(mappedMembers.getDefiningOp());
2327-
int memberDataIdx = -1;
2328-
for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
2329-
if (mapData.MapClause[i] == memberClause)
2330-
memberDataIdx = i;
2331-
}
2417+
llvm::cast<mlir::omp::MapInfoOp>(mappedMembers.getDefiningOp());
2418+
int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
23322419

23332420
assert(memberDataIdx >= 0 && "could not find mapped member of structure");
23342421

23352422
// Same MemberOfFlag to indicate its link with parent and other members
2336-
// of, and we flag that it's part of a pointer and object coupling.
2423+
// of.
23372424
auto mapFlag =
23382425
llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType().value());
23392426
mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
@@ -2347,18 +2434,81 @@ static void processMapMembersWithParent(
23472434
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
23482435
combinedInfo.Names.emplace_back(
23492436
LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
2350-
2351-
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[memberDataIdx]);
2437+
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
23522438
combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
23532439
combinedInfo.Sizes.emplace_back(mapData.Sizes[memberDataIdx]);
23542440
}
23552441
}
23562442

2443+
static void
2444+
processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
2445+
llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo,
2446+
bool isTargetParams, int mapDataParentIdx = -1) {
2447+
// Declare Target Mappings are excluded from being marked as
2448+
// OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're
2449+
// marked with OMP_MAP_PTR_AND_OBJ instead.
2450+
auto mapFlag = mapData.Types[mapDataIdx];
2451+
auto mapInfoOp =
2452+
llvm::cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
2453+
2454+
bool isPtrTy = checkIfPointerMap(mapInfoOp);
2455+
if (isPtrTy)
2456+
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
2457+
2458+
if (isTargetParams && !mapData.IsDeclareTarget[mapDataIdx])
2459+
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
2460+
2461+
if (mapInfoOp.getMapCaptureType().value() ==
2462+
mlir::omp::VariableCaptureKind::ByCopy &&
2463+
!isPtrTy)
2464+
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
2465+
2466+
// if we're provided a mapDataParentIdx, then the data being mapped is
2467+
// part of a larger object (in a parent <-> member mapping) and in this
2468+
// case our BasePointer should be the parent.
2469+
if (mapDataParentIdx >= 0)
2470+
combinedInfo.BasePointers.emplace_back(
2471+
mapData.BasePointers[mapDataParentIdx]);
2472+
else
2473+
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
2474+
2475+
combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
2476+
combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
2477+
combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
2478+
combinedInfo.Types.emplace_back(mapFlag);
2479+
combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
2480+
}
2481+
23572482
static void processMapWithMembersOf(
23582483
LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
23592484
llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl,
23602485
llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData,
23612486
uint64_t mapDataIndex, bool isTargetParams) {
2487+
auto parentClause =
2488+
llvm::cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
2489+
2490+
// If we have a partial map (no parent referenced in the map clauses of the
2491+
// directive, only members) and only a single member, we do not need to bind
2492+
// the map of the member to the parent, we can pass the member seperately.
2493+
if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
2494+
auto memberClause = llvm::cast<mlir::omp::MapInfoOp>(
2495+
parentClause.getMembers()[0].getDefiningOp());
2496+
int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
2497+
// Note: Clang treats arrays with explicit bounds that fall into this
2498+
// category as a parent with map case, however, it seems this isn't a
2499+
// requirement, and processing them as an individual map is fine. So,
2500+
// we will handle them as individual maps for the moment, as it's
2501+
// difficult for us to check this as we always require bounds to be
2502+
// specified currently and it's also marginally more optimal (single
2503+
// map rather than two). The difference may come from the fact that
2504+
// Clang maps array without bounds as pointers (which we do not
2505+
// currently do), whereas we treat them as arrays in all cases
2506+
// currently.
2507+
processIndividualMap(mapData, memberDataIdx, combinedInfo, isTargetParams,
2508+
mapDataIndex);
2509+
return;
2510+
}
2511+
23622512
llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
23632513
mapParentWithMembers(moduleTranslation, builder, ompBuilder, dl,
23642514
combinedInfo, mapData, mapDataIndex, isTargetParams);
@@ -2477,12 +2627,8 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
24772627
// utilise the size from any component of MapInfoData, if we can't
24782628
// something is missing from the initial MapInfoData construction.
24792629
for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
2480-
// NOTE/TODO: We currently do not handle member mapping seperately from it's
2481-
// parent or explicit mapping of a parent and member in the same operation,
2482-
// this will need to change in the near future, for now we primarily handle
2483-
// descriptor mapping from fortran, generalised as mapping record types
2484-
// with implicit member maps. This lowering needs further generalisation to
2485-
// fully support fortran derived types, and C/C++ structures and classes.
2630+
// NOTE/TODO: We currently do not support arbitrary depth record
2631+
// type mapping.
24862632
if (mapData.IsAMember[i])
24872633
continue;
24882634

@@ -2493,28 +2639,7 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
24932639
continue;
24942640
}
24952641

2496-
auto mapFlag = mapData.Types[i];
2497-
bool isPtrTy = checkIfPointerMap(mapInfoOp);
2498-
if (isPtrTy)
2499-
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
2500-
2501-
// Declare Target Mappings are excluded from being marked as
2502-
// OMP_MAP_TARGET_PARAM as they are not passed as parameters.
2503-
if (isTargetParams && !mapData.IsDeclareTarget[i])
2504-
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
2505-
2506-
if (auto mapInfoOp = dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[i]))
2507-
if (mapInfoOp.getMapCaptureType().value() ==
2508-
mlir::omp::VariableCaptureKind::ByCopy &&
2509-
!isPtrTy)
2510-
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
2511-
2512-
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[i]);
2513-
combinedInfo.Pointers.emplace_back(mapData.Pointers[i]);
2514-
combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[i]);
2515-
combinedInfo.Names.emplace_back(mapData.Names[i]);
2516-
combinedInfo.Types.emplace_back(mapFlag);
2517-
combinedInfo.Sizes.emplace_back(mapData.Sizes[i]);
2642+
processIndividualMap(mapData, i, combinedInfo, isTargetParams);
25182643
}
25192644

25202645
auto findMapInfo = [&combinedInfo](llvm::Value *val, unsigned &index) {

0 commit comments

Comments
 (0)