Skip to content

Commit f40b28e

Browse files
committed
[OpenMP][MLIR] Extend record member map support for omp dialect to LLVM-IR
This patch seeks to refactor slightly and extend the current record type map support that was put in place for Fortran's descriptor types to handle explicit member mapping for record types at a single level of depth (the case of explicit mapping of nested record types is currently unsupported). This patch seeks to support this by extending the OpenMPToLLVMIRTranslation phase to more generally support record types, building on the prior groundwork in the Fortran allocatables/pointers patch. It now supports different kinds of record type mapping, in this case full record type mapping and then explicit member mapping in which there is a special case for certain types when mapped individually to not require any parent map link in the kernel argument structure. To facilitate this required: * The movement of the setting of the map flag type "ptr_and_obj" to respective frontends, now supporting it as a possible flag that can be read and printed in mlir form. Some minor changes to declare target map type setting was neccesary for this. * The addition of a member index array operand, which tracks the position of the member in the parent, required for caclulating the appropriate size to offload to the target, alongside the parents offload pointer (always the first member currently being mapped). * A partial mapping attribute operand, to indicate if the entire record type is being mapped or just member components, aiding the ability to lower record types in the different manners that are possible. * Refactoring bounds calculation for record types and general arrays to one location (as well as load/store generation prior to assigning to the kernel argument structure), as a side affect enter/exit/update/data mapping should now be more correct and fully support bounds mapping, previously this would have only worked for target. Pull Request: llvm#82852
1 parent da60f06 commit f40b28e

File tree

5 files changed

+354
-85
lines changed

5 files changed

+354
-85
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 200 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
#include "llvm/Transforms/Utils/ModuleUtils.h"
3434

3535
#include <any>
36+
#include <cstdint>
37+
#include <numeric>
3638
#include <optional>
3739
#include <utility>
3840

@@ -2050,9 +2052,9 @@ void collectMapDataFromMapOperands(MapInfoData &mapData,
20502052

20512053
mapData.BaseType.push_back(
20522054
moduleTranslation.convertType(mapOp.getVarType()));
2053-
mapData.Sizes.push_back(getSizeInBytes(
2054-
dl, mapOp.getVarType(), mapOp, mapData.BasePointers.back(),
2055-
mapData.BaseType.back(), builder, moduleTranslation));
2055+
mapData.Sizes.push_back(
2056+
getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
2057+
mapData.BaseType.back(), builder, moduleTranslation));
20562058
mapData.MapClause.push_back(mapOp.getOperation());
20572059
mapData.Types.push_back(
20582060
llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType().value()));
@@ -2083,6 +2085,79 @@ void collectMapDataFromMapOperands(MapInfoData &mapData,
20832085
}
20842086
}
20852087

2088+
static int getMapDataMemberIdx(MapInfoData &mapData,
2089+
mlir::omp::MapInfoOp memberOp) {
2090+
auto *res = llvm::find(mapData.MapClause, memberOp);
2091+
assert(res != mapData.MapClause.end());
2092+
return std::distance(mapData.MapClause.begin(), res);
2093+
}
2094+
2095+
static mlir::omp::MapInfoOp
2096+
getFirstOrLastMappedMemberPtr(mlir::omp::MapInfoOp mapInfo, bool first) {
2097+
// Only 1 member has been mapped, we can return it.
2098+
if (mapInfo.getMembersIndex()->size() == 1)
2099+
if (auto mapOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(
2100+
mapInfo.getMembers()[0].getDefiningOp()))
2101+
return mapOp;
2102+
2103+
std::vector<size_t> indices(
2104+
mapInfo.getMembersIndexAttr().getShapedType().getShape()[0]);
2105+
std::iota(indices.begin(), indices.end(), 0);
2106+
2107+
llvm::sort(
2108+
indices.begin(), indices.end(), [&](const size_t a, const size_t b) {
2109+
for (int i = 0;
2110+
i < mapInfo.getMembersIndexAttr().getShapedType().getShape()[1];
2111+
++i) {
2112+
int aIndex =
2113+
mapInfo.getMembersIndexAttr()
2114+
.getValues<int32_t>()[a * mapInfo.getMembersIndexAttr()
2115+
.getShapedType()
2116+
.getShape()[1] +
2117+
i];
2118+
int bIndex =
2119+
mapInfo.getMembersIndexAttr()
2120+
.getValues<int32_t>()[b * mapInfo.getMembersIndexAttr()
2121+
.getShapedType()
2122+
.getShape()[1] +
2123+
i];
2124+
2125+
// As we have iterated to a stage where both indices are invalid
2126+
// we likely have the same member index, possibly the same member
2127+
// being mapped, return the first.
2128+
if (aIndex == -1 && bIndex == -1)
2129+
return true;
2130+
2131+
if (aIndex == -1)
2132+
return true;
2133+
2134+
if (bIndex == -1)
2135+
return false;
2136+
2137+
// A is earlier in the record type layout than B
2138+
if (aIndex < bIndex)
2139+
return true;
2140+
2141+
if (bIndex < aIndex)
2142+
return false;
2143+
}
2144+
2145+
// iterated the entire list and couldn't make a decision, all elements
2146+
// were likely the same, return true for now similar to reaching the end
2147+
// of both and finding invalid indices.
2148+
return true;
2149+
});
2150+
2151+
if (auto mapOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(
2152+
mapInfo.getMembers()[((first) ? indices.front() : indices.back())]
2153+
.getDefiningOp()))
2154+
return mapOp;
2155+
2156+
assert(false && "getFirstOrLastMappedMemberPtr could not find approproaite "
2157+
"map information");
2158+
return {};
2159+
}
2160+
20862161
/// This function calculates the array/pointer offset for map data provided
20872162
/// with bounds operations, e.g. when provided something like the following:
20882163
///
@@ -2188,6 +2263,9 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
21882263
// which is utilised in subsequent member mappings (by modifying there map type
21892264
// with it) to indicate that a member is part of this parent and should be
21902265
// treated by the runtime as such. Important to achieve the correct mapping.
2266+
//
2267+
// This function borrows a lot from it's Clang parallel function
2268+
// emitCombinedEntry inside of CGOpenMPRuntime.cpp
21912269
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
21922270
LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
21932271
llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl,
@@ -2203,7 +2281,6 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
22032281
combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
22042282
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
22052283
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
2206-
combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
22072284

22082285
// Calculate size of the parent object being mapped based on the
22092286
// addresses at runtime, highAddr - lowAddr = size. This of course
@@ -2212,42 +2289,68 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
22122289
// Fortran pointers and allocatables, the mapping of the pointed to
22132290
// data by the descriptor (which itself, is a structure containing
22142291
// runtime information on the dynamically allocated data).
2215-
llvm::Value *lowAddr = builder.CreatePointerCast(
2216-
mapData.Pointers[mapDataIndex], builder.getPtrTy());
2217-
llvm::Value *highAddr = builder.CreatePointerCast(
2218-
builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
2219-
mapData.Pointers[mapDataIndex], 1),
2220-
builder.getPtrTy());
2292+
auto parentClause =
2293+
mlir::dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
2294+
2295+
llvm::Value *lowAddr, *highAddr;
2296+
if (!parentClause.getPartialMap()) {
2297+
lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
2298+
builder.getPtrTy());
2299+
highAddr = builder.CreatePointerCast(
2300+
builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
2301+
mapData.Pointers[mapDataIndex], 1),
2302+
builder.getPtrTy());
2303+
combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
2304+
} else {
2305+
auto mapOp =
2306+
mlir::dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
2307+
int firstMemberIdx = getMapDataMemberIdx(
2308+
mapData, getFirstOrLastMappedMemberPtr(mapOp, true));
2309+
lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
2310+
builder.getPtrTy());
2311+
int lastMemberIdx = getMapDataMemberIdx(
2312+
mapData, getFirstOrLastMappedMemberPtr(mapOp, false));
2313+
highAddr = builder.CreatePointerCast(
2314+
builder.CreateGEP(mapData.BaseType[lastMemberIdx],
2315+
mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
2316+
builder.getPtrTy());
2317+
combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
2318+
}
2319+
22212320
llvm::Value *size = builder.CreateIntCast(
22222321
builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
22232322
builder.getInt64Ty(),
22242323
/*isSigned=*/false);
22252324
combinedInfo.Sizes.push_back(size);
22262325

2227-
// This creates the initial MEMBER_OF mapping that consists of
2228-
// the parent/top level container (same as above effectively, except
2229-
// with a fixed initial compile time size and seperate maptype which
2230-
// indicates the true mape type (tofrom etc.) and that it is a part
2231-
// of a larger mapping and indicating the link between it and it's
2232-
// members that are also explicitly mapped).
2326+
// TODO: This will need expanded to include the whole host of logic for the
2327+
// map flags that Clang currently supports (e.g. it hsould take the map flag
2328+
// of the parent map flag, remove the OMP_MAP_TARGET_PARAM and do some further
2329+
// case specific flag modifications), for the moment it handles what we
2330+
// support as expected.
22332331
llvm::omp::OpenMPOffloadMappingFlags mapFlag =
22342332
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
2235-
if (isTargetParams)
2236-
mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
22372333

22382334
llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
22392335
ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
22402336
ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
22412337

2242-
combinedInfo.Types.emplace_back(mapFlag);
2243-
combinedInfo.DevicePointers.emplace_back(
2244-
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
2245-
combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
2246-
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
2247-
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
2248-
combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
2249-
combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
2250-
2338+
// This creates the initial MEMBER_OF mapping that consists of
2339+
// the parent/top level container (same as above effectively, except
2340+
// with a fixed initial compile time size and seperate maptype which
2341+
// indicates the true mape type (tofrom etc.). This parent mapping is
2342+
// only relevant if the structure in it's totality is being mapped,
2343+
// otherwise the above suffices.
2344+
if (!parentClause.getPartialMap()) {
2345+
combinedInfo.Types.emplace_back(mapFlag);
2346+
combinedInfo.DevicePointers.emplace_back(
2347+
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
2348+
combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
2349+
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
2350+
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
2351+
combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
2352+
combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
2353+
}
22512354
return memberOfFlag;
22522355
}
22532356

@@ -2285,16 +2388,12 @@ static void processMapMembersWithParent(
22852388
for (auto mappedMembers : parentClause.getMembers()) {
22862389
auto memberClause =
22872390
mlir::dyn_cast<mlir::omp::MapInfoOp>(mappedMembers.getDefiningOp());
2288-
int memberDataIdx = -1;
2289-
for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
2290-
if (mapData.MapClause[i] == memberClause)
2291-
memberDataIdx = i;
2292-
}
2391+
int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
22932392

22942393
assert(memberDataIdx >= 0 && "could not find mapped member of structure");
22952394

22962395
// Same MemberOfFlag to indicate its link with parent and other members
2297-
// of, and we flag that it's part of a pointer and object coupling.
2396+
// of
22982397
auto mapFlag =
22992398
llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType().value());
23002399
mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
@@ -2308,18 +2407,81 @@ static void processMapMembersWithParent(
23082407
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
23092408
combinedInfo.Names.emplace_back(
23102409
LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
2311-
2312-
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[memberDataIdx]);
2410+
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
23132411
combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
23142412
combinedInfo.Sizes.emplace_back(mapData.Sizes[memberDataIdx]);
23152413
}
23162414
}
23172415

2416+
static void
2417+
processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
2418+
llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo,
2419+
bool isTargetParams, int mapDataParentIdx = -1) {
2420+
// Declare Target Mappings are excluded from being marked as
2421+
// OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're
2422+
// marked with OMP_MAP_PTR_AND_OBJ instead.
2423+
auto mapFlag = mapData.Types[mapDataIdx];
2424+
auto mapInfoOp =
2425+
dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
2426+
2427+
bool isPtrTy = checkIfPointerMap(mapInfoOp);
2428+
if (isPtrTy)
2429+
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
2430+
2431+
if (isTargetParams && !mapData.IsDeclareTarget[mapDataIdx])
2432+
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
2433+
2434+
if (mapInfoOp.getMapCaptureType().value() ==
2435+
mlir::omp::VariableCaptureKind::ByCopy &&
2436+
!isPtrTy)
2437+
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
2438+
2439+
// if we're provided a mapDataParentIdx, then the data being mapped is
2440+
// part of a larger object (in a parent <-> member mapping) and in this
2441+
// case our BasePointer should be the parent.
2442+
if (mapDataParentIdx >= 0)
2443+
combinedInfo.BasePointers.emplace_back(
2444+
mapData.BasePointers[mapDataParentIdx]);
2445+
else
2446+
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
2447+
2448+
combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
2449+
combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
2450+
combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
2451+
combinedInfo.Types.emplace_back(mapFlag);
2452+
combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
2453+
}
2454+
23182455
static void processMapWithMembersOf(
23192456
LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
23202457
llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl,
23212458
llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData,
23222459
uint64_t mapDataIndex, bool isTargetParams) {
2460+
auto parentClause =
2461+
mlir::dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
2462+
2463+
// If we have a partial map (no parent referneced in the map clauses of the
2464+
// directive, only members) and only a single member, we do not need to bind
2465+
// the map of the member to the parent, we can pass the member seperately.
2466+
if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
2467+
auto memberClause = mlir::dyn_cast<mlir::omp::MapInfoOp>(
2468+
parentClause.getMembers()[0].getDefiningOp());
2469+
int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
2470+
// Note: Clang treats arrays with explicit bounds that fall into this
2471+
// category as a parent with map case, however, it seems this isn't a
2472+
// requirement, and processing them as an individual map is fine. So,
2473+
// we will handle them as individual maps for the moment, as it's
2474+
// difficult for us to check this as we always require bounds to be
2475+
// specified currently and it's also marginally more optimal (single
2476+
// map rather than two). The difference may come from the fact that
2477+
// Clang maps array without bounds as pointers (which we do not
2478+
// currently do), whereas we treat them as arrays in all cases
2479+
// currently.
2480+
processIndividualMap(mapData, memberDataIdx, combinedInfo, isTargetParams,
2481+
mapDataIndex);
2482+
return;
2483+
}
2484+
23232485
llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
23242486
mapParentWithMembers(moduleTranslation, builder, ompBuilder, dl,
23252487
combinedInfo, mapData, mapDataIndex, isTargetParams);
@@ -2438,12 +2600,8 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
24382600
// utilise the size from any component of MapInfoData, if we can't
24392601
// something is missing from the initial MapInfoData construction.
24402602
for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
2441-
// NOTE/TODO: We currently do not handle member mapping seperately from it's
2442-
// parent or explicit mapping of a parent and member in the same operation,
2443-
// this will need to change in the near future, for now we primarily handle
2444-
// descriptor mapping from fortran, generalised as mapping record types
2445-
// with implicit member maps. This lowering needs further generalisation to
2446-
// fully support fortran derived types, and C/C++ structures and classes.
2603+
// NOTE/TODO: We currently do not support arbitrary depth record
2604+
// type mapping.
24472605
if (mapData.IsAMember[i])
24482606
continue;
24492607

@@ -2454,28 +2612,7 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
24542612
continue;
24552613
}
24562614

2457-
auto mapFlag = mapData.Types[i];
2458-
bool isPtrTy = checkIfPointerMap(mapInfoOp);
2459-
if (isPtrTy)
2460-
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
2461-
2462-
// Declare Target Mappings are excluded from being marked as
2463-
// OMP_MAP_TARGET_PARAM as they are not passed as parameters.
2464-
if (isTargetParams && !mapData.IsDeclareTarget[i])
2465-
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
2466-
2467-
if (auto mapInfoOp = dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[i]))
2468-
if (mapInfoOp.getMapCaptureType().value() ==
2469-
mlir::omp::VariableCaptureKind::ByCopy &&
2470-
!isPtrTy)
2471-
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
2472-
2473-
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[i]);
2474-
combinedInfo.Pointers.emplace_back(mapData.Pointers[i]);
2475-
combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[i]);
2476-
combinedInfo.Names.emplace_back(mapData.Names[i]);
2477-
combinedInfo.Types.emplace_back(mapFlag);
2478-
combinedInfo.Sizes.emplace_back(mapData.Sizes[i]);
2615+
processIndividualMap(mapData, i, combinedInfo, isTargetParams);
24792616
}
24802617

24812618
auto findMapInfo = [&combinedInfo](llvm::Value *val, unsigned &index) {

0 commit comments

Comments
 (0)