Skip to content

Commit 0816681

Browse files
committed
[flang][OpenMP] Translate OpenMP scopes when compiling for target device
If a `target` directive is nested in a host OpenMP directive (e.g. parallel, task, or a worksharing loop), flang currently crashes if the target directive-related MLIR ops (e.g. `omp.map.bounds` and `omp.map.info` depends on SSA values defined inside the parent host OpenMP directives/ops. This PR tries to solve this problem by treating these parent OpenMP ops as "SSA scopes". Whenever we are translating for the device, instead of completely translating host ops, we just tranlate their MLIR ops as pure SSA values.
1 parent 72bb0a9 commit 0816681

File tree

2 files changed

+218
-9
lines changed

2 files changed

+218
-9
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 82 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,18 @@ static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
537537
llvm_unreachable("Unknown ClauseProcBindKind kind");
538538
}
539539

540+
/// Maps elements of \p blockArgs (which are MLIR values) to the corresponding
541+
/// LLVM values of \p operands' elements. This is useful when an OpenMP region
542+
/// with entry block arguments is converted to LLVM. In this case \p blockArgs
543+
/// are (part of) of the OpenMP region's entry arguments and \p operands are
544+
/// (part of) of the operands to the OpenMP op containing the region.
545+
static void forwardArgs(LLVM::ModuleTranslation &moduleTranslation,
546+
llvm::ArrayRef<BlockArgument> blockArgs,
547+
OperandRange operands) {
548+
for (auto [arg, var] : llvm::zip_equal(blockArgs, operands))
549+
moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var));
550+
}
551+
540552
/// Helper function to map block arguments defined by ignored loop wrappers to
541553
/// LLVM values and prevent any uses of those from triggering null pointer
542554
/// dereferences.
@@ -549,18 +561,12 @@ convertIgnoredWrapper(omp::LoopWrapperInterface opInst,
549561
// Map block arguments directly to the LLVM value associated to the
550562
// corresponding operand. This is semantically equivalent to this wrapper not
551563
// being present.
552-
auto forwardArgs =
553-
[&moduleTranslation](llvm::ArrayRef<BlockArgument> blockArgs,
554-
OperandRange operands) {
555-
for (auto [arg, var] : llvm::zip_equal(blockArgs, operands))
556-
moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var));
557-
};
558-
559564
return llvm::TypeSwitch<Operation *, LogicalResult>(opInst)
560565
.Case([&](omp::SimdOp op) {
561566
auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(*op);
562-
forwardArgs(blockArgIface.getPrivateBlockArgs(), op.getPrivateVars());
563-
forwardArgs(blockArgIface.getReductionBlockArgs(),
567+
forwardArgs(moduleTranslation, blockArgIface.getPrivateBlockArgs(),
568+
op.getPrivateVars());
569+
forwardArgs(moduleTranslation, blockArgIface.getReductionBlockArgs(),
564570
op.getReductionVars());
565571
op.emitWarning() << "simd information on composite construct discarded";
566572
return success();
@@ -5296,6 +5302,28 @@ convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder,
52965302
return convertHostOrTargetOperation(op, builder, moduleTranslation);
52975303
}
52985304

5305+
/// Forwards private entry block arguments, \see forwardArgs for more details.
5306+
template <typename OMPOp>
5307+
static void forwardPrivateArgs(OMPOp ompOp,
5308+
LLVM::ModuleTranslation &moduleTranslation) {
5309+
auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(*ompOp);
5310+
if (blockArgIface) {
5311+
forwardArgs(moduleTranslation, blockArgIface.getPrivateBlockArgs(),
5312+
ompOp.getPrivateVars());
5313+
}
5314+
}
5315+
5316+
/// Forwards reduction entry block arguments, \see forwardArgs for more details.
5317+
template <typename OMPOp>
5318+
static void forwardReductionArgs(OMPOp ompOp,
5319+
LLVM::ModuleTranslation &moduleTranslation) {
5320+
auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(*ompOp);
5321+
if (blockArgIface) {
5322+
forwardArgs(moduleTranslation, blockArgIface.getReductionBlockArgs(),
5323+
ompOp.getReductionVars());
5324+
}
5325+
}
5326+
52995327
static LogicalResult
53005328
convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder,
53015329
LLVM::ModuleTranslation &moduleTranslation) {
@@ -5315,6 +5343,51 @@ convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder,
53155343
return WalkResult::interrupt();
53165344
return WalkResult::skip();
53175345
}
5346+
5347+
// Non-target ops might nest target-related ops, therefore, we
5348+
// translate them as non-OpenMP scopes. Translating them is needed by
5349+
// nested target-related ops since they might LLVM values defined in
5350+
// their parent non-target ops.
5351+
if (isa<omp::OpenMPDialect>(oper->getDialect()) &&
5352+
oper->getParentOfType<LLVM::LLVMFuncOp>() &&
5353+
!oper->getRegions().empty()) {
5354+
5355+
// TODO Handle other ops with entry block args.
5356+
llvm::TypeSwitch<Operation &>(*oper)
5357+
.Case([&](omp::WsloopOp wsloopOp) {
5358+
forwardPrivateArgs(wsloopOp, moduleTranslation);
5359+
forwardReductionArgs(wsloopOp, moduleTranslation);
5360+
})
5361+
.Case([&](omp::ParallelOp parallelOp) {
5362+
forwardPrivateArgs(parallelOp, moduleTranslation);
5363+
forwardReductionArgs(parallelOp, moduleTranslation);
5364+
})
5365+
.Case([&](omp::TaskOp taskOp) {
5366+
forwardPrivateArgs(taskOp, moduleTranslation);
5367+
});
5368+
5369+
if (auto loopNest = dyn_cast<omp::LoopNestOp>(oper)) {
5370+
for (auto iv : loopNest.getIVs()) {
5371+
// Create fake allocas just to maintain IR validity.
5372+
moduleTranslation.mapValue(
5373+
iv, builder.CreateAlloca(
5374+
moduleTranslation.convertType(iv.getType())));
5375+
}
5376+
}
5377+
5378+
for (Region &region : oper->getRegions()) {
5379+
auto result = convertOmpOpRegions(
5380+
region, oper->getName().getStringRef().str() + ".fake.region",
5381+
builder, moduleTranslation);
5382+
if (failed(handleError(result, *oper)))
5383+
return WalkResult::interrupt();
5384+
5385+
builder.SetInsertPoint(result.get(), result.get()->end());
5386+
}
5387+
5388+
return WalkResult::skip();
5389+
}
5390+
53185391
return WalkResult::advance();
53195392
}).wasInterrupted();
53205393
return failure(interrupted);
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} {
4+
5+
omp.private {type = private} @i32_privatizer : i32
6+
7+
llvm.func @test_nested_target_in_parallel(%arg0: !llvm.ptr) {
8+
omp.parallel {
9+
%0 = llvm.mlir.constant(4 : index) : i64
10+
%1 = llvm.mlir.constant(1 : index) : i64
11+
%4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%0 : i64) stride(%1 : i64) start_idx(%1 : i64)
12+
%mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr {name = ""}
13+
omp.target map_entries(%mapv1 -> %map_arg : !llvm.ptr) {
14+
omp.terminator
15+
}
16+
omp.terminator
17+
}
18+
llvm.return
19+
}
20+
21+
// CHECK-LABEL: define void @test_nested_target_in_parallel({{.*}}) {
22+
// CHECK-NEXT: br label %omp.parallel.fake.region
23+
// CHECK: omp.parallel.fake.region:
24+
// CHECK-NEXT: br label %omp.region.cont
25+
// CHECK: omp.region.cont:
26+
// CHECK-NEXT: ret void
27+
// CHECK-NEXT: }
28+
29+
llvm.func @test_nested_target_in_wsloop(%arg0: !llvm.ptr) {
30+
%8 = llvm.mlir.constant(1 : i64) : i64
31+
%9 = llvm.alloca %8 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
32+
%16 = llvm.mlir.constant(10 : i32) : i32
33+
%17 = llvm.mlir.constant(1 : i32) : i32
34+
omp.wsloop private(@i32_privatizer %9 -> %loop_arg : !llvm.ptr) {
35+
omp.loop_nest (%arg1) : i32 = (%17) to (%16) inclusive step (%17) {
36+
llvm.store %arg1, %loop_arg : i32, !llvm.ptr
37+
%0 = llvm.mlir.constant(4 : index) : i64
38+
%1 = llvm.mlir.constant(1 : index) : i64
39+
%4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%0 : i64) stride(%1 : i64) start_idx(%1 : i64)
40+
%mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr {name = ""}
41+
omp.target map_entries(%mapv1 -> %map_arg : !llvm.ptr) {
42+
omp.terminator
43+
}
44+
omp.yield
45+
}
46+
}
47+
llvm.return
48+
}
49+
50+
// CHECK-LABEL: define void @test_nested_target_in_wsloop(ptr %0) {
51+
// CHECK-NEXT: %{{.*}} = alloca i32, i64 1, align 4
52+
// CHECK-NEXT: br label %omp.wsloop.fake.region
53+
// CHECK: omp.wsloop.fake.region:
54+
// CHECK-NEXT: %{{.*}} = alloca i32, align 4
55+
// CHECK-NEXT: br label %omp.loop_nest.fake.region
56+
// CHECK: omp.loop_nest.fake.region:
57+
// CHECK-NEXT: store ptr %3, ptr %2, align 8
58+
// CHECK-NEXT: br label %omp.region.cont1
59+
// CHECK: omp.region.cont1:
60+
// CHECK-NEXT: br label %omp.region.cont
61+
// CHECK: omp.region.cont:
62+
// CHECK-NEXT: ret void
63+
// CHECK-NEXT: }
64+
65+
llvm.func @test_nested_target_in_parallel_with_private(%arg0: !llvm.ptr) {
66+
%8 = llvm.mlir.constant(1 : i64) : i64
67+
%9 = llvm.alloca %8 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
68+
omp.parallel private(@i32_privatizer %9 -> %i_priv_arg : !llvm.ptr) {
69+
%1 = llvm.mlir.constant(1 : index) : i64
70+
// Use the private clause from omp.parallel to make sure block arguments
71+
// are handled.
72+
%i_val = llvm.load %i_priv_arg : !llvm.ptr -> i64
73+
%4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%i_val : i64) stride(%1 : i64) start_idx(%1 : i64)
74+
%mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr {name = ""}
75+
omp.target map_entries(%mapv1 -> %map_arg : !llvm.ptr) {
76+
omp.terminator
77+
}
78+
omp.terminator
79+
}
80+
llvm.return
81+
}
82+
83+
llvm.func @test_nested_target_in_task_with_private(%arg0: !llvm.ptr) {
84+
%8 = llvm.mlir.constant(1 : i64) : i64
85+
%9 = llvm.alloca %8 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
86+
omp.task private(@i32_privatizer %9 -> %i_priv_arg : !llvm.ptr) {
87+
%1 = llvm.mlir.constant(1 : index) : i64
88+
// Use the private clause from omp.task to make sure block arguments
89+
// are handled.
90+
%i_val = llvm.load %i_priv_arg : !llvm.ptr -> i64
91+
%4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%i_val : i64) stride(%1 : i64) start_idx(%1 : i64)
92+
%mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr {name = ""}
93+
omp.target map_entries(%mapv1 -> %map_arg : !llvm.ptr) {
94+
omp.terminator
95+
}
96+
omp.terminator
97+
}
98+
llvm.return
99+
}
100+
101+
// CHECK-LABEL: define void @test_nested_target_in_parallel_with_private({{.*}}) {
102+
// CHECK: br label %omp.parallel.fake.region
103+
// CHECK: omp.parallel.fake.region:
104+
// CHECK: br label %omp.region.cont
105+
// CHECK: omp.region.cont:
106+
// CHECK-NEXT: ret void
107+
// CHECK-NEXT: }
108+
109+
// CHECK-LABEL: define {{.*}} amdgpu_kernel void @__omp_offloading_{{.*}}_nested_target_in_parallel_{{.*}} {
110+
// CHECK: call i32 @__kmpc_target_init
111+
// CHECK: user_code.entry:
112+
// CHECK: call void @__kmpc_target_deinit()
113+
// CHECK: ret void
114+
// CHECK: }
115+
116+
// CHECK-LABEL: define {{.*}} amdgpu_kernel void @__omp_offloading_{{.*}}_test_nested_target_in_wsloop_{{.*}} {
117+
// CHECK: call i32 @__kmpc_target_init
118+
// CHECK: user_code.entry:
119+
// CHECK: call void @__kmpc_target_deinit()
120+
// CHECK: ret void
121+
// CHECK: }
122+
123+
// CHECK-LABEL: define {{.*}} amdgpu_kernel void @__omp_offloading_{{.*}}_test_nested_target_in_parallel_with_private_{{.*}} {
124+
// CHECK: call i32 @__kmpc_target_init
125+
// CHECK: user_code.entry:
126+
// CHECK: call void @__kmpc_target_deinit()
127+
// CHECK: ret void
128+
// CHECK: }
129+
130+
// CHECK-LABEL: define {{.*}} amdgpu_kernel void @__omp_offloading_{{.*}}_test_nested_target_in_task_with_private_{{.*}} {
131+
// CHECK: call i32 @__kmpc_target_init
132+
// CHECK: user_code.entry:
133+
// CHECK: call void @__kmpc_target_deinit()
134+
// CHECK: ret void
135+
// CHECK: }
136+
}

0 commit comments

Comments
 (0)