Skip to content

Commit bffc2aa

Browse files
committed
[flang][OpenMP] Translate OpenMP scopes when compiling for target device
If a `target` directive is nested in a host OpenMP directive (e.g. parallel, task, or a worksharing loop), flang currently crashes if the target directive-related MLIR ops (e.g. `omp.map.bounds` and `omp.map.info` depends on SSA values defined inside the parent host OpenMP directives/ops. This PR tries to solve this problem by treating these parent OpenMP ops as "SSA scopes". Whenever we are translating for the device, instead of completely translating host ops, we just tranlate their MLIR ops as pure SSA values.
1 parent e15545c commit bffc2aa

File tree

2 files changed

+218
-9
lines changed

2 files changed

+218
-9
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 82 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,18 @@ static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
542542
llvm_unreachable("Unknown ClauseProcBindKind kind");
543543
}
544544

545+
/// Maps elements of \p blockArgs (which are MLIR values) to the corresponding
546+
/// LLVM values of \p operands' elements. This is useful when an OpenMP region
547+
/// with entry block arguments is converted to LLVM. In this case \p blockArgs
548+
/// are (part of) of the OpenMP region's entry arguments and \p operands are
549+
/// (part of) of the operands to the OpenMP op containing the region.
550+
static void forwardArgs(LLVM::ModuleTranslation &moduleTranslation,
551+
llvm::ArrayRef<BlockArgument> blockArgs,
552+
OperandRange operands) {
553+
for (auto [arg, var] : llvm::zip_equal(blockArgs, operands))
554+
moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var));
555+
}
556+
545557
/// Helper function to map block arguments defined by ignored loop wrappers to
546558
/// LLVM values and prevent any uses of those from triggering null pointer
547559
/// dereferences.
@@ -554,18 +566,12 @@ convertIgnoredWrapper(omp::LoopWrapperInterface opInst,
554566
// Map block arguments directly to the LLVM value associated to the
555567
// corresponding operand. This is semantically equivalent to this wrapper not
556568
// being present.
557-
auto forwardArgs =
558-
[&moduleTranslation](llvm::ArrayRef<BlockArgument> blockArgs,
559-
OperandRange operands) {
560-
for (auto [arg, var] : llvm::zip_equal(blockArgs, operands))
561-
moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var));
562-
};
563-
564569
return llvm::TypeSwitch<Operation *, LogicalResult>(opInst)
565570
.Case([&](omp::SimdOp op) {
566571
auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(*op);
567-
forwardArgs(blockArgIface.getPrivateBlockArgs(), op.getPrivateVars());
568-
forwardArgs(blockArgIface.getReductionBlockArgs(),
572+
forwardArgs(moduleTranslation, blockArgIface.getPrivateBlockArgs(),
573+
op.getPrivateVars());
574+
forwardArgs(moduleTranslation, blockArgIface.getReductionBlockArgs(),
569575
op.getReductionVars());
570576
op.emitWarning() << "simd information on composite construct discarded";
571577
return success();
@@ -5236,6 +5242,28 @@ convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder,
52365242
return convertHostOrTargetOperation(op, builder, moduleTranslation);
52375243
}
52385244

5245+
/// Forwards private entry block arguments, \see forwardArgs for more details.
5246+
template <typename OMPOp>
5247+
static void forwardPrivateArgs(OMPOp ompOp,
5248+
LLVM::ModuleTranslation &moduleTranslation) {
5249+
auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(*ompOp);
5250+
if (blockArgIface) {
5251+
forwardArgs(moduleTranslation, blockArgIface.getPrivateBlockArgs(),
5252+
ompOp.getPrivateVars());
5253+
}
5254+
}
5255+
5256+
/// Forwards reduction entry block arguments, \see forwardArgs for more details.
5257+
template <typename OMPOp>
5258+
static void forwardReductionArgs(OMPOp ompOp,
5259+
LLVM::ModuleTranslation &moduleTranslation) {
5260+
auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(*ompOp);
5261+
if (blockArgIface) {
5262+
forwardArgs(moduleTranslation, blockArgIface.getReductionBlockArgs(),
5263+
ompOp.getReductionVars());
5264+
}
5265+
}
5266+
52395267
static LogicalResult
52405268
convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder,
52415269
LLVM::ModuleTranslation &moduleTranslation) {
@@ -5255,6 +5283,51 @@ convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder,
52555283
return WalkResult::interrupt();
52565284
return WalkResult::skip();
52575285
}
5286+
5287+
// Non-target ops might nest target-related ops, therefore, we
5288+
// translate them as non-OpenMP scopes. Translating them is needed by
5289+
// nested target-related ops since they might LLVM values defined in
5290+
// their parent non-target ops.
5291+
if (isa<omp::OpenMPDialect>(oper->getDialect()) &&
5292+
oper->getParentOfType<LLVM::LLVMFuncOp>() &&
5293+
!oper->getRegions().empty()) {
5294+
5295+
// TODO Handle other ops with entry block args.
5296+
llvm::TypeSwitch<Operation &>(*oper)
5297+
.Case([&](omp::WsloopOp wsloopOp) {
5298+
forwardPrivateArgs(wsloopOp, moduleTranslation);
5299+
forwardReductionArgs(wsloopOp, moduleTranslation);
5300+
})
5301+
.Case([&](omp::ParallelOp parallelOp) {
5302+
forwardPrivateArgs(parallelOp, moduleTranslation);
5303+
forwardReductionArgs(parallelOp, moduleTranslation);
5304+
})
5305+
.Case([&](omp::TaskOp taskOp) {
5306+
forwardPrivateArgs(taskOp, moduleTranslation);
5307+
});
5308+
5309+
if (auto loopNest = dyn_cast<omp::LoopNestOp>(oper)) {
5310+
for (auto iv : loopNest.getIVs()) {
5311+
// Create fake allocas just to maintain IR validity.
5312+
moduleTranslation.mapValue(
5313+
iv, builder.CreateAlloca(
5314+
moduleTranslation.convertType(iv.getType())));
5315+
}
5316+
}
5317+
5318+
for (Region &region : oper->getRegions()) {
5319+
auto result = convertOmpOpRegions(
5320+
region, oper->getName().getStringRef().str() + ".fake.region",
5321+
builder, moduleTranslation);
5322+
if (failed(handleError(result, *oper)))
5323+
return WalkResult::interrupt();
5324+
5325+
builder.SetInsertPoint(result.get(), result.get()->end());
5326+
}
5327+
5328+
return WalkResult::skip();
5329+
}
5330+
52585331
return WalkResult::advance();
52595332
}).wasInterrupted();
52605333
return failure(interrupted);
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} {
4+
5+
omp.private {type = private} @i32_privatizer : i32
6+
7+
llvm.func @test_nested_target_in_parallel(%arg0: !llvm.ptr) {
8+
omp.parallel {
9+
%0 = llvm.mlir.constant(4 : index) : i64
10+
%1 = llvm.mlir.constant(1 : index) : i64
11+
%4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%0 : i64) stride(%1 : i64) start_idx(%1 : i64)
12+
%mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr {name = ""}
13+
omp.target map_entries(%mapv1 -> %map_arg : !llvm.ptr) {
14+
omp.terminator
15+
}
16+
omp.terminator
17+
}
18+
llvm.return
19+
}
20+
21+
// CHECK-LABEL: define void @test_nested_target_in_parallel({{.*}}) {
22+
// CHECK-NEXT: br label %omp.parallel.fake.region
23+
// CHECK: omp.parallel.fake.region:
24+
// CHECK-NEXT: br label %omp.region.cont
25+
// CHECK: omp.region.cont:
26+
// CHECK-NEXT: ret void
27+
// CHECK-NEXT: }
28+
29+
llvm.func @test_nested_target_in_wsloop(%arg0: !llvm.ptr) {
30+
%8 = llvm.mlir.constant(1 : i64) : i64
31+
%9 = llvm.alloca %8 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
32+
%16 = llvm.mlir.constant(10 : i32) : i32
33+
%17 = llvm.mlir.constant(1 : i32) : i32
34+
omp.wsloop private(@i32_privatizer %9 -> %loop_arg : !llvm.ptr) {
35+
omp.loop_nest (%arg1) : i32 = (%17) to (%16) inclusive step (%17) {
36+
llvm.store %arg1, %loop_arg : i32, !llvm.ptr
37+
%0 = llvm.mlir.constant(4 : index) : i64
38+
%1 = llvm.mlir.constant(1 : index) : i64
39+
%4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%0 : i64) stride(%1 : i64) start_idx(%1 : i64)
40+
%mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr {name = ""}
41+
omp.target map_entries(%mapv1 -> %map_arg : !llvm.ptr) {
42+
omp.terminator
43+
}
44+
omp.yield
45+
}
46+
}
47+
llvm.return
48+
}
49+
50+
// CHECK-LABEL: define void @test_nested_target_in_wsloop(ptr %0) {
51+
// CHECK-NEXT: %{{.*}} = alloca i32, i64 1, align 4
52+
// CHECK-NEXT: br label %omp.wsloop.fake.region
53+
// CHECK: omp.wsloop.fake.region:
54+
// CHECK-NEXT: %{{.*}} = alloca i32, align 4
55+
// CHECK-NEXT: br label %omp.loop_nest.fake.region
56+
// CHECK: omp.loop_nest.fake.region:
57+
// CHECK-NEXT: store ptr %3, ptr %2, align 8
58+
// CHECK-NEXT: br label %omp.region.cont1
59+
// CHECK: omp.region.cont1:
60+
// CHECK-NEXT: br label %omp.region.cont
61+
// CHECK: omp.region.cont:
62+
// CHECK-NEXT: ret void
63+
// CHECK-NEXT: }
64+
65+
llvm.func @test_nested_target_in_parallel_with_private(%arg0: !llvm.ptr) {
66+
%8 = llvm.mlir.constant(1 : i64) : i64
67+
%9 = llvm.alloca %8 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
68+
omp.parallel private(@i32_privatizer %9 -> %i_priv_arg : !llvm.ptr) {
69+
%1 = llvm.mlir.constant(1 : index) : i64
70+
// Use the private clause from omp.parallel to make sure block arguments
71+
// are handled.
72+
%i_val = llvm.load %i_priv_arg : !llvm.ptr -> i64
73+
%4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%i_val : i64) stride(%1 : i64) start_idx(%1 : i64)
74+
%mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr {name = ""}
75+
omp.target map_entries(%mapv1 -> %map_arg : !llvm.ptr) {
76+
omp.terminator
77+
}
78+
omp.terminator
79+
}
80+
llvm.return
81+
}
82+
83+
llvm.func @test_nested_target_in_task_with_private(%arg0: !llvm.ptr) {
84+
%8 = llvm.mlir.constant(1 : i64) : i64
85+
%9 = llvm.alloca %8 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
86+
omp.task private(@i32_privatizer %9 -> %i_priv_arg : !llvm.ptr) {
87+
%1 = llvm.mlir.constant(1 : index) : i64
88+
// Use the private clause from omp.task to make sure block arguments
89+
// are handled.
90+
%i_val = llvm.load %i_priv_arg : !llvm.ptr -> i64
91+
%4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%i_val : i64) stride(%1 : i64) start_idx(%1 : i64)
92+
%mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr {name = ""}
93+
omp.target map_entries(%mapv1 -> %map_arg : !llvm.ptr) {
94+
omp.terminator
95+
}
96+
omp.terminator
97+
}
98+
llvm.return
99+
}
100+
101+
// CHECK-LABEL: define void @test_nested_target_in_parallel_with_private({{.*}}) {
102+
// CHECK: br label %omp.parallel.fake.region
103+
// CHECK: omp.parallel.fake.region:
104+
// CHECK: br label %omp.region.cont
105+
// CHECK: omp.region.cont:
106+
// CHECK-NEXT: ret void
107+
// CHECK-NEXT: }
108+
109+
// CHECK-LABEL: define {{.*}} amdgpu_kernel void @__omp_offloading_{{.*}}_nested_target_in_parallel_{{.*}} {
110+
// CHECK: call i32 @__kmpc_target_init
111+
// CHECK: user_code.entry:
112+
// CHECK: call void @__kmpc_target_deinit()
113+
// CHECK: ret void
114+
// CHECK: }
115+
116+
// CHECK-LABEL: define {{.*}} amdgpu_kernel void @__omp_offloading_{{.*}}_test_nested_target_in_wsloop_{{.*}} {
117+
// CHECK: call i32 @__kmpc_target_init
118+
// CHECK: user_code.entry:
119+
// CHECK: call void @__kmpc_target_deinit()
120+
// CHECK: ret void
121+
// CHECK: }
122+
123+
// CHECK-LABEL: define {{.*}} amdgpu_kernel void @__omp_offloading_{{.*}}_test_nested_target_in_parallel_with_private_{{.*}} {
124+
// CHECK: call i32 @__kmpc_target_init
125+
// CHECK: user_code.entry:
126+
// CHECK: call void @__kmpc_target_deinit()
127+
// CHECK: ret void
128+
// CHECK: }
129+
130+
// CHECK-LABEL: define {{.*}} amdgpu_kernel void @__omp_offloading_{{.*}}_test_nested_target_in_task_with_private_{{.*}} {
131+
// CHECK: call i32 @__kmpc_target_init
132+
// CHECK: user_code.entry:
133+
// CHECK: call void @__kmpc_target_deinit()
134+
// CHECK: ret void
135+
// CHECK: }
136+
}

0 commit comments

Comments
 (0)