Skip to content

Commit e737b84

Browse files
authored
[flang][OpenMP] Translate OpenMP scopes when compiling for target device (#130078)
If a `target` directive is nested in a host OpenMP directive (e.g. parallel, task, or a worksharing loop), flang currently crashes if the target directive-related MLIR ops (e.g. `omp.map.bounds` and `omp.map.info` depends on SSA values defined inside the parent host OpenMP directives/ops. This PR tries to solve this problem by treating these parent OpenMP ops as "SSA scopes". Whenever we are translating for the device, instead of completely translating host ops, we just tranlate their MLIR ops as pure SSA values.
1 parent e295f5d commit e737b84

File tree

2 files changed

+192
-9
lines changed

2 files changed

+192
-9
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "llvm/ADT/TypeSwitch.h"
3030
#include "llvm/Frontend/OpenMP/OMPConstants.h"
3131
#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
32+
#include "llvm/IR/Constants.h"
3233
#include "llvm/IR/DebugInfoMetadata.h"
3334
#include "llvm/IR/DerivedTypes.h"
3435
#include "llvm/IR/IRBuilder.h"
@@ -536,6 +537,20 @@ static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
536537
llvm_unreachable("Unknown ClauseProcBindKind kind");
537538
}
538539

540+
/// Maps block arguments from \p blockArgIface (which are MLIR values) to the
541+
/// corresponding LLVM values of \p the interface's operands. This is useful
542+
/// when an OpenMP region with entry block arguments is converted to LLVM. In
543+
/// this case the block arguments are (part of) of the OpenMP region's entry
544+
/// arguments and the operands are (part of) of the operands to the OpenMP op
545+
/// containing the region.
546+
static void forwardArgs(LLVM::ModuleTranslation &moduleTranslation,
547+
omp::BlockArgOpenMPOpInterface blockArgIface) {
548+
llvm::SmallVector<std::pair<Value, BlockArgument>> blockArgsPairs;
549+
blockArgIface.getBlockArgsPairs(blockArgsPairs);
550+
for (auto [var, arg] : blockArgsPairs)
551+
moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var));
552+
}
553+
539554
/// Helper function to map block arguments defined by ignored loop wrappers to
540555
/// LLVM values and prevent any uses of those from triggering null pointer
541556
/// dereferences.
@@ -548,17 +563,10 @@ convertIgnoredWrapper(omp::LoopWrapperInterface opInst,
548563
// Map block arguments directly to the LLVM value associated to the
549564
// corresponding operand. This is semantically equivalent to this wrapper not
550565
// being present.
551-
auto forwardArgs =
552-
[&moduleTranslation](omp::BlockArgOpenMPOpInterface blockArgIface) {
553-
llvm::SmallVector<std::pair<Value, BlockArgument>> blockArgsPairs;
554-
blockArgIface.getBlockArgsPairs(blockArgsPairs);
555-
for (auto [var, arg] : blockArgsPairs)
556-
moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var));
557-
};
558-
559566
return llvm::TypeSwitch<Operation *, LogicalResult>(opInst)
560567
.Case([&](omp::SimdOp op) {
561-
forwardArgs(cast<omp::BlockArgOpenMPOpInterface>(*op));
568+
forwardArgs(moduleTranslation,
569+
cast<omp::BlockArgOpenMPOpInterface>(*op));
562570
op.emitWarning() << "simd information on composite construct discarded";
563571
return success();
564572
})
@@ -5351,6 +5359,46 @@ convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder,
53515359
return WalkResult::interrupt();
53525360
return WalkResult::skip();
53535361
}
5362+
5363+
// Non-target ops might nest target-related ops, therefore, we
5364+
// translate them as non-OpenMP scopes. Translating them is needed by
5365+
// nested target-related ops since they might need LLVM values defined
5366+
// in their parent non-target ops.
5367+
if (isa<omp::OpenMPDialect>(oper->getDialect()) &&
5368+
oper->getParentOfType<LLVM::LLVMFuncOp>() &&
5369+
!oper->getRegions().empty()) {
5370+
if (auto blockArgsIface =
5371+
dyn_cast<omp::BlockArgOpenMPOpInterface>(oper))
5372+
forwardArgs(moduleTranslation, blockArgsIface);
5373+
5374+
if (auto loopNest = dyn_cast<omp::LoopNestOp>(oper)) {
5375+
assert(builder.GetInsertBlock() &&
5376+
"No insert block is set for the builder");
5377+
for (auto iv : loopNest.getIVs()) {
5378+
// Map iv to an undefined value just to keep the IR validity.
5379+
moduleTranslation.mapValue(
5380+
iv, llvm::PoisonValue::get(
5381+
moduleTranslation.convertType(iv.getType())));
5382+
}
5383+
}
5384+
5385+
for (Region &region : oper->getRegions()) {
5386+
// Regions are fake in the sense that they are not a truthful
5387+
// translation of the OpenMP construct being converted (e.g. no
5388+
// OpenMP runtime calls will be generated). We just need this to
5389+
// prepare the kernel invocation args.
5390+
auto result = convertOmpOpRegions(
5391+
region, oper->getName().getStringRef().str() + ".fake.region",
5392+
builder, moduleTranslation);
5393+
if (failed(handleError(result, *oper)))
5394+
return WalkResult::interrupt();
5395+
5396+
builder.SetInsertPoint(result.get(), result.get()->end());
5397+
}
5398+
5399+
return WalkResult::skip();
5400+
}
5401+
53545402
return WalkResult::advance();
53555403
}).wasInterrupted();
53565404
return failure(interrupted);
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} {
4+
5+
omp.private {type = private} @i32_privatizer : i32
6+
7+
llvm.func @test_nested_target_in_parallel(%arg0: !llvm.ptr) {
8+
omp.parallel {
9+
%0 = llvm.mlir.constant(4 : index) : i64
10+
%1 = llvm.mlir.constant(1 : index) : i64
11+
%4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%0 : i64) stride(%1 : i64) start_idx(%1 : i64)
12+
%mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr {name = ""}
13+
omp.target map_entries(%mapv1 -> %map_arg : !llvm.ptr) {
14+
omp.terminator
15+
}
16+
omp.terminator
17+
}
18+
llvm.return
19+
}
20+
21+
// CHECK-LABEL: define void @test_nested_target_in_parallel({{.*}}) {
22+
// CHECK-NEXT: br label %omp.parallel.fake.region
23+
// CHECK: omp.parallel.fake.region:
24+
// CHECK-NEXT: br label %omp.region.cont
25+
// CHECK: omp.region.cont:
26+
// CHECK-NEXT: ret void
27+
// CHECK-NEXT: }
28+
29+
llvm.func @test_nested_target_in_wsloop(%arg0: !llvm.ptr) {
30+
%8 = llvm.mlir.constant(1 : i64) : i64
31+
%9 = llvm.alloca %8 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
32+
%16 = llvm.mlir.constant(10 : i32) : i32
33+
%17 = llvm.mlir.constant(1 : i32) : i32
34+
omp.wsloop private(@i32_privatizer %9 -> %loop_arg : !llvm.ptr) {
35+
omp.loop_nest (%arg1) : i32 = (%17) to (%16) inclusive step (%17) {
36+
llvm.store %arg1, %loop_arg : i32, !llvm.ptr
37+
%0 = llvm.mlir.constant(4 : index) : i64
38+
%1 = llvm.mlir.constant(1 : index) : i64
39+
%4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%0 : i64) stride(%1 : i64) start_idx(%1 : i64)
40+
%mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr {name = ""}
41+
omp.target map_entries(%mapv1 -> %map_arg : !llvm.ptr) {
42+
omp.terminator
43+
}
44+
omp.yield
45+
}
46+
}
47+
llvm.return
48+
}
49+
50+
// CHECK-LABEL: define void @test_nested_target_in_wsloop(ptr %0) {
51+
// CHECK-NEXT: %{{.*}} = alloca i32, i64 1, align 4
52+
// CHECK-NEXT: br label %omp.wsloop.fake.region
53+
// CHECK: omp.wsloop.fake.region:
54+
// CHECK-NEXT: br label %omp.loop_nest.fake.region
55+
// CHECK: omp.loop_nest.fake.region:
56+
// CHECK-NEXT: store i32 poison, ptr %{{.*}}
57+
// CHECK-NEXT: br label %omp.region.cont1
58+
// CHECK: omp.region.cont1:
59+
// CHECK-NEXT: br label %omp.region.cont
60+
// CHECK: omp.region.cont:
61+
// CHECK-NEXT: ret void
62+
// CHECK-NEXT: }
63+
64+
llvm.func @test_nested_target_in_parallel_with_private(%arg0: !llvm.ptr) {
65+
%8 = llvm.mlir.constant(1 : i64) : i64
66+
%9 = llvm.alloca %8 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
67+
omp.parallel private(@i32_privatizer %9 -> %i_priv_arg : !llvm.ptr) {
68+
%1 = llvm.mlir.constant(1 : index) : i64
69+
// Use the private clause from omp.parallel to make sure block arguments
70+
// are handled.
71+
%i_val = llvm.load %i_priv_arg : !llvm.ptr -> i64
72+
%4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%i_val : i64) stride(%1 : i64) start_idx(%1 : i64)
73+
%mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr {name = ""}
74+
omp.target map_entries(%mapv1 -> %map_arg : !llvm.ptr) {
75+
omp.terminator
76+
}
77+
omp.terminator
78+
}
79+
llvm.return
80+
}
81+
82+
llvm.func @test_nested_target_in_task_with_private(%arg0: !llvm.ptr) {
83+
%8 = llvm.mlir.constant(1 : i64) : i64
84+
%9 = llvm.alloca %8 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
85+
omp.task private(@i32_privatizer %9 -> %i_priv_arg : !llvm.ptr) {
86+
%1 = llvm.mlir.constant(1 : index) : i64
87+
// Use the private clause from omp.task to make sure block arguments
88+
// are handled.
89+
%i_val = llvm.load %i_priv_arg : !llvm.ptr -> i64
90+
%4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%i_val : i64) stride(%1 : i64) start_idx(%1 : i64)
91+
%mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr {name = ""}
92+
omp.target map_entries(%mapv1 -> %map_arg : !llvm.ptr) {
93+
omp.terminator
94+
}
95+
omp.terminator
96+
}
97+
llvm.return
98+
}
99+
100+
// CHECK-LABEL: define void @test_nested_target_in_parallel_with_private({{.*}}) {
101+
// CHECK: br label %omp.parallel.fake.region
102+
// CHECK: omp.parallel.fake.region:
103+
// CHECK: br label %omp.region.cont
104+
// CHECK: omp.region.cont:
105+
// CHECK-NEXT: ret void
106+
// CHECK-NEXT: }
107+
108+
// CHECK-LABEL: define {{.*}} amdgpu_kernel void @__omp_offloading_{{.*}}_nested_target_in_parallel_{{.*}} {
109+
// CHECK: call i32 @__kmpc_target_init
110+
// CHECK: user_code.entry:
111+
// CHECK: call void @__kmpc_target_deinit()
112+
// CHECK: ret void
113+
// CHECK: }
114+
115+
// CHECK-LABEL: define {{.*}} amdgpu_kernel void @__omp_offloading_{{.*}}_test_nested_target_in_wsloop_{{.*}} {
116+
// CHECK: call i32 @__kmpc_target_init
117+
// CHECK: user_code.entry:
118+
// CHECK: call void @__kmpc_target_deinit()
119+
// CHECK: ret void
120+
// CHECK: }
121+
122+
// CHECK-LABEL: define {{.*}} amdgpu_kernel void @__omp_offloading_{{.*}}_test_nested_target_in_parallel_with_private_{{.*}} {
123+
// CHECK: call i32 @__kmpc_target_init
124+
// CHECK: user_code.entry:
125+
// CHECK: call void @__kmpc_target_deinit()
126+
// CHECK: ret void
127+
// CHECK: }
128+
129+
// CHECK-LABEL: define {{.*}} amdgpu_kernel void @__omp_offloading_{{.*}}_test_nested_target_in_task_with_private_{{.*}} {
130+
// CHECK: call i32 @__kmpc_target_init
131+
// CHECK: user_code.entry:
132+
// CHECK: call void @__kmpc_target_deinit()
133+
// CHECK: ret void
134+
// CHECK: }
135+
}

0 commit comments

Comments
 (0)