Skip to content

Commit 9cb5045

Browse files
committed
Addressed reviewer comments.
1 parent 9f34645 commit 9cb5045

File tree

3 files changed

+47
-46
lines changed

3 files changed

+47
-46
lines changed

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6354,14 +6354,14 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
63546354
if (!updateToLocation(Loc))
63556355
return InsertPointTy();
63566356

6357+
Builder.restoreIP(CodeGenIP);
63576358
// Disable TargetData CodeGen on Device pass.
63586359
if (Config.IsTargetDevice.value_or(false)) {
63596360
if (BodyGenCB)
6360-
Builder.restoreIP(BodyGenCB(CodeGenIP, BodyGenTy::NoPriv));
6361+
Builder.restoreIP(BodyGenCB(Builder.saveIP(), BodyGenTy::NoPriv));
63616362
return Builder.saveIP();
63626363
}
63636364

6364-
Builder.restoreIP(CodeGenIP);
63656365
bool IsStandAlone = !BodyGenCB;
63666366
MapInfosTy *MapInfo;
63676367
// Generate the code for the opening of the data environment. Capture all the

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 31 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2242,14 +2242,10 @@ static void collectMapDataFromMapOperands(
22422242
// types, i.e. member maps that can have member maps.
22432243
mapData.IsAMember.push_back(false);
22442244
for (Value mapValue : mapVars) {
2245-
if (auto map =
2246-
dyn_cast_if_present<omp::MapInfoOp>(mapValue.getDefiningOp())) {
2247-
for (auto member : map.getMembers()) {
2248-
if (member == mapOp) {
2249-
mapData.IsAMember.back() = true;
2250-
}
2251-
}
2252-
}
2245+
auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
2246+
for (auto member : map.getMembers())
2247+
if (member == mapOp)
2248+
mapData.IsAMember.back() = true;
22532249
}
22542250
}
22552251

@@ -2303,12 +2299,12 @@ static void collectMapDataFromMapOperands(
23032299
// TODO: May require some further additions to support nested record
23042300
// types, i.e. member maps that can have member maps.
23052301
mapData.IsAMember.push_back(false);
2306-
for (Value mapValue : useDevOperands)
2307-
if (auto map =
2308-
dyn_cast_if_present<omp::MapInfoOp>(mapValue.getDefiningOp()))
2309-
for (auto member : map.getMembers())
2310-
if (member == mapOp)
2311-
mapData.IsAMember.back() = true;
2302+
for (Value mapValue : useDevOperands) {
2303+
auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
2304+
for (auto member : map.getMembers())
2305+
if (member == mapOp)
2306+
mapData.IsAMember.back() = true;
2307+
}
23122308
}
23132309
}
23142310
};
@@ -2713,7 +2709,7 @@ createAlteredByCaptureMap(MapInfoData &mapData,
27132709
for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
27142710
// if it's declare target, skip it, it's handled separately.
27152711
if (!mapData.IsDeclareTarget[i]) {
2716-
auto mapOp = dyn_cast_if_present<omp::MapInfoOp>(mapData.MapClause[i]);
2712+
auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
27172713
omp::VariableCaptureKind captureKind =
27182714
mapOp.getMapCaptureType().value_or(omp::VariableCaptureKind::ByRef);
27192715
bool isPtrTy = checkIfPointerMap(mapOp);
@@ -2935,20 +2931,18 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
29352931
if (!info.DevicePtrInfoMap.empty()) {
29362932
builder.restoreIP(codeGenIP);
29372933
unsigned argIndex = 0;
2938-
for (size_t i = 0; i < combinedInfo.BasePointers.size(); ++i) {
2939-
if (combinedInfo.DevicePointers[i] ==
2940-
llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer) {
2934+
for (auto [basePointer, devicePointer] : llvm::zip_equal(
2935+
combinedInfo.BasePointers, combinedInfo.DevicePointers)) {
2936+
if (devicePointer == llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer) {
29412937
const auto &arg = region.front().getArgument(argIndex);
29422938
moduleTranslation.mapValue(
2943-
arg,
2944-
info.DevicePtrInfoMap[combinedInfo.BasePointers[i]].second);
2939+
arg, info.DevicePtrInfoMap[basePointer].second);
29452940
argIndex++;
2946-
} else if (combinedInfo.DevicePointers[i] ==
2941+
} else if (devicePointer ==
29472942
llvm::OpenMPIRBuilder::DeviceInfoTy::Address) {
29482943
const auto &arg = region.front().getArgument(argIndex);
29492944
auto *loadInst = builder.CreateLoad(
2950-
builder.getPtrTy(),
2951-
info.DevicePtrInfoMap[combinedInfo.BasePointers[i]].second);
2945+
builder.getPtrTy(), info.DevicePtrInfoMap[basePointer].second);
29522946
moduleTranslation.mapValue(arg, loadInst);
29532947
argIndex++;
29542948
}
@@ -2968,13 +2962,12 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
29682962
// we need to link them here before codegen.
29692963
if (ompBuilder->Config.IsTargetDevice.value_or(false)) {
29702964
unsigned argIndex = 0;
2971-
for (size_t i = 0; i < mapData.BasePointers.size(); ++i) {
2972-
if (mapData.DevicePointers[i] ==
2973-
llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer ||
2974-
mapData.DevicePointers[i] ==
2975-
llvm::OpenMPIRBuilder::DeviceInfoTy::Address) {
2965+
for (auto [basePointer, devicePointer] :
2966+
llvm::zip_equal(mapData.BasePointers, mapData.DevicePointers)) {
2967+
if (devicePointer == llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer ||
2968+
devicePointer == llvm::OpenMPIRBuilder::DeviceInfoTy::Address) {
29762969
const auto &arg = region.front().getArgument(argIndex);
2977-
moduleTranslation.mapValue(arg, mapData.BasePointers[i]);
2970+
moduleTranslation.mapValue(arg, basePointer);
29782971
argIndex++;
29792972
}
29802973
}
@@ -3198,17 +3191,14 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
31983191
llvm::IRBuilderBase::InsertPoint codeGenIP) {
31993192
builder.restoreIP(allocaIP);
32003193

3201-
mlir::omp::VariableCaptureKind capture =
3202-
mlir::omp::VariableCaptureKind::ByRef;
3194+
omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
32033195

32043196
// Find the associated MapInfoData entry for the current input
32053197
for (size_t i = 0; i < mapData.MapClause.size(); ++i)
32063198
if (mapData.OriginalValue[i] == input) {
3207-
if (auto mapOp = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
3208-
mapData.MapClause[i])) {
3209-
capture = mapOp.getMapCaptureType().value_or(
3210-
mlir::omp::VariableCaptureKind::ByRef);
3211-
}
3199+
auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
3200+
capture =
3201+
mapOp.getMapCaptureType().value_or(omp::VariableCaptureKind::ByRef);
32123202

32133203
break;
32143204
}
@@ -3229,18 +3219,18 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
32293219
builder.restoreIP(codeGenIP);
32303220

32313221
switch (capture) {
3232-
case mlir::omp::VariableCaptureKind::ByCopy: {
3222+
case omp::VariableCaptureKind::ByCopy: {
32333223
retVal = v;
32343224
break;
32353225
}
3236-
case mlir::omp::VariableCaptureKind::ByRef: {
3226+
case omp::VariableCaptureKind::ByRef: {
32373227
retVal = builder.CreateAlignedLoad(
32383228
v->getType(), v,
32393229
ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
32403230
break;
32413231
}
3242-
case mlir::omp::VariableCaptureKind::This:
3243-
case mlir::omp::VariableCaptureKind::VLAType:
3232+
case omp::VariableCaptureKind::This:
3233+
case omp::VariableCaptureKind::VLAType:
32443234
assert(false && "Currently unsupported capture kind");
32453235
break;
32463236
}
@@ -3292,8 +3282,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
32923282
builder.restoreIP(codeGenIP);
32933283
unsigned argIndex = 0;
32943284
for (auto &mapOp : mapVars) {
3295-
auto mapInfoOp =
3296-
mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp());
3285+
auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
32973286
llvm::Value *mapOpValue =
32983287
moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
32993288
const auto &arg = targetRegion.front().getArgument(argIndex);

mlir/test/Target/LLVMIR/openmp-target-use-device-nested.mlir

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,20 @@
33
// This tests check that target code nested inside a target data region which
44
// has only use_device_ptr mapping corectly generates code on the device pass.
55

6-
// CHECK-NOT: call void @__tgt_target_data_begin_mapper
7-
// CHECK: store i32 999, ptr {{.*}}
6+
// CHECK: define weak_odr protected void @__omp_offloading{{.*}}main_
7+
// CHECK-NEXT: entry:
8+
// CHECK-NEXT: %[[VAL_3:.*]] = alloca ptr, align 8
9+
// CHECK-NEXT: store ptr %[[VAL_4:.*]], ptr %[[VAL_3]], align 8
10+
// CHECK-NEXT: %[[VAL_5:.*]] = call i32 @__kmpc_target_init(ptr @__omp_offloading_{{.*}}_kernel_environment, ptr %[[VAL_6:.*]])
11+
// CHECK-NEXT: %[[VAL_7:.*]] = icmp eq i32 %[[VAL_5]], -1
12+
// CHECK-NEXT: br i1 %[[VAL_7]], label %[[VAL_8:.*]], label %[[VAL_9:.*]]
13+
// CHECK: user_code.entry: ; preds = %[[VAL_10:.*]]
14+
// CHECK-NEXT: %[[VAL_11:.*]] = load ptr, ptr %[[VAL_3]], align 8
15+
// CHECK-NEXT: br label %[[VAL_12:.*]]
16+
// CHECK: omp.target: ; preds = %[[VAL_8]]
17+
// CHECK-NEXT: %[[VAL_13:.*]] = load ptr, ptr %[[VAL_11]], align 8
18+
// CHECK-NEXT: store i32 999, ptr %[[VAL_13]], align 4
19+
// CHECK-NEXT: br label %[[VAL_14:.*]]
820
module attributes {omp.is_target_device = true } {
921
llvm.func @_QQmain() attributes {fir.bindc_name = "main"} {
1022
%0 = llvm.mlir.constant(1 : i64) : i64

0 commit comments

Comments
 (0)