33
33
#include " llvm/Transforms/Utils/ModuleUtils.h"
34
34
35
35
#include < any>
36
+ #include < cstdint>
36
37
#include < iterator>
38
+ #include < numeric>
37
39
#include < optional>
38
40
#include < utility>
39
41
@@ -2037,7 +2039,7 @@ llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
2037
2039
if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
2038
2040
bounds.getDefiningOp ())) {
2039
2041
// The below calculation for the size to be mapped calculated from the
2040
- // map_info 's bounds is: (elemCount * [UB - LB] + 1), later we
2042
+ // map.info 's bounds is: (elemCount * [UB - LB] + 1), later we
2041
2043
// multiply by the underlying element types byte size to get the full
2042
2044
// size to be offloaded based on the bounds
2043
2045
elementCount = builder.CreateMul (
@@ -2089,9 +2091,9 @@ void collectMapDataFromMapOperands(MapInfoData &mapData,
2089
2091
2090
2092
mapData.BaseType .push_back (
2091
2093
moduleTranslation.convertType (mapOp.getVarType ()));
2092
- mapData.Sizes .push_back (getSizeInBytes (
2093
- dl, mapOp.getVarType (), mapOp, mapData.BasePointers .back (),
2094
- mapData.BaseType .back (), builder, moduleTranslation));
2094
+ mapData.Sizes .push_back (
2095
+ getSizeInBytes ( dl, mapOp.getVarType (), mapOp, mapData.Pointers .back (),
2096
+ mapData.BaseType .back (), builder, moduleTranslation));
2095
2097
mapData.MapClause .push_back (mapOp.getOperation ());
2096
2098
mapData.Types .push_back (
2097
2099
llvm::omp::OpenMPOffloadMappingFlags (mapOp.getMapType ().value ()));
@@ -2122,6 +2124,67 @@ void collectMapDataFromMapOperands(MapInfoData &mapData,
2122
2124
}
2123
2125
}
2124
2126
2127
+ static int getMapDataMemberIdx (MapInfoData &mapData,
2128
+ mlir::omp::MapInfoOp memberOp) {
2129
+ auto *res = llvm::find (mapData.MapClause , memberOp);
2130
+ assert (res != mapData.MapClause .end () &&
2131
+ " MapInfoOp for member not found in MapData, cannot return index" );
2132
+ return std::distance (mapData.MapClause .begin (), res);
2133
+ }
2134
+
2135
+ static mlir::omp::MapInfoOp
2136
+ getFirstOrLastMappedMemberPtr (mlir::omp::MapInfoOp mapInfo, bool first) {
2137
+ mlir::DenseIntElementsAttr indexAttr = mapInfo.getMembersIndexAttr ();
2138
+
2139
+ // Only 1 member has been mapped, we can return it.
2140
+ if (indexAttr.size () == 1 )
2141
+ if (auto mapOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(
2142
+ mapInfo.getMembers ()[0 ].getDefiningOp ()))
2143
+ return mapOp;
2144
+
2145
+ llvm::ArrayRef<int64_t > shape = indexAttr.getShapedType ().getShape ();
2146
+ llvm::SmallVector<size_t > indices (shape[0 ]);
2147
+ std::iota (indices.begin (), indices.end (), 0 );
2148
+
2149
+ llvm::sort (
2150
+ indices.begin (), indices.end (), [&](const size_t a, const size_t b) {
2151
+ auto indexValues = indexAttr.getValues <int32_t >();
2152
+ for (int i = 0 ;
2153
+ i < shape[1 ];
2154
+ ++i) {
2155
+ int aIndex = indexValues[a * shape[1 ] + i];
2156
+ int bIndex = indexValues[b * shape[1 ] + i];
2157
+
2158
+ if (aIndex != -1 && bIndex == -1 )
2159
+ return false ;
2160
+
2161
+ if (aIndex == -1 && bIndex != -1 )
2162
+ return true ;
2163
+
2164
+ if (aIndex == -1 )
2165
+ return first;
2166
+
2167
+ if (bIndex == -1 )
2168
+ return !first;
2169
+
2170
+ // A is earlier in the record type layout than B
2171
+ if (aIndex < bIndex)
2172
+ return first;
2173
+
2174
+ if (bIndex < aIndex)
2175
+ return !first;
2176
+ }
2177
+
2178
+ // iterated the entire list and couldn't make a decision, all elements
2179
+ // were likely the same, return true for now similar to reaching the end
2180
+ // of both and finding invalid indices.
2181
+ return true ;
2182
+ });
2183
+
2184
+ return llvm::cast<mlir::omp::MapInfoOp>(
2185
+ mapInfo.getMembers ()[indices.front ()].getDefiningOp ());
2186
+ }
2187
+
2125
2188
// / This function calculates the array/pointer offset for map data provided
2126
2189
// / with bounds operations, e.g. when provided something like the following:
2127
2190
// /
@@ -2227,6 +2290,9 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
2227
2290
// which is utilised in subsequent member mappings (by modifying there map type
2228
2291
// with it) to indicate that a member is part of this parent and should be
2229
2292
// treated by the runtime as such. Important to achieve the correct mapping.
2293
+ //
2294
+ // This function borrows a lot from Clang's emitCombinedEntry function
2295
+ // inside of CGOpenMPRuntime.cpp
2230
2296
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers (
2231
2297
LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
2232
2298
llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl,
@@ -2242,7 +2308,6 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
2242
2308
combinedInfo.Names .emplace_back (LLVM::createMappingInformation (
2243
2309
mapData.MapClause [mapDataIndex]->getLoc (), ompBuilder));
2244
2310
combinedInfo.BasePointers .emplace_back (mapData.BasePointers [mapDataIndex]);
2245
- combinedInfo.Pointers .emplace_back (mapData.Pointers [mapDataIndex]);
2246
2311
2247
2312
// Calculate size of the parent object being mapped based on the
2248
2313
// addresses at runtime, highAddr - lowAddr = size. This of course
@@ -2251,42 +2316,68 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
2251
2316
// Fortran pointers and allocatables, the mapping of the pointed to
2252
2317
// data by the descriptor (which itself, is a structure containing
2253
2318
// runtime information on the dynamically allocated data).
2254
- llvm::Value *lowAddr = builder.CreatePointerCast (
2255
- mapData.Pointers [mapDataIndex], builder.getPtrTy ());
2256
- llvm::Value *highAddr = builder.CreatePointerCast (
2257
- builder.CreateConstGEP1_32 (mapData.BaseType [mapDataIndex],
2258
- mapData.Pointers [mapDataIndex], 1 ),
2259
- builder.getPtrTy ());
2319
+ auto parentClause =
2320
+ llvm::cast<mlir::omp::MapInfoOp>(mapData.MapClause [mapDataIndex]);
2321
+
2322
+ llvm::Value *lowAddr, *highAddr;
2323
+ if (!parentClause.getPartialMap ()) {
2324
+ lowAddr = builder.CreatePointerCast (mapData.Pointers [mapDataIndex],
2325
+ builder.getPtrTy ());
2326
+ highAddr = builder.CreatePointerCast (
2327
+ builder.CreateConstGEP1_32 (mapData.BaseType [mapDataIndex],
2328
+ mapData.Pointers [mapDataIndex], 1 ),
2329
+ builder.getPtrTy ());
2330
+ combinedInfo.Pointers .emplace_back (mapData.Pointers [mapDataIndex]);
2331
+ } else {
2332
+ auto mapOp =
2333
+ mlir::dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause [mapDataIndex]);
2334
+ int firstMemberIdx = getMapDataMemberIdx (
2335
+ mapData, getFirstOrLastMappedMemberPtr (mapOp, true ));
2336
+ lowAddr = builder.CreatePointerCast (mapData.Pointers [firstMemberIdx],
2337
+ builder.getPtrTy ());
2338
+ int lastMemberIdx = getMapDataMemberIdx (
2339
+ mapData, getFirstOrLastMappedMemberPtr (mapOp, false ));
2340
+ highAddr = builder.CreatePointerCast (
2341
+ builder.CreateGEP (mapData.BaseType [lastMemberIdx],
2342
+ mapData.Pointers [lastMemberIdx], builder.getInt64 (1 )),
2343
+ builder.getPtrTy ());
2344
+ combinedInfo.Pointers .emplace_back (mapData.Pointers [firstMemberIdx]);
2345
+ }
2346
+
2260
2347
llvm::Value *size = builder.CreateIntCast (
2261
2348
builder.CreatePtrDiff (builder.getInt8Ty (), highAddr, lowAddr),
2262
2349
builder.getInt64Ty (),
2263
2350
/* isSigned=*/ false );
2264
2351
combinedInfo.Sizes .push_back (size);
2265
2352
2266
- // This creates the initial MEMBER_OF mapping that consists of
2267
- // the parent/top level container (same as above effectively, except
2268
- // with a fixed initial compile time size and seperate maptype which
2269
- // indicates the true mape type (tofrom etc.) and that it is a part
2270
- // of a larger mapping and indicating the link between it and it's
2271
- // members that are also explicitly mapped).
2353
+ // TODO: This will need to be expanded to include the whole host of logic for
2354
+ // the map flags that Clang currently supports (e.g. it should take the map
2355
+ // flag of the parent map flag, remove the OMP_MAP_TARGET_PARAM and do some
2356
+ // further case specific flag modifications). For the moment, it handles what
2357
+ // we support as expected.
2272
2358
llvm::omp::OpenMPOffloadMappingFlags mapFlag =
2273
2359
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
2274
- if (isTargetParams)
2275
- mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
2276
2360
2277
2361
llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
2278
2362
ompBuilder.getMemberOfFlag (combinedInfo.BasePointers .size () - 1 );
2279
2363
ompBuilder.setCorrectMemberOfFlag (mapFlag, memberOfFlag);
2280
2364
2281
- combinedInfo.Types .emplace_back (mapFlag);
2282
- combinedInfo.DevicePointers .emplace_back (
2283
- llvm::OpenMPIRBuilder::DeviceInfoTy::None);
2284
- combinedInfo.Names .emplace_back (LLVM::createMappingInformation (
2285
- mapData.MapClause [mapDataIndex]->getLoc (), ompBuilder));
2286
- combinedInfo.BasePointers .emplace_back (mapData.BasePointers [mapDataIndex]);
2287
- combinedInfo.Pointers .emplace_back (mapData.Pointers [mapDataIndex]);
2288
- combinedInfo.Sizes .emplace_back (mapData.Sizes [mapDataIndex]);
2289
-
2365
+ // This creates the initial MEMBER_OF mapping that consists of
2366
+ // the parent/top level container (same as above effectively, except
2367
+ // with a fixed initial compile time size and seperate maptype which
2368
+ // indicates the true mape type (tofrom etc.). This parent mapping is
2369
+ // only relevant if the structure in its totality is being mapped,
2370
+ // otherwise the above suffices.
2371
+ if (!parentClause.getPartialMap ()) {
2372
+ combinedInfo.Types .emplace_back (mapFlag);
2373
+ combinedInfo.DevicePointers .emplace_back (
2374
+ llvm::OpenMPIRBuilder::DeviceInfoTy::None);
2375
+ combinedInfo.Names .emplace_back (LLVM::createMappingInformation (
2376
+ mapData.MapClause [mapDataIndex]->getLoc (), ompBuilder));
2377
+ combinedInfo.BasePointers .emplace_back (mapData.BasePointers [mapDataIndex]);
2378
+ combinedInfo.Pointers .emplace_back (mapData.Pointers [mapDataIndex]);
2379
+ combinedInfo.Sizes .emplace_back (mapData.Sizes [mapDataIndex]);
2380
+ }
2290
2381
return memberOfFlag;
2291
2382
}
2292
2383
@@ -2319,21 +2410,17 @@ static void processMapMembersWithParent(
2319
2410
uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) {
2320
2411
2321
2412
auto parentClause =
2322
- mlir::dyn_cast <mlir::omp::MapInfoOp>(mapData.MapClause [mapDataIndex]);
2413
+ llvm::cast <mlir::omp::MapInfoOp>(mapData.MapClause [mapDataIndex]);
2323
2414
2324
2415
for (auto mappedMembers : parentClause.getMembers ()) {
2325
2416
auto memberClause =
2326
- mlir::dyn_cast<mlir::omp::MapInfoOp>(mappedMembers.getDefiningOp ());
2327
- int memberDataIdx = -1 ;
2328
- for (size_t i = 0 ; i < mapData.MapClause .size (); ++i) {
2329
- if (mapData.MapClause [i] == memberClause)
2330
- memberDataIdx = i;
2331
- }
2417
+ llvm::cast<mlir::omp::MapInfoOp>(mappedMembers.getDefiningOp ());
2418
+ int memberDataIdx = getMapDataMemberIdx (mapData, memberClause);
2332
2419
2333
2420
assert (memberDataIdx >= 0 && " could not find mapped member of structure" );
2334
2421
2335
2422
// Same MemberOfFlag to indicate its link with parent and other members
2336
- // of, and we flag that it's part of a pointer and object coupling .
2423
+ // of.
2337
2424
auto mapFlag =
2338
2425
llvm::omp::OpenMPOffloadMappingFlags (memberClause.getMapType ().value ());
2339
2426
mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
@@ -2347,18 +2434,81 @@ static void processMapMembersWithParent(
2347
2434
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
2348
2435
combinedInfo.Names .emplace_back (
2349
2436
LLVM::createMappingInformation (memberClause.getLoc (), ompBuilder));
2350
-
2351
- combinedInfo.BasePointers .emplace_back (mapData.BasePointers [memberDataIdx]);
2437
+ combinedInfo.BasePointers .emplace_back (mapData.BasePointers [mapDataIndex]);
2352
2438
combinedInfo.Pointers .emplace_back (mapData.Pointers [memberDataIdx]);
2353
2439
combinedInfo.Sizes .emplace_back (mapData.Sizes [memberDataIdx]);
2354
2440
}
2355
2441
}
2356
2442
2443
+ static void
2444
+ processIndividualMap (MapInfoData &mapData, size_t mapDataIdx,
2445
+ llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo,
2446
+ bool isTargetParams, int mapDataParentIdx = -1 ) {
2447
+ // Declare Target Mappings are excluded from being marked as
2448
+ // OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're
2449
+ // marked with OMP_MAP_PTR_AND_OBJ instead.
2450
+ auto mapFlag = mapData.Types [mapDataIdx];
2451
+ auto mapInfoOp =
2452
+ llvm::cast<mlir::omp::MapInfoOp>(mapData.MapClause [mapDataIdx]);
2453
+
2454
+ bool isPtrTy = checkIfPointerMap (mapInfoOp);
2455
+ if (isPtrTy)
2456
+ mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
2457
+
2458
+ if (isTargetParams && !mapData.IsDeclareTarget [mapDataIdx])
2459
+ mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
2460
+
2461
+ if (mapInfoOp.getMapCaptureType ().value () ==
2462
+ mlir::omp::VariableCaptureKind::ByCopy &&
2463
+ !isPtrTy)
2464
+ mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
2465
+
2466
+ // if we're provided a mapDataParentIdx, then the data being mapped is
2467
+ // part of a larger object (in a parent <-> member mapping) and in this
2468
+ // case our BasePointer should be the parent.
2469
+ if (mapDataParentIdx >= 0 )
2470
+ combinedInfo.BasePointers .emplace_back (
2471
+ mapData.BasePointers [mapDataParentIdx]);
2472
+ else
2473
+ combinedInfo.BasePointers .emplace_back (mapData.BasePointers [mapDataIdx]);
2474
+
2475
+ combinedInfo.Pointers .emplace_back (mapData.Pointers [mapDataIdx]);
2476
+ combinedInfo.DevicePointers .emplace_back (mapData.DevicePointers [mapDataIdx]);
2477
+ combinedInfo.Names .emplace_back (mapData.Names [mapDataIdx]);
2478
+ combinedInfo.Types .emplace_back (mapFlag);
2479
+ combinedInfo.Sizes .emplace_back (mapData.Sizes [mapDataIdx]);
2480
+ }
2481
+
2357
2482
static void processMapWithMembersOf (
2358
2483
LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
2359
2484
llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl,
2360
2485
llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData,
2361
2486
uint64_t mapDataIndex, bool isTargetParams) {
2487
+ auto parentClause =
2488
+ llvm::cast<mlir::omp::MapInfoOp>(mapData.MapClause [mapDataIndex]);
2489
+
2490
+ // If we have a partial map (no parent referenced in the map clauses of the
2491
+ // directive, only members) and only a single member, we do not need to bind
2492
+ // the map of the member to the parent, we can pass the member seperately.
2493
+ if (parentClause.getMembers ().size () == 1 && parentClause.getPartialMap ()) {
2494
+ auto memberClause = llvm::cast<mlir::omp::MapInfoOp>(
2495
+ parentClause.getMembers ()[0 ].getDefiningOp ());
2496
+ int memberDataIdx = getMapDataMemberIdx (mapData, memberClause);
2497
+ // Note: Clang treats arrays with explicit bounds that fall into this
2498
+ // category as a parent with map case, however, it seems this isn't a
2499
+ // requirement, and processing them as an individual map is fine. So,
2500
+ // we will handle them as individual maps for the moment, as it's
2501
+ // difficult for us to check this as we always require bounds to be
2502
+ // specified currently and it's also marginally more optimal (single
2503
+ // map rather than two). The difference may come from the fact that
2504
+ // Clang maps array without bounds as pointers (which we do not
2505
+ // currently do), whereas we treat them as arrays in all cases
2506
+ // currently.
2507
+ processIndividualMap (mapData, memberDataIdx, combinedInfo, isTargetParams,
2508
+ mapDataIndex);
2509
+ return ;
2510
+ }
2511
+
2362
2512
llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
2363
2513
mapParentWithMembers (moduleTranslation, builder, ompBuilder, dl,
2364
2514
combinedInfo, mapData, mapDataIndex, isTargetParams);
@@ -2477,12 +2627,8 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
2477
2627
// utilise the size from any component of MapInfoData, if we can't
2478
2628
// something is missing from the initial MapInfoData construction.
2479
2629
for (size_t i = 0 ; i < mapData.MapClause .size (); ++i) {
2480
- // NOTE/TODO: We currently do not handle member mapping seperately from it's
2481
- // parent or explicit mapping of a parent and member in the same operation,
2482
- // this will need to change in the near future, for now we primarily handle
2483
- // descriptor mapping from fortran, generalised as mapping record types
2484
- // with implicit member maps. This lowering needs further generalisation to
2485
- // fully support fortran derived types, and C/C++ structures and classes.
2630
+ // NOTE/TODO: We currently do not support arbitrary depth record
2631
+ // type mapping.
2486
2632
if (mapData.IsAMember [i])
2487
2633
continue ;
2488
2634
@@ -2493,28 +2639,7 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
2493
2639
continue ;
2494
2640
}
2495
2641
2496
- auto mapFlag = mapData.Types [i];
2497
- bool isPtrTy = checkIfPointerMap (mapInfoOp);
2498
- if (isPtrTy)
2499
- mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
2500
-
2501
- // Declare Target Mappings are excluded from being marked as
2502
- // OMP_MAP_TARGET_PARAM as they are not passed as parameters.
2503
- if (isTargetParams && !mapData.IsDeclareTarget [i])
2504
- mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
2505
-
2506
- if (auto mapInfoOp = dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause [i]))
2507
- if (mapInfoOp.getMapCaptureType ().value () ==
2508
- mlir::omp::VariableCaptureKind::ByCopy &&
2509
- !isPtrTy)
2510
- mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
2511
-
2512
- combinedInfo.BasePointers .emplace_back (mapData.BasePointers [i]);
2513
- combinedInfo.Pointers .emplace_back (mapData.Pointers [i]);
2514
- combinedInfo.DevicePointers .emplace_back (mapData.DevicePointers [i]);
2515
- combinedInfo.Names .emplace_back (mapData.Names [i]);
2516
- combinedInfo.Types .emplace_back (mapFlag);
2517
- combinedInfo.Sizes .emplace_back (mapData.Sizes [i]);
2642
+ processIndividualMap (mapData, i, combinedInfo, isTargetParams);
2518
2643
}
2519
2644
2520
2645
auto findMapInfo = [&combinedInfo](llvm::Value *val, unsigned &index) {
0 commit comments