@@ -2242,14 +2242,10 @@ static void collectMapDataFromMapOperands(
2242
2242
// types, i.e. member maps that can have member maps.
2243
2243
mapData.IsAMember .push_back (false );
2244
2244
for (Value mapValue : mapVars) {
2245
- if (auto map =
2246
- dyn_cast_if_present<omp::MapInfoOp>(mapValue.getDefiningOp ())) {
2247
- for (auto member : map.getMembers ()) {
2248
- if (member == mapOp) {
2249
- mapData.IsAMember .back () = true ;
2250
- }
2251
- }
2252
- }
2245
+ auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp ());
2246
+ for (auto member : map.getMembers ())
2247
+ if (member == mapOp)
2248
+ mapData.IsAMember .back () = true ;
2253
2249
}
2254
2250
}
2255
2251
@@ -2303,12 +2299,12 @@ static void collectMapDataFromMapOperands(
2303
2299
// TODO: May require some further additions to support nested record
2304
2300
// types, i.e. member maps that can have member maps.
2305
2301
mapData.IsAMember .push_back (false );
2306
- for (Value mapValue : useDevOperands)
2307
- if ( auto map =
2308
- dyn_cast_if_present<omp::MapInfoOp>(mapValue. getDefiningOp () ))
2309
- for ( auto member : map. getMembers () )
2310
- if (member == mapOp)
2311
- mapData. IsAMember . back () = true ;
2302
+ for (Value mapValue : useDevOperands) {
2303
+ auto map = cast<omp::MapInfoOp>(mapValue. getDefiningOp ());
2304
+ for ( auto member : map. getMembers ( ))
2305
+ if ( member == mapOp )
2306
+ mapData. IsAMember . back () = true ;
2307
+ }
2312
2308
}
2313
2309
}
2314
2310
};
@@ -2713,7 +2709,7 @@ createAlteredByCaptureMap(MapInfoData &mapData,
2713
2709
for (size_t i = 0 ; i < mapData.MapClause .size (); ++i) {
2714
2710
// if it's declare target, skip it, it's handled separately.
2715
2711
if (!mapData.IsDeclareTarget [i]) {
2716
- auto mapOp = dyn_cast_if_present <omp::MapInfoOp>(mapData.MapClause [i]);
2712
+ auto mapOp = cast <omp::MapInfoOp>(mapData.MapClause [i]);
2717
2713
omp::VariableCaptureKind captureKind =
2718
2714
mapOp.getMapCaptureType ().value_or (omp::VariableCaptureKind::ByRef);
2719
2715
bool isPtrTy = checkIfPointerMap (mapOp);
@@ -2935,20 +2931,18 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
2935
2931
if (!info.DevicePtrInfoMap .empty ()) {
2936
2932
builder.restoreIP (codeGenIP);
2937
2933
unsigned argIndex = 0 ;
2938
- for (size_t i = 0 ; i < combinedInfo. BasePointers . size (); ++i) {
2939
- if ( combinedInfo.DevicePointers [i] ==
2940
- llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer) {
2934
+ for (auto [basePointer, devicePointer] : llvm::zip_equal (
2935
+ combinedInfo.BasePointers , combinedInfo. DevicePointers )) {
2936
+ if (devicePointer == llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer) {
2941
2937
const auto &arg = region.front ().getArgument (argIndex);
2942
2938
moduleTranslation.mapValue (
2943
- arg,
2944
- info.DevicePtrInfoMap [combinedInfo.BasePointers [i]].second );
2939
+ arg, info.DevicePtrInfoMap [basePointer].second );
2945
2940
argIndex++;
2946
- } else if (combinedInfo. DevicePointers [i] ==
2941
+ } else if (devicePointer ==
2947
2942
llvm::OpenMPIRBuilder::DeviceInfoTy::Address) {
2948
2943
const auto &arg = region.front ().getArgument (argIndex);
2949
2944
auto *loadInst = builder.CreateLoad (
2950
- builder.getPtrTy (),
2951
- info.DevicePtrInfoMap [combinedInfo.BasePointers [i]].second );
2945
+ builder.getPtrTy (), info.DevicePtrInfoMap [basePointer].second );
2952
2946
moduleTranslation.mapValue (arg, loadInst);
2953
2947
argIndex++;
2954
2948
}
@@ -2968,13 +2962,12 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
2968
2962
// we need to link them here before codegen.
2969
2963
if (ompBuilder->Config .IsTargetDevice .value_or (false )) {
2970
2964
unsigned argIndex = 0 ;
2971
- for (size_t i = 0 ; i < mapData.BasePointers .size (); ++i) {
2972
- if (mapData.DevicePointers [i] ==
2973
- llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer ||
2974
- mapData.DevicePointers [i] ==
2975
- llvm::OpenMPIRBuilder::DeviceInfoTy::Address) {
2965
+ for (auto [basePointer, devicePointer] :
2966
+ llvm::zip_equal (mapData.BasePointers , mapData.DevicePointers )) {
2967
+ if (devicePointer == llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer ||
2968
+ devicePointer == llvm::OpenMPIRBuilder::DeviceInfoTy::Address) {
2976
2969
const auto &arg = region.front ().getArgument (argIndex);
2977
- moduleTranslation.mapValue (arg, mapData. BasePointers [i] );
2970
+ moduleTranslation.mapValue (arg, basePointer );
2978
2971
argIndex++;
2979
2972
}
2980
2973
}
@@ -3198,17 +3191,14 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
3198
3191
llvm::IRBuilderBase::InsertPoint codeGenIP) {
3199
3192
builder.restoreIP (allocaIP);
3200
3193
3201
- mlir::omp::VariableCaptureKind capture =
3202
- mlir::omp::VariableCaptureKind::ByRef;
3194
+ omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
3203
3195
3204
3196
// Find the associated MapInfoData entry for the current input
3205
3197
for (size_t i = 0 ; i < mapData.MapClause .size (); ++i)
3206
3198
if (mapData.OriginalValue [i] == input) {
3207
- if (auto mapOp = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
3208
- mapData.MapClause [i])) {
3209
- capture = mapOp.getMapCaptureType ().value_or (
3210
- mlir::omp::VariableCaptureKind::ByRef);
3211
- }
3199
+ auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause [i]);
3200
+ capture =
3201
+ mapOp.getMapCaptureType ().value_or (omp::VariableCaptureKind::ByRef);
3212
3202
3213
3203
break ;
3214
3204
}
@@ -3229,18 +3219,18 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
3229
3219
builder.restoreIP (codeGenIP);
3230
3220
3231
3221
switch (capture) {
3232
- case mlir:: omp::VariableCaptureKind::ByCopy: {
3222
+ case omp::VariableCaptureKind::ByCopy: {
3233
3223
retVal = v;
3234
3224
break ;
3235
3225
}
3236
- case mlir:: omp::VariableCaptureKind::ByRef: {
3226
+ case omp::VariableCaptureKind::ByRef: {
3237
3227
retVal = builder.CreateAlignedLoad (
3238
3228
v->getType (), v,
3239
3229
ompBuilder.M .getDataLayout ().getPrefTypeAlign (v->getType ()));
3240
3230
break ;
3241
3231
}
3242
- case mlir:: omp::VariableCaptureKind::This:
3243
- case mlir:: omp::VariableCaptureKind::VLAType:
3232
+ case omp::VariableCaptureKind::This:
3233
+ case omp::VariableCaptureKind::VLAType:
3244
3234
assert (false && " Currently unsupported capture kind" );
3245
3235
break ;
3246
3236
}
@@ -3292,8 +3282,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
3292
3282
builder.restoreIP (codeGenIP);
3293
3283
unsigned argIndex = 0 ;
3294
3284
for (auto &mapOp : mapVars) {
3295
- auto mapInfoOp =
3296
- mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp ());
3285
+ auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp ());
3297
3286
llvm::Value *mapOpValue =
3298
3287
moduleTranslation.lookupValue (mapInfoOp.getVarPtr ());
3299
3288
const auto &arg = targetRegion.front ().getArgument (argIndex);
0 commit comments