@@ -2101,6 +2101,7 @@ getRefPtrIfDeclareTarget(mlir::Value value,
2101
2101
return nullptr ;
2102
2102
}
2103
2103
2104
+ namespace {
2104
2105
// A small helper structure to contain data gathered
2105
2106
// for map lowering and coalese it into one area and
2106
2107
// avoiding extra computations such as searches in the
@@ -2129,6 +2130,7 @@ struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy {
2129
2130
llvm::OpenMPIRBuilder::MapInfosTy::append (CurInfo);
2130
2131
}
2131
2132
};
2133
+ } // namespace
2132
2134
2133
2135
uint64_t getArrayElementSizeInBits (LLVM::LLVMArrayType arrTy, DataLayout &dl) {
2134
2136
if (auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
@@ -2195,16 +2197,15 @@ llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
2195
2197
return builder.getInt64 (dl.getTypeSizeInBits (type) / 8 );
2196
2198
}
2197
2199
2198
- void collectMapDataFromMapOperands (
2199
- MapInfoData &mapData, llvm:: SmallVectorImpl<Value> &mapVars,
2200
+ static void collectMapDataFromMapOperands (
2201
+ MapInfoData &mapData, SmallVectorImpl<Value> &mapVars,
2200
2202
LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl,
2201
- llvm::IRBuilderBase &builder,
2202
- const llvm::ArrayRef<Value> &useDevPtrOperands = {},
2203
- const llvm::ArrayRef<Value> &useDevAddrOperands = {}) {
2203
+ llvm::IRBuilderBase &builder, const ArrayRef<Value> &useDevPtrOperands = {},
2204
+ const ArrayRef<Value> &useDevAddrOperands = {}) {
2204
2205
// Process MapOperands
2205
- for (mlir:: Value mapValue : mapVars) {
2206
- auto mapOp = mlir:: cast<mlir:: omp::MapInfoOp>(mapValue.getDefiningOp ());
2207
- mlir:: Value offloadPtr =
2206
+ for (Value mapValue : mapVars) {
2207
+ auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp ());
2208
+ Value offloadPtr =
2208
2209
mapOp.getVarPtrPtr () ? mapOp.getVarPtrPtr () : mapOp.getVarPtr ();
2209
2210
mapData.OriginalValue .push_back (moduleTranslation.lookupValue (offloadPtr));
2210
2211
mapData.Pointers .push_back (mapData.OriginalValue .back ());
@@ -2240,9 +2241,9 @@ void collectMapDataFromMapOperands(
2240
2241
// TODO: May require some further additions to support nested record
2241
2242
// types, i.e. member maps that can have member maps.
2242
2243
mapData.IsAMember .push_back (false );
2243
- for (mlir:: Value mapValue : mapVars) {
2244
- if (auto map = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
2245
- mapValue.getDefiningOp ())) {
2244
+ for (Value mapValue : mapVars) {
2245
+ if (auto map =
2246
+ dyn_cast_if_present<omp::MapInfoOp>( mapValue.getDefiningOp ())) {
2246
2247
for (auto member : map.getMembers ()) {
2247
2248
if (member == mapOp) {
2248
2249
mapData.IsAMember .back () = true ;
@@ -2271,9 +2272,9 @@ void collectMapDataFromMapOperands(
2271
2272
// Process useDevPtr(Addr)Operands
2272
2273
auto addDevInfos = [&](const llvm::ArrayRef<Value> &useDevOperands,
2273
2274
llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
2274
- for (mlir:: Value mapValue : useDevOperands) {
2275
- auto mapOp = mlir:: cast<mlir:: omp::MapInfoOp>(mapValue.getDefiningOp ());
2276
- mlir:: Value offloadPtr =
2275
+ for (Value mapValue : useDevOperands) {
2276
+ auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp ());
2277
+ Value offloadPtr =
2277
2278
mapOp.getVarPtrPtr () ? mapOp.getVarPtrPtr () : mapOp.getVarPtr ();
2278
2279
llvm::Value *origValue = moduleTranslation.lookupValue (offloadPtr);
2279
2280
@@ -2302,9 +2303,9 @@ void collectMapDataFromMapOperands(
2302
2303
// TODO: May require some further additions to support nested record
2303
2304
// types, i.e. member maps that can have member maps.
2304
2305
mapData.IsAMember .push_back (false );
2305
- for (mlir:: Value mapValue : useDevOperands)
2306
- if (auto map = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
2307
- mapValue.getDefiningOp ()))
2306
+ for (Value mapValue : useDevOperands)
2307
+ if (auto map =
2308
+ dyn_cast_if_present<omp::MapInfoOp>( mapValue.getDefiningOp ()))
2308
2309
for (auto member : map.getMembers ())
2309
2310
if (member == mapOp)
2310
2311
mapData.IsAMember .back () = true ;
@@ -2316,22 +2317,21 @@ void collectMapDataFromMapOperands(
2316
2317
addDevInfos (useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
2317
2318
}
2318
2319
2319
- static int getMapDataMemberIdx (MapInfoData &mapData,
2320
- mlir::omp::MapInfoOp memberOp) {
2320
+ static int getMapDataMemberIdx (MapInfoData &mapData, omp::MapInfoOp memberOp) {
2321
2321
auto *res = llvm::find (mapData.MapClause , memberOp);
2322
2322
assert (res != mapData.MapClause .end () &&
2323
2323
" MapInfoOp for member not found in MapData, cannot return index" );
2324
2324
return std::distance (mapData.MapClause .begin (), res);
2325
2325
}
2326
2326
2327
- static mlir:: omp::MapInfoOp
2328
- getFirstOrLastMappedMemberPtr (mlir::omp::MapInfoOp mapInfo, bool first) {
2329
- mlir:: DenseIntElementsAttr indexAttr = mapInfo.getMembersIndexAttr ();
2327
+ static omp::MapInfoOp getFirstOrLastMappedMemberPtr ( omp::MapInfoOp mapInfo,
2328
+ bool first) {
2329
+ DenseIntElementsAttr indexAttr = mapInfo.getMembersIndexAttr ();
2330
2330
2331
2331
// Only 1 member has been mapped, we can return it.
2332
2332
if (indexAttr.size () == 1 )
2333
- if (auto mapOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(
2334
- mapInfo.getMembers ()[0 ].getDefiningOp ()))
2333
+ if (auto mapOp =
2334
+ dyn_cast<omp::MapInfoOp>( mapInfo.getMembers ()[0 ].getDefiningOp ()))
2335
2335
return mapOp;
2336
2336
2337
2337
llvm::ArrayRef<int64_t > shape = indexAttr.getShapedType ().getShape ();
@@ -2368,7 +2368,7 @@ getFirstOrLastMappedMemberPtr(mlir::omp::MapInfoOp mapInfo, bool first) {
2368
2368
return false ;
2369
2369
});
2370
2370
2371
- return llvm::cast<mlir:: omp::MapInfoOp>(
2371
+ return llvm::cast<omp::MapInfoOp>(
2372
2372
mapInfo.getMembers ()[indices.front ()].getDefiningOp ());
2373
2373
}
2374
2374
@@ -2394,7 +2394,7 @@ getFirstOrLastMappedMemberPtr(mlir::omp::MapInfoOp mapInfo, bool first) {
2394
2394
std::vector<llvm::Value *>
2395
2395
calculateBoundsOffset (LLVM::ModuleTranslation &moduleTranslation,
2396
2396
llvm::IRBuilderBase &builder, bool isArrayTy,
2397
- mlir:: OperandRange bounds) {
2397
+ OperandRange bounds) {
2398
2398
std::vector<llvm::Value *> idx;
2399
2399
// There's no bounds to calculate an offset from, we can safely
2400
2400
// ignore and return no indices.
@@ -2408,7 +2408,7 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
2408
2408
if (isArrayTy) {
2409
2409
idx.push_back (builder.getInt64 (0 ));
2410
2410
for (int i = bounds.size () - 1 ; i >= 0 ; --i) {
2411
- if (auto boundOp = mlir:: dyn_cast_if_present<mlir:: omp::MapBoundsOp>(
2411
+ if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
2412
2412
bounds[i].getDefiningOp ())) {
2413
2413
idx.push_back (moduleTranslation.lookupValue (boundOp.getLowerBound ()));
2414
2414
}
@@ -2434,7 +2434,7 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
2434
2434
// (extent/size of current) 100 for 1000 for each index increment
2435
2435
std::vector<llvm::Value *> dimensionIndexSizeOffset{builder.getInt64 (1 )};
2436
2436
for (size_t i = 1 ; i < bounds.size (); ++i) {
2437
- if (auto boundOp = mlir:: dyn_cast_if_present<mlir:: omp::MapBoundsOp>(
2437
+ if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
2438
2438
bounds[i].getDefiningOp ())) {
2439
2439
dimensionIndexSizeOffset.push_back (builder.CreateMul (
2440
2440
moduleTranslation.lookupValue (boundOp.getExtent ()),
@@ -2447,7 +2447,7 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
2447
2447
// have calculated in the previous and accumulate the results to get
2448
2448
// our final resulting offset.
2449
2449
for (int i = bounds.size () - 1 ; i >= 0 ; --i) {
2450
- if (auto boundOp = mlir:: dyn_cast_if_present<mlir:: omp::MapBoundsOp>(
2450
+ if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
2451
2451
bounds[i].getDefiningOp ())) {
2452
2452
if (idx.empty ())
2453
2453
idx.emplace_back (builder.CreateMul (
@@ -2504,7 +2504,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
2504
2504
// data by the descriptor (which itself, is a structure containing
2505
2505
// runtime information on the dynamically allocated data).
2506
2506
auto parentClause =
2507
- llvm::cast<mlir:: omp::MapInfoOp>(mapData.MapClause [mapDataIndex]);
2507
+ llvm::cast<omp::MapInfoOp>(mapData.MapClause [mapDataIndex]);
2508
2508
2509
2509
llvm::Value *lowAddr, *highAddr;
2510
2510
if (!parentClause.getPartialMap ()) {
@@ -2516,8 +2516,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
2516
2516
builder.getPtrTy ());
2517
2517
combinedInfo.Pointers .emplace_back (mapData.Pointers [mapDataIndex]);
2518
2518
} else {
2519
- auto mapOp =
2520
- mlir::dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause [mapDataIndex]);
2519
+ auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause [mapDataIndex]);
2521
2520
int firstMemberIdx = getMapDataMemberIdx (
2522
2521
mapData, getFirstOrLastMappedMemberPtr (mapOp, true ));
2523
2522
lowAddr = builder.CreatePointerCast (mapData.Pointers [firstMemberIdx],
@@ -2575,7 +2574,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
2575
2574
// There may be a better way to verify this, but unfortunately with
2576
2575
// opaque pointers we lose the ability to easily check if something is
2577
2576
// a pointer whilst maintaining access to the underlying type.
2578
- static bool checkIfPointerMap (mlir:: omp::MapInfoOp mapOp) {
2577
+ static bool checkIfPointerMap (omp::MapInfoOp mapOp) {
2579
2578
// If we have a varPtrPtr field assigned then the underlying type is a pointer
2580
2579
if (mapOp.getVarPtrPtr ())
2581
2580
return true ;
@@ -2597,11 +2596,11 @@ static void processMapMembersWithParent(
2597
2596
uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) {
2598
2597
2599
2598
auto parentClause =
2600
- llvm::cast<mlir:: omp::MapInfoOp>(mapData.MapClause [mapDataIndex]);
2599
+ llvm::cast<omp::MapInfoOp>(mapData.MapClause [mapDataIndex]);
2601
2600
2602
2601
for (auto mappedMembers : parentClause.getMembers ()) {
2603
2602
auto memberClause =
2604
- llvm::cast<mlir:: omp::MapInfoOp>(mappedMembers.getDefiningOp ());
2603
+ llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp ());
2605
2604
int memberDataIdx = getMapDataMemberIdx (mapData, memberClause);
2606
2605
2607
2606
assert (memberDataIdx >= 0 && " could not find mapped member of structure" );
@@ -2635,8 +2634,7 @@ processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
2635
2634
// OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're
2636
2635
// marked with OMP_MAP_PTR_AND_OBJ instead.
2637
2636
auto mapFlag = mapData.Types [mapDataIdx];
2638
- auto mapInfoOp =
2639
- llvm::cast<mlir::omp::MapInfoOp>(mapData.MapClause [mapDataIdx]);
2637
+ auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause [mapDataIdx]);
2640
2638
2641
2639
bool isPtrTy = checkIfPointerMap (mapInfoOp);
2642
2640
if (isPtrTy)
@@ -2646,7 +2644,7 @@ processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
2646
2644
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
2647
2645
2648
2646
if (mapInfoOp.getMapCaptureType ().value () ==
2649
- mlir:: omp::VariableCaptureKind::ByCopy &&
2647
+ omp::VariableCaptureKind::ByCopy &&
2650
2648
!isPtrTy)
2651
2649
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
2652
2650
@@ -2672,13 +2670,13 @@ static void processMapWithMembersOf(
2672
2670
llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData,
2673
2671
uint64_t mapDataIndex, bool isTargetParams) {
2674
2672
auto parentClause =
2675
- llvm::cast<mlir:: omp::MapInfoOp>(mapData.MapClause [mapDataIndex]);
2673
+ llvm::cast<omp::MapInfoOp>(mapData.MapClause [mapDataIndex]);
2676
2674
2677
2675
// If we have a partial map (no parent referenced in the map clauses of the
2678
2676
// directive, only members) and only a single member, we do not need to bind
2679
2677
// the map of the member to the parent, we can pass the member separately.
2680
2678
if (parentClause.getMembers ().size () == 1 && parentClause.getPartialMap ()) {
2681
- auto memberClause = llvm::cast<mlir:: omp::MapInfoOp>(
2679
+ auto memberClause = llvm::cast<omp::MapInfoOp>(
2682
2680
parentClause.getMembers ()[0 ].getDefiningOp ());
2683
2681
int memberDataIdx = getMapDataMemberIdx (mapData, memberClause);
2684
2682
// Note: Clang treats arrays with explicit bounds that fall into this
@@ -2715,11 +2713,9 @@ createAlteredByCaptureMap(MapInfoData &mapData,
2715
2713
for (size_t i = 0 ; i < mapData.MapClause .size (); ++i) {
2716
2714
// if it's declare target, skip it, it's handled separately.
2717
2715
if (!mapData.IsDeclareTarget [i]) {
2718
- auto mapOp =
2719
- mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(mapData.MapClause [i]);
2720
- mlir::omp::VariableCaptureKind captureKind =
2721
- mapOp.getMapCaptureType ().value_or (
2722
- mlir::omp::VariableCaptureKind::ByRef);
2716
+ auto mapOp = dyn_cast_if_present<omp::MapInfoOp>(mapData.MapClause [i]);
2717
+ omp::VariableCaptureKind captureKind =
2718
+ mapOp.getMapCaptureType ().value_or (omp::VariableCaptureKind::ByRef);
2723
2719
bool isPtrTy = checkIfPointerMap (mapOp);
2724
2720
2725
2721
// Currently handles array sectioning lowerbound case, but more
@@ -2730,7 +2726,7 @@ createAlteredByCaptureMap(MapInfoData &mapData,
2730
2726
// function mimics some of the logic from Clang that we require for
2731
2727
// kernel argument passing from host -> device.
2732
2728
switch (captureKind) {
2733
- case mlir:: omp::VariableCaptureKind::ByRef: {
2729
+ case omp::VariableCaptureKind::ByRef: {
2734
2730
llvm::Value *newV = mapData.Pointers [i];
2735
2731
std::vector<llvm::Value *> offsetIdx = calculateBoundsOffset (
2736
2732
moduleTranslation, builder, mapData.BaseType [i]->isArrayTy (),
@@ -2743,7 +2739,7 @@ createAlteredByCaptureMap(MapInfoData &mapData,
2743
2739
" array_offset" );
2744
2740
mapData.Pointers [i] = newV;
2745
2741
} break ;
2746
- case mlir:: omp::VariableCaptureKind::ByCopy: {
2742
+ case omp::VariableCaptureKind::ByCopy: {
2747
2743
llvm::Type *type = mapData.BaseType [i];
2748
2744
llvm::Value *newV;
2749
2745
if (mapData.Pointers [i]->getType ()->isPointerTy ())
@@ -2765,8 +2761,8 @@ createAlteredByCaptureMap(MapInfoData &mapData,
2765
2761
mapData.Pointers [i] = newV;
2766
2762
mapData.BasePointers [i] = newV;
2767
2763
} break ;
2768
- case mlir:: omp::VariableCaptureKind::This:
2769
- case mlir:: omp::VariableCaptureKind::VLAType:
2764
+ case omp::VariableCaptureKind::This:
2765
+ case omp::VariableCaptureKind::VLAType:
2770
2766
mapData.MapClause [i]->emitOpError (" Unhandled capture kind" );
2771
2767
break ;
2772
2768
}
@@ -2807,7 +2803,7 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
2807
2803
if (mapData.IsAMember [i])
2808
2804
continue ;
2809
2805
2810
- auto mapInfoOp = mlir:: dyn_cast<mlir:: omp::MapInfoOp>(mapData.MapClause [i]);
2806
+ auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause [i]);
2811
2807
if (!mapInfoOp.getMembers ().empty ()) {
2812
2808
processMapWithMembersOf (moduleTranslation, builder, *ompBuilder, dl,
2813
2809
combinedInfo, mapData, i, isTargetParams);
0 commit comments