@@ -2110,6 +2110,8 @@ getRefPtrIfDeclareTarget(mlir::Value value,
2110
2110
struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy {
2111
2111
llvm::SmallVector<bool , 4 > IsDeclareTarget;
2112
2112
llvm::SmallVector<bool , 4 > IsAMember;
2113
+ // Identify if mapping was added by mapClause or use_device clauses.
2114
+ llvm::SmallVector<bool , 4 > IsAMapping;
2113
2115
llvm::SmallVector<mlir::Operation *, 4 > MapClause;
2114
2116
llvm::SmallVector<llvm::Value *, 4 > OriginalValue;
2115
2117
// Stripped off array/pointer to get the underlying
@@ -2193,62 +2195,125 @@ llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
2193
2195
return builder.getInt64 (dl.getTypeSizeInBits (type) / 8 );
2194
2196
}
2195
2197
2196
- void collectMapDataFromMapVars (MapInfoData &mapData,
2197
- llvm::SmallVectorImpl<Value> &mapVars,
2198
- LLVM::ModuleTranslation &moduleTranslation,
2199
- DataLayout &dl, llvm::IRBuilderBase &builder) {
2198
+ void collectMapDataFromMapOperands (
2199
+ MapInfoData &mapData, llvm::SmallVectorImpl<Value> &mapVars,
2200
+ LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl,
2201
+ llvm::IRBuilderBase &builder,
2202
+ const llvm::ArrayRef<Value> &useDevPtrOperands = {},
2203
+ const llvm::ArrayRef<Value> &useDevAddrOperands = {}) {
2204
+ // Process MapOperands
2200
2205
for (mlir::Value mapValue : mapVars) {
2201
- if (auto mapOp = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
2202
- mapValue.getDefiningOp ())) {
2203
- mlir::Value offloadPtr =
2204
- mapOp.getVarPtrPtr () ? mapOp.getVarPtrPtr () : mapOp.getVarPtr ();
2205
- mapData.OriginalValue .push_back (
2206
- moduleTranslation.lookupValue (offloadPtr));
2207
- mapData.Pointers .push_back (mapData.OriginalValue .back ());
2208
-
2209
- if (llvm::Value *refPtr =
2210
- getRefPtrIfDeclareTarget (offloadPtr,
2211
- moduleTranslation)) { // declare target
2212
- mapData.IsDeclareTarget .push_back (true );
2213
- mapData.BasePointers .push_back (refPtr);
2214
- } else { // regular mapped variable
2215
- mapData.IsDeclareTarget .push_back (false );
2216
- mapData.BasePointers .push_back (mapData.OriginalValue .back ());
2217
- }
2206
+ auto mapOp = mlir::cast<mlir::omp::MapInfoOp>(mapValue.getDefiningOp ());
2207
+ mlir::Value offloadPtr =
2208
+ mapOp.getVarPtrPtr () ? mapOp.getVarPtrPtr () : mapOp.getVarPtr ();
2209
+ mapData.OriginalValue .push_back (moduleTranslation.lookupValue (offloadPtr));
2210
+ mapData.Pointers .push_back (mapData.OriginalValue .back ());
2211
+
2212
+ if (llvm::Value *refPtr =
2213
+ getRefPtrIfDeclareTarget (offloadPtr,
2214
+ moduleTranslation)) { // declare target
2215
+ mapData.IsDeclareTarget .push_back (true );
2216
+ mapData.BasePointers .push_back (refPtr);
2217
+ } else { // regular mapped variable
2218
+ mapData.IsDeclareTarget .push_back (false );
2219
+ mapData.BasePointers .push_back (mapData.OriginalValue .back ());
2220
+ }
2218
2221
2219
- mapData.BaseType .push_back (
2220
- moduleTranslation.convertType (mapOp.getVarType ()));
2221
- mapData.Sizes .push_back (
2222
- getSizeInBytes (dl, mapOp.getVarType (), mapOp, mapData.Pointers .back (),
2223
- mapData.BaseType .back (), builder, moduleTranslation));
2224
- mapData.MapClause .push_back (mapOp.getOperation ());
2225
- mapData.Types .push_back (
2226
- llvm::omp::OpenMPOffloadMappingFlags (mapOp.getMapType ().value ()));
2227
- mapData.Names .push_back (LLVM::createMappingInformation (
2228
- mapOp.getLoc (), *moduleTranslation.getOpenMPBuilder ()));
2229
- mapData.DevicePointers .push_back (
2230
- llvm::OpenMPIRBuilder::DeviceInfoTy::None);
2231
-
2232
- // Check if this is a member mapping and correctly assign that it is, if
2233
- // it is a member of a larger object.
2234
- // TODO: Need better handling of members, and distinguishing of members
2235
- // that are implicitly allocated on device vs explicitly passed in as
2236
- // arguments.
2237
- // TODO: May require some further additions to support nested record
2238
- // types, i.e. member maps that can have member maps.
2239
- mapData.IsAMember .push_back (false );
2240
- for (mlir::Value mapValue : mapVars) {
2241
- if (auto map = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
2242
- mapValue.getDefiningOp ())) {
2243
- for (auto member : map.getMembers ()) {
2244
- if (member == mapOp) {
2245
- mapData.IsAMember .back () = true ;
2246
- }
2222
+ mapData.BaseType .push_back (
2223
+ moduleTranslation.convertType (mapOp.getVarType ()));
2224
+ mapData.Sizes .push_back (
2225
+ getSizeInBytes (dl, mapOp.getVarType (), mapOp, mapData.Pointers .back (),
2226
+ mapData.BaseType .back (), builder, moduleTranslation));
2227
+ mapData.MapClause .push_back (mapOp.getOperation ());
2228
+ mapData.Types .push_back (
2229
+ llvm::omp::OpenMPOffloadMappingFlags (mapOp.getMapType ().value ()));
2230
+ mapData.Names .push_back (LLVM::createMappingInformation (
2231
+ mapOp.getLoc (), *moduleTranslation.getOpenMPBuilder ()));
2232
+ mapData.DevicePointers .push_back (llvm::OpenMPIRBuilder::DeviceInfoTy::None);
2233
+ mapData.IsAMapping .push_back (true );
2234
+
2235
+ // Check if this is a member mapping and correctly assign that it is, if
2236
+ // it is a member of a larger object.
2237
+ // TODO: Need better handling of members, and distinguishing of members
2238
+ // that are implicitly allocated on device vs explicitly passed in as
2239
+ // arguments.
2240
+ // TODO: May require some further additions to support nested record
2241
+ // types, i.e. member maps that can have member maps.
2242
+ mapData.IsAMember .push_back (false );
2243
+ for (mlir::Value mapValue : mapVars) {
2244
+ if (auto map = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
2245
+ mapValue.getDefiningOp ())) {
2246
+ for (auto member : map.getMembers ()) {
2247
+ if (member == mapOp) {
2248
+ mapData.IsAMember .back () = true ;
2247
2249
}
2248
2250
}
2249
2251
}
2250
2252
}
2251
2253
}
2254
+
2255
+ auto findMapInfo = [&mapData](llvm::Value *val,
2256
+ llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
2257
+ unsigned index = 0 ;
2258
+ bool found = false ;
2259
+ for (llvm::Value *basePtr : mapData.OriginalValue ) {
2260
+ if (basePtr == val && mapData.IsAMapping [index]) {
2261
+ found = true ;
2262
+ mapData.Types [index] |=
2263
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
2264
+ mapData.DevicePointers [index] = devInfoTy;
2265
+ }
2266
+ index++;
2267
+ }
2268
+ return found;
2269
+ };
2270
+
2271
+ // Process useDevPtr(Addr)Operands
2272
+ auto addDevInfos = [&](const llvm::ArrayRef<Value> &useDevOperands,
2273
+ llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
2274
+ for (mlir::Value mapValue : useDevOperands) {
2275
+ auto mapOp = mlir::cast<mlir::omp::MapInfoOp>(mapValue.getDefiningOp ());
2276
+ mlir::Value offloadPtr =
2277
+ mapOp.getVarPtrPtr () ? mapOp.getVarPtrPtr () : mapOp.getVarPtr ();
2278
+ llvm::Value *origValue = moduleTranslation.lookupValue (offloadPtr);
2279
+
2280
+ // Check if map info is already present for this entry.
2281
+ if (!findMapInfo (origValue, devInfoTy)) {
2282
+ mapData.OriginalValue .push_back (origValue);
2283
+ mapData.Pointers .push_back (mapData.OriginalValue .back ());
2284
+ mapData.IsDeclareTarget .push_back (false );
2285
+ mapData.BasePointers .push_back (mapData.OriginalValue .back ());
2286
+ mapData.BaseType .push_back (
2287
+ moduleTranslation.convertType (mapOp.getVarType ()));
2288
+ mapData.Sizes .push_back (builder.getInt64 (0 ));
2289
+ mapData.MapClause .push_back (mapOp.getOperation ());
2290
+ mapData.Types .push_back (
2291
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
2292
+ mapData.Names .push_back (LLVM::createMappingInformation (
2293
+ mapOp.getLoc (), *moduleTranslation.getOpenMPBuilder ()));
2294
+ mapData.DevicePointers .push_back (devInfoTy);
2295
+ mapData.IsAMapping .push_back (true );
2296
+
2297
+ // Check if this is a member mapping and correctly assign that it is,
2298
+ // if it is a member of a larger object.
2299
+ // TODO: Need better handling of members, and distinguishing of
2300
+ // members that are implicitly allocated on device vs explicitly
2301
+ // passed in as arguments.
2302
+ // TODO: May require some further additions to support nested record
2303
+ // types, i.e. member maps that can have member maps.
2304
+ mapData.IsAMember .push_back (false );
2305
+ for (mlir::Value mapValue : useDevOperands)
2306
+ if (auto map = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
2307
+ mapValue.getDefiningOp ()))
2308
+ for (auto member : map.getMembers ())
2309
+ if (member == mapOp)
2310
+ mapData.IsAMember .back () = true ;
2311
+ }
2312
+ }
2313
+ };
2314
+
2315
+ addDevInfos (useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
2316
+ addDevInfos (useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
2252
2317
}
2253
2318
2254
2319
static int getMapDataMemberIdx (MapInfoData &mapData,
@@ -2426,7 +2491,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
2426
2491
? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
2427
2492
: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE);
2428
2493
combinedInfo.DevicePointers .emplace_back (
2429
- llvm::OpenMPIRBuilder::DeviceInfoTy::None );
2494
+ mapData. DevicePointers [mapDataIndex] );
2430
2495
combinedInfo.Names .emplace_back (LLVM::createMappingInformation (
2431
2496
mapData.MapClause [mapDataIndex]->getLoc (), ompBuilder));
2432
2497
combinedInfo.BasePointers .emplace_back (mapData.BasePointers [mapDataIndex]);
@@ -2553,7 +2618,7 @@ static void processMapMembersWithParent(
2553
2618
2554
2619
combinedInfo.Types .emplace_back (mapFlag);
2555
2620
combinedInfo.DevicePointers .emplace_back (
2556
- llvm::OpenMPIRBuilder::DeviceInfoTy::None );
2621
+ mapData. DevicePointers [memberDataIdx] );
2557
2622
combinedInfo.Names .emplace_back (
2558
2623
LLVM::createMappingInformation (memberClause.getLoc (), ompBuilder));
2559
2624
combinedInfo.BasePointers .emplace_back (mapData.BasePointers [mapDataIndex]);
@@ -2714,10 +2779,7 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
2714
2779
LLVM::ModuleTranslation &moduleTranslation,
2715
2780
DataLayout &dl,
2716
2781
llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo,
2717
- MapInfoData &mapData,
2718
- const SmallVector<Value> &useDevicePtrVars = {},
2719
- const SmallVector<Value> &useDeviceAddrVars = {},
2720
- bool isTargetParams = false ) {
2782
+ MapInfoData &mapData, bool isTargetParams = false ) {
2721
2783
// We wish to modify some of the methods in which arguments are
2722
2784
// passed based on their capture type by the target region, this can
2723
2785
// involve generating new loads and stores, which changes the
@@ -2734,15 +2796,6 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
2734
2796
2735
2797
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
2736
2798
2737
- auto fail = [&combinedInfo]() -> void {
2738
- combinedInfo.BasePointers .clear ();
2739
- combinedInfo.Pointers .clear ();
2740
- combinedInfo.DevicePointers .clear ();
2741
- combinedInfo.Sizes .clear ();
2742
- combinedInfo.Types .clear ();
2743
- combinedInfo.Names .clear ();
2744
- };
2745
-
2746
2799
// We operate under the assumption that all vectors that are
2747
2800
// required in MapInfoData are of equal lengths (either filled with
2748
2801
// default constructed data or appropiate information) so we can
@@ -2763,46 +2816,6 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
2763
2816
2764
2817
processIndividualMap (mapData, i, combinedInfo, isTargetParams);
2765
2818
}
2766
-
2767
- auto findMapInfo = [&combinedInfo](llvm::Value *val, unsigned &index) {
2768
- index = 0 ;
2769
- for (llvm::Value *basePtr : combinedInfo.BasePointers ) {
2770
- if (basePtr == val)
2771
- return true ;
2772
- index++;
2773
- }
2774
- return false ;
2775
- };
2776
-
2777
- auto addDevInfos = [&, fail](auto useDeviceVars, auto devOpType) -> void {
2778
- for (const auto &useDeviceVar : useDeviceVars) {
2779
- // TODO: Only LLVMPointerTypes are handled.
2780
- if (!isa<LLVM::LLVMPointerType>(useDeviceVar.getType ()))
2781
- return fail ();
2782
-
2783
- llvm::Value *mapOpValue = moduleTranslation.lookupValue (useDeviceVar);
2784
-
2785
- // Check if map info is already present for this entry.
2786
- unsigned infoIndex;
2787
- if (findMapInfo (mapOpValue, infoIndex)) {
2788
- combinedInfo.Types [infoIndex] |=
2789
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
2790
- combinedInfo.DevicePointers [infoIndex] = devOpType;
2791
- } else {
2792
- combinedInfo.BasePointers .emplace_back (mapOpValue);
2793
- combinedInfo.Pointers .emplace_back (mapOpValue);
2794
- combinedInfo.DevicePointers .emplace_back (devOpType);
2795
- combinedInfo.Names .emplace_back (
2796
- LLVM::createMappingInformation (useDeviceVar.getLoc (), *ompBuilder));
2797
- combinedInfo.Types .emplace_back (
2798
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
2799
- combinedInfo.Sizes .emplace_back (builder.getInt64 (0 ));
2800
- }
2801
- }
2802
- };
2803
-
2804
- addDevInfos (useDevicePtrVars, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
2805
- addDevInfos (useDeviceAddrVars, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
2806
2819
}
2807
2820
2808
2821
static LogicalResult
@@ -2899,19 +2912,15 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
2899
2912
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2900
2913
2901
2914
MapInfoData mapData;
2902
- collectMapDataFromMapVars (mapData, mapVars, moduleTranslation, DL, builder);
2915
+ collectMapDataFromMapOperands (mapData, mapVars, moduleTranslation, DL,
2916
+ builder, useDevicePtrVars, useDeviceAddrVars);
2903
2917
2904
2918
// Fill up the arrays with all the mapped variables.
2905
2919
llvm::OpenMPIRBuilder::MapInfosTy combinedInfo;
2906
2920
auto genMapInfoCB =
2907
2921
[&](InsertPointTy codeGenIP) -> llvm::OpenMPIRBuilder::MapInfosTy & {
2908
2922
builder.restoreIP (codeGenIP);
2909
- if (auto dataOp = dyn_cast<omp::TargetDataOp>(op)) {
2910
- genMapInfos (builder, moduleTranslation, DL, combinedInfo, mapData,
2911
- useDevicePtrVars, useDeviceAddrVars);
2912
- } else {
2913
- genMapInfos (builder, moduleTranslation, DL, combinedInfo, mapData);
2914
- }
2923
+ genMapInfos (builder, moduleTranslation, DL, combinedInfo, mapData);
2915
2924
return combinedInfo;
2916
2925
};
2917
2926
@@ -2930,21 +2939,23 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
2930
2939
if (!info.DevicePtrInfoMap .empty ()) {
2931
2940
builder.restoreIP (codeGenIP);
2932
2941
unsigned argIndex = 0 ;
2933
- for (auto &devPtrOp : useDevicePtrVars) {
2934
- llvm::Value *mapOpValue = moduleTranslation.lookupValue (devPtrOp);
2935
- const auto &arg = region.front ().getArgument (argIndex);
2936
- moduleTranslation.mapValue (arg,
2937
- info.DevicePtrInfoMap [mapOpValue].second );
2938
- argIndex++;
2939
- }
2940
-
2941
- for (auto &devAddrOp : useDeviceAddrVars) {
2942
- llvm::Value *mapOpValue = moduleTranslation.lookupValue (devAddrOp);
2943
- const auto &arg = region.front ().getArgument (argIndex);
2944
- auto *LI = builder.CreateLoad (
2945
- builder.getPtrTy (), info.DevicePtrInfoMap [mapOpValue].second );
2946
- moduleTranslation.mapValue (arg, LI);
2947
- argIndex++;
2942
+ for (size_t i = 0 ; i < combinedInfo.BasePointers .size (); ++i) {
2943
+ if (combinedInfo.DevicePointers [i] ==
2944
+ llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer) {
2945
+ const auto &arg = region.front ().getArgument (argIndex);
2946
+ moduleTranslation.mapValue (
2947
+ arg,
2948
+ info.DevicePtrInfoMap [combinedInfo.BasePointers [i]].second );
2949
+ argIndex++;
2950
+ } else if (combinedInfo.DevicePointers [i] ==
2951
+ llvm::OpenMPIRBuilder::DeviceInfoTy::Address) {
2952
+ const auto &arg = region.front ().getArgument (argIndex);
2953
+ auto *loadInst = builder.CreateLoad (
2954
+ builder.getPtrTy (),
2955
+ info.DevicePtrInfoMap [combinedInfo.BasePointers [i]].second );
2956
+ moduleTranslation.mapValue (arg, loadInst);
2957
+ argIndex++;
2958
+ }
2948
2959
}
2949
2960
2950
2961
bodyGenStatus = inlineConvertOmpRegions (region, " omp.data.region" ,
@@ -2957,6 +2968,21 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
2957
2968
// If device info is available then region has already been generated
2958
2969
if (info.DevicePtrInfoMap .empty ()) {
2959
2970
builder.restoreIP (codeGenIP);
2971
+ // For device pass, if use_device_ptr(addr) mappings were present,
2972
+ // we need to link them here before codegen.
2973
+ if (ompBuilder->Config .IsTargetDevice .value_or (false )) {
2974
+ unsigned argIndex = 0 ;
2975
+ for (size_t i = 0 ; i < mapData.BasePointers .size (); ++i) {
2976
+ if (mapData.DevicePointers [i] ==
2977
+ llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer ||
2978
+ mapData.DevicePointers [i] ==
2979
+ llvm::OpenMPIRBuilder::DeviceInfoTy::Address) {
2980
+ const auto &arg = region.front ().getArgument (argIndex);
2981
+ moduleTranslation.mapValue (arg, mapData.BasePointers [i]);
2982
+ argIndex++;
2983
+ }
2984
+ }
2985
+ }
2960
2986
bodyGenStatus = inlineConvertOmpRegions (region, " omp.data.region" ,
2961
2987
builder, moduleTranslation);
2962
2988
}
@@ -3292,14 +3318,14 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
3292
3318
findAllocaInsertPoint (builder, moduleTranslation);
3293
3319
3294
3320
MapInfoData mapData;
3295
- collectMapDataFromMapVars (mapData, mapVars, moduleTranslation, dl, builder);
3321
+ collectMapDataFromMapOperands (mapData, mapVars, moduleTranslation, dl,
3322
+ builder);
3296
3323
3297
3324
llvm::OpenMPIRBuilder::MapInfosTy combinedInfos;
3298
3325
auto genMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP)
3299
3326
-> llvm::OpenMPIRBuilder::MapInfosTy & {
3300
3327
builder.restoreIP (codeGenIP);
3301
- genMapInfos (builder, moduleTranslation, dl, combinedInfos, mapData, {}, {},
3302
- true );
3328
+ genMapInfos (builder, moduleTranslation, dl, combinedInfos, mapData, true );
3303
3329
return combinedInfos;
3304
3330
};
3305
3331
0 commit comments