Skip to content

Commit 2b6b9c1

Browse files
committed
[flang][OpenMP] Translate OpenMP scopes when compiling for target device
If a `target` directive is nested in a host OpenMP directive (e.g. parallel, task, or a worksharing loop), flang currently crashes if the target directive-related MLIR ops (e.g. `omp.map.bounds` and `omp.map.info` depends on SSA values defined inside the parent host OpenMP directives/ops. This PR tries to solve this problem by treating these parent OpenMP ops as "SSA scopes". Whenever we are translating for the device, instead of completely translating host ops, we just tranlate their MLIR ops as pure SSA values.
1 parent 270790b commit 2b6b9c1

File tree

2 files changed

+186
-9
lines changed

2 files changed

+186
-9
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,19 @@ static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
537537
llvm_unreachable("Unknown ClauseProcBindKind kind");
538538
}
539539

540+
/// Maps elements of \p blockArgs (which are MLIR values) to the corresponding
541+
/// LLVM values of \p operands' elements. This is useful when an OpenMP region
542+
/// with entry block arguments is converted to LLVM. In this case \p blockArgs
543+
/// are (part of) of the OpenMP region's entry arguments and \p operands are
544+
/// (part of) of the operands to the OpenMP op containing the region.
545+
static void forwardArgs(LLVM::ModuleTranslation &moduleTranslation,
546+
omp::BlockArgOpenMPOpInterface blockArgIface) {
547+
llvm::SmallVector<std::pair<Value, BlockArgument>> blockArgsPairs;
548+
blockArgIface.getBlockArgsPairs(blockArgsPairs);
549+
for (auto [var, arg] : blockArgsPairs)
550+
moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var));
551+
}
552+
540553
/// Helper function to map block arguments defined by ignored loop wrappers to
541554
/// LLVM values and prevent any uses of those from triggering null pointer
542555
/// dereferences.
@@ -549,17 +562,10 @@ convertIgnoredWrapper(omp::LoopWrapperInterface opInst,
549562
// Map block arguments directly to the LLVM value associated to the
550563
// corresponding operand. This is semantically equivalent to this wrapper not
551564
// being present.
552-
auto forwardArgs =
553-
[&moduleTranslation](omp::BlockArgOpenMPOpInterface blockArgIface) {
554-
llvm::SmallVector<std::pair<Value, BlockArgument>> blockArgsPairs;
555-
blockArgIface.getBlockArgsPairs(blockArgsPairs);
556-
for (auto [var, arg] : blockArgsPairs)
557-
moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var));
558-
};
559-
560565
return llvm::TypeSwitch<Operation *, LogicalResult>(opInst)
561566
.Case([&](omp::SimdOp op) {
562-
forwardArgs(cast<omp::BlockArgOpenMPOpInterface>(*op));
567+
forwardArgs(moduleTranslation,
568+
cast<omp::BlockArgOpenMPOpInterface>(*op));
563569
op.emitWarning() << "simd information on composite construct discarded";
564570
return success();
565571
})
@@ -5294,6 +5300,7 @@ convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder,
52945300
return convertHostOrTargetOperation(op, builder, moduleTranslation);
52955301
}
52965302

5303+
52975304
static LogicalResult
52985305
convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder,
52995306
LLVM::ModuleTranslation &moduleTranslation) {
@@ -5313,6 +5320,40 @@ convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder,
53135320
return WalkResult::interrupt();
53145321
return WalkResult::skip();
53155322
}
5323+
5324+
// Non-target ops might nest target-related ops, therefore, we
5325+
// translate them as non-OpenMP scopes. Translating them is needed by
5326+
// nested target-related ops since they might LLVM values defined in
5327+
// their parent non-target ops.
5328+
if (isa<omp::OpenMPDialect>(oper->getDialect()) &&
5329+
oper->getParentOfType<LLVM::LLVMFuncOp>() &&
5330+
!oper->getRegions().empty()) {
5331+
if (auto blockArgsIface =
5332+
dyn_cast<omp::BlockArgOpenMPOpInterface>(oper))
5333+
forwardArgs(moduleTranslation, blockArgsIface);
5334+
5335+
if (auto loopNest = dyn_cast<omp::LoopNestOp>(oper)) {
5336+
for (auto iv : loopNest.getIVs()) {
5337+
// Create fake allocas just to maintain IR validity.
5338+
moduleTranslation.mapValue(
5339+
iv, builder.CreateAlloca(
5340+
moduleTranslation.convertType(iv.getType())));
5341+
}
5342+
}
5343+
5344+
for (Region &region : oper->getRegions()) {
5345+
auto result = convertOmpOpRegions(
5346+
region, oper->getName().getStringRef().str() + ".fake.region",
5347+
builder, moduleTranslation);
5348+
if (failed(handleError(result, *oper)))
5349+
return WalkResult::interrupt();
5350+
5351+
builder.SetInsertPoint(result.get(), result.get()->end());
5352+
}
5353+
5354+
return WalkResult::skip();
5355+
}
5356+
53165357
return WalkResult::advance();
53175358
}).wasInterrupted();
53185359
return failure(interrupted);
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} {
4+
5+
omp.private {type = private} @i32_privatizer : i32
6+
7+
llvm.func @test_nested_target_in_parallel(%arg0: !llvm.ptr) {
8+
omp.parallel {
9+
%0 = llvm.mlir.constant(4 : index) : i64
10+
%1 = llvm.mlir.constant(1 : index) : i64
11+
%4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%0 : i64) stride(%1 : i64) start_idx(%1 : i64)
12+
%mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr {name = ""}
13+
omp.target map_entries(%mapv1 -> %map_arg : !llvm.ptr) {
14+
omp.terminator
15+
}
16+
omp.terminator
17+
}
18+
llvm.return
19+
}
20+
21+
// CHECK-LABEL: define void @test_nested_target_in_parallel({{.*}}) {
22+
// CHECK-NEXT: br label %omp.parallel.fake.region
23+
// CHECK: omp.parallel.fake.region:
24+
// CHECK-NEXT: br label %omp.region.cont
25+
// CHECK: omp.region.cont:
26+
// CHECK-NEXT: ret void
27+
// CHECK-NEXT: }
28+
29+
llvm.func @test_nested_target_in_wsloop(%arg0: !llvm.ptr) {
30+
%8 = llvm.mlir.constant(1 : i64) : i64
31+
%9 = llvm.alloca %8 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
32+
%16 = llvm.mlir.constant(10 : i32) : i32
33+
%17 = llvm.mlir.constant(1 : i32) : i32
34+
omp.wsloop private(@i32_privatizer %9 -> %loop_arg : !llvm.ptr) {
35+
omp.loop_nest (%arg1) : i32 = (%17) to (%16) inclusive step (%17) {
36+
llvm.store %arg1, %loop_arg : i32, !llvm.ptr
37+
%0 = llvm.mlir.constant(4 : index) : i64
38+
%1 = llvm.mlir.constant(1 : index) : i64
39+
%4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%0 : i64) stride(%1 : i64) start_idx(%1 : i64)
40+
%mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr {name = ""}
41+
omp.target map_entries(%mapv1 -> %map_arg : !llvm.ptr) {
42+
omp.terminator
43+
}
44+
omp.yield
45+
}
46+
}
47+
llvm.return
48+
}
49+
50+
// CHECK-LABEL: define void @test_nested_target_in_wsloop(ptr %0) {
51+
// CHECK-NEXT: %{{.*}} = alloca i32, i64 1, align 4
52+
// CHECK-NEXT: br label %omp.wsloop.fake.region
53+
// CHECK: omp.wsloop.fake.region:
54+
// CHECK-NEXT: %{{.*}} = alloca i32, align 4
55+
// CHECK-NEXT: br label %omp.loop_nest.fake.region
56+
// CHECK: omp.loop_nest.fake.region:
57+
// CHECK-NEXT: store ptr %3, ptr %2, align 8
58+
// CHECK-NEXT: br label %omp.region.cont1
59+
// CHECK: omp.region.cont1:
60+
// CHECK-NEXT: br label %omp.region.cont
61+
// CHECK: omp.region.cont:
62+
// CHECK-NEXT: ret void
63+
// CHECK-NEXT: }
64+
65+
llvm.func @test_nested_target_in_parallel_with_private(%arg0: !llvm.ptr) {
66+
%8 = llvm.mlir.constant(1 : i64) : i64
67+
%9 = llvm.alloca %8 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
68+
omp.parallel private(@i32_privatizer %9 -> %i_priv_arg : !llvm.ptr) {
69+
%1 = llvm.mlir.constant(1 : index) : i64
70+
// Use the private clause from omp.parallel to make sure block arguments
71+
// are handled.
72+
%i_val = llvm.load %i_priv_arg : !llvm.ptr -> i64
73+
%4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%i_val : i64) stride(%1 : i64) start_idx(%1 : i64)
74+
%mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr {name = ""}
75+
omp.target map_entries(%mapv1 -> %map_arg : !llvm.ptr) {
76+
omp.terminator
77+
}
78+
omp.terminator
79+
}
80+
llvm.return
81+
}
82+
83+
llvm.func @test_nested_target_in_task_with_private(%arg0: !llvm.ptr) {
84+
%8 = llvm.mlir.constant(1 : i64) : i64
85+
%9 = llvm.alloca %8 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
86+
omp.task private(@i32_privatizer %9 -> %i_priv_arg : !llvm.ptr) {
87+
%1 = llvm.mlir.constant(1 : index) : i64
88+
// Use the private clause from omp.task to make sure block arguments
89+
// are handled.
90+
%i_val = llvm.load %i_priv_arg : !llvm.ptr -> i64
91+
%4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%i_val : i64) stride(%1 : i64) start_idx(%1 : i64)
92+
%mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr {name = ""}
93+
omp.target map_entries(%mapv1 -> %map_arg : !llvm.ptr) {
94+
omp.terminator
95+
}
96+
omp.terminator
97+
}
98+
llvm.return
99+
}
100+
101+
// CHECK-LABEL: define void @test_nested_target_in_parallel_with_private({{.*}}) {
102+
// CHECK: br label %omp.parallel.fake.region
103+
// CHECK: omp.parallel.fake.region:
104+
// CHECK: br label %omp.region.cont
105+
// CHECK: omp.region.cont:
106+
// CHECK-NEXT: ret void
107+
// CHECK-NEXT: }
108+
109+
// CHECK-LABEL: define {{.*}} amdgpu_kernel void @__omp_offloading_{{.*}}_nested_target_in_parallel_{{.*}} {
110+
// CHECK: call i32 @__kmpc_target_init
111+
// CHECK: user_code.entry:
112+
// CHECK: call void @__kmpc_target_deinit()
113+
// CHECK: ret void
114+
// CHECK: }
115+
116+
// CHECK-LABEL: define {{.*}} amdgpu_kernel void @__omp_offloading_{{.*}}_test_nested_target_in_wsloop_{{.*}} {
117+
// CHECK: call i32 @__kmpc_target_init
118+
// CHECK: user_code.entry:
119+
// CHECK: call void @__kmpc_target_deinit()
120+
// CHECK: ret void
121+
// CHECK: }
122+
123+
// CHECK-LABEL: define {{.*}} amdgpu_kernel void @__omp_offloading_{{.*}}_test_nested_target_in_parallel_with_private_{{.*}} {
124+
// CHECK: call i32 @__kmpc_target_init
125+
// CHECK: user_code.entry:
126+
// CHECK: call void @__kmpc_target_deinit()
127+
// CHECK: ret void
128+
// CHECK: }
129+
130+
// CHECK-LABEL: define {{.*}} amdgpu_kernel void @__omp_offloading_{{.*}}_test_nested_target_in_task_with_private_{{.*}} {
131+
// CHECK: call i32 @__kmpc_target_init
132+
// CHECK: user_code.entry:
133+
// CHECK: call void @__kmpc_target_deinit()
134+
// CHECK: ret void
135+
// CHECK: }
136+
}

0 commit comments

Comments
 (0)