Skip to content

Commit a96f44c

Browse files
authored
Amd/dev/rlieberm/restore2flang commits (llvm#1288)
2 parents ab04a8a + 134e3f6 commit a96f44c

File tree

2 files changed

+228
-9
lines changed

2 files changed

+228
-9
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 72 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "llvm/Frontend/OpenMP/OMPConstants.h"
3131
#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
3232
#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
33+
#include "llvm/IR/Constants.h"
3334
#include "llvm/IR/DebugInfoMetadata.h"
3435
#include "llvm/IR/DerivedTypes.h"
3536
#include "llvm/IR/IRBuilder.h"
@@ -542,6 +543,20 @@ static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
542543
llvm_unreachable("Unknown ClauseProcBindKind kind");
543544
}
544545

546+
/// Maps block arguments from \p blockArgIface (which are MLIR values) to the
547+
/// corresponding LLVM values of \p the interface's operands. This is useful
548+
/// when an OpenMP region with entry block arguments is converted to LLVM. In
549+
/// this case the block arguments are (part of) of the OpenMP region's entry
550+
/// arguments and the operands are (part of) of the operands to the OpenMP op
551+
/// containing the region.
552+
static void forwardArgs(LLVM::ModuleTranslation &moduleTranslation,
553+
omp::BlockArgOpenMPOpInterface blockArgIface) {
554+
llvm::SmallVector<std::pair<Value, BlockArgument>> blockArgsPairs;
555+
blockArgIface.getBlockArgsPairs(blockArgsPairs);
556+
for (auto [var, arg] : blockArgsPairs)
557+
moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var));
558+
}
559+
545560
/// Helper function to map block arguments defined by ignored loop wrappers to
546561
/// LLVM values and prevent any uses of those from triggering null pointer
547562
/// dereferences.
@@ -554,17 +569,10 @@ convertIgnoredWrapper(omp::LoopWrapperInterface opInst,
554569
// Map block arguments directly to the LLVM value associated to the
555570
// corresponding operand. This is semantically equivalent to this wrapper not
556571
// being present.
557-
auto forwardArgs =
558-
[&moduleTranslation](omp::BlockArgOpenMPOpInterface blockArgIface) {
559-
llvm::SmallVector<std::pair<Value, BlockArgument>> blockArgsPairs;
560-
blockArgIface.getBlockArgsPairs(blockArgsPairs);
561-
for (auto [var, arg] : blockArgsPairs)
562-
moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var));
563-
};
564-
565572
return llvm::TypeSwitch<Operation *, LogicalResult>(opInst)
566573
.Case([&](omp::SimdOp op) {
567-
forwardArgs(cast<omp::BlockArgOpenMPOpInterface>(*op));
574+
forwardArgs(moduleTranslation,
575+
cast<omp::BlockArgOpenMPOpInterface>(*op));
568576
op.emitWarning() << "simd information on composite construct discarded";
569577
return success();
570578
})
@@ -5803,6 +5811,61 @@ convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder,
58035811
return WalkResult::interrupt();
58045812
return WalkResult::skip();
58055813
}
5814+
5815+
// Non-target ops might nest target-related ops, therefore, we
5816+
// translate them as non-OpenMP scopes. Translating them is needed by
5817+
// nested target-related ops since they might need LLVM values defined
5818+
// in their parent non-target ops.
5819+
if (isa<omp::OpenMPDialect>(oper->getDialect()) &&
5820+
oper->getParentOfType<LLVM::LLVMFuncOp>() &&
5821+
!oper->getRegions().empty()) {
5822+
if (auto blockArgsIface =
5823+
dyn_cast<omp::BlockArgOpenMPOpInterface>(oper))
5824+
forwardArgs(moduleTranslation, blockArgsIface);
5825+
else {
5826+
// Here we map entry block arguments of
5827+
// non-BlockArgOpenMPOpInterface ops if they can be encountered
5828+
// inside of a function and they define any of these arguments.
5829+
if (isa<mlir::omp::AtomicUpdateOp>(oper))
5830+
for (auto [operand, arg] :
5831+
llvm::zip_equal(oper->getOperands(),
5832+
oper->getRegion(0).getArguments())) {
5833+
moduleTranslation.mapValue(
5834+
arg, builder.CreateLoad(
5835+
moduleTranslation.convertType(arg.getType()),
5836+
moduleTranslation.lookupValue(operand)));
5837+
}
5838+
}
5839+
5840+
if (auto loopNest = dyn_cast<omp::LoopNestOp>(oper)) {
5841+
assert(builder.GetInsertBlock() &&
5842+
"No insert block is set for the builder");
5843+
for (auto iv : loopNest.getIVs()) {
5844+
// Map iv to an undefined value just to keep the IR validity.
5845+
moduleTranslation.mapValue(
5846+
iv, llvm::PoisonValue::get(
5847+
moduleTranslation.convertType(iv.getType())));
5848+
}
5849+
}
5850+
5851+
for (Region &region : oper->getRegions()) {
5852+
// Regions are fake in the sense that they are not a truthful
5853+
// translation of the OpenMP construct being converted (e.g. no
5854+
// OpenMP runtime calls will be generated). We just need this to
5855+
// prepare the kernel invocation args.
5856+
SmallVector<llvm::PHINode *> phis;
5857+
auto result = convertOmpOpRegions(
5858+
region, oper->getName().getStringRef().str() + ".fake.region",
5859+
builder, moduleTranslation, &phis);
5860+
if (failed(handleError(result, *oper)))
5861+
return WalkResult::interrupt();
5862+
5863+
builder.SetInsertPoint(result.get(), result.get()->end());
5864+
}
5865+
5866+
return WalkResult::skip();
5867+
}
5868+
58065869
return WalkResult::advance();
58075870
}).wasInterrupted();
58085871
return failure(interrupted);
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} {
4+
5+
omp.private {type = private} @i32_privatizer : i32
6+
7+
llvm.func @test_nested_target_in_parallel(%arg0: !llvm.ptr) {
8+
omp.parallel {
9+
%0 = llvm.mlir.constant(4 : index) : i64
10+
%1 = llvm.mlir.constant(1 : index) : i64
11+
%4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%0 : i64) stride(%1 : i64) start_idx(%1 : i64)
12+
%mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr {name = ""}
13+
omp.target map_entries(%mapv1 -> %map_arg : !llvm.ptr) {
14+
omp.terminator
15+
}
16+
omp.terminator
17+
}
18+
llvm.return
19+
}
20+
21+
// CHECK-LABEL: define void @test_nested_target_in_parallel({{.*}}) {
22+
// CHECK-NEXT: br label %omp.parallel.fake.region
23+
// CHECK: omp.parallel.fake.region:
24+
// CHECK-NEXT: br label %omp.region.cont
25+
// CHECK: omp.region.cont:
26+
// CHECK-NEXT: ret void
27+
// CHECK-NEXT: }
28+
29+
llvm.func @test_nested_target_in_wsloop(%arg0: !llvm.ptr) {
30+
%8 = llvm.mlir.constant(1 : i64) : i64
31+
%9 = llvm.alloca %8 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
32+
%16 = llvm.mlir.constant(10 : i32) : i32
33+
%17 = llvm.mlir.constant(1 : i32) : i32
34+
omp.wsloop private(@i32_privatizer %9 -> %loop_arg : !llvm.ptr) {
35+
omp.loop_nest (%arg1) : i32 = (%17) to (%16) inclusive step (%17) {
36+
llvm.store %arg1, %loop_arg : i32, !llvm.ptr
37+
%0 = llvm.mlir.constant(4 : index) : i64
38+
%1 = llvm.mlir.constant(1 : index) : i64
39+
%4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%0 : i64) stride(%1 : i64) start_idx(%1 : i64)
40+
%mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr {name = ""}
41+
omp.target map_entries(%mapv1 -> %map_arg : !llvm.ptr) {
42+
omp.terminator
43+
}
44+
omp.yield
45+
}
46+
}
47+
llvm.return
48+
}
49+
50+
// CHECK-LABEL: define void @test_nested_target_in_wsloop(ptr %0) {
51+
// CHECK-NEXT: %{{.*}} = alloca i32, i64 1, align 4
52+
// CHECK-NEXT: br label %omp.wsloop.fake.region
53+
// CHECK: omp.wsloop.fake.region:
54+
// CHECK-NEXT: br label %omp.loop_nest.fake.region
55+
// CHECK: omp.loop_nest.fake.region:
56+
// CHECK-NEXT: store i32 poison, ptr %{{.*}}
57+
// CHECK-NEXT: br label %omp.region.cont1
58+
// CHECK: omp.region.cont1:
59+
// CHECK-NEXT: br label %omp.region.cont
60+
// CHECK: omp.region.cont:
61+
// CHECK-NEXT: ret void
62+
// CHECK-NEXT: }
63+
64+
llvm.func @test_nested_target_in_parallel_with_private(%arg0: !llvm.ptr) {
65+
%8 = llvm.mlir.constant(1 : i64) : i64
66+
%9 = llvm.alloca %8 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
67+
omp.parallel private(@i32_privatizer %9 -> %i_priv_arg : !llvm.ptr) {
68+
%1 = llvm.mlir.constant(1 : index) : i64
69+
// Use the private clause from omp.parallel to make sure block arguments
70+
// are handled.
71+
%i_val = llvm.load %i_priv_arg : !llvm.ptr -> i64
72+
%4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%i_val : i64) stride(%1 : i64) start_idx(%1 : i64)
73+
%mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr {name = ""}
74+
omp.target map_entries(%mapv1 -> %map_arg : !llvm.ptr) {
75+
omp.terminator
76+
}
77+
omp.terminator
78+
}
79+
llvm.return
80+
}
81+
82+
llvm.func @test_nested_target_in_task_with_private(%arg0: !llvm.ptr) {
83+
%8 = llvm.mlir.constant(1 : i64) : i64
84+
%9 = llvm.alloca %8 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
85+
omp.task private(@i32_privatizer %9 -> %i_priv_arg : !llvm.ptr) {
86+
%1 = llvm.mlir.constant(1 : index) : i64
87+
// Use the private clause from omp.task to make sure block arguments
88+
// are handled.
89+
%i_val = llvm.load %i_priv_arg : !llvm.ptr -> i64
90+
%4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%i_val : i64) stride(%1 : i64) start_idx(%1 : i64)
91+
%mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr {name = ""}
92+
omp.target map_entries(%mapv1 -> %map_arg : !llvm.ptr) {
93+
omp.terminator
94+
}
95+
omp.terminator
96+
}
97+
llvm.return
98+
}
99+
100+
llvm.func @test_target_and_atomic_update(%x: !llvm.ptr, %expr : i32) {
101+
omp.target {
102+
omp.terminator
103+
}
104+
105+
omp.atomic.update %x : !llvm.ptr {
106+
^bb0(%xval: i32):
107+
%newval = llvm.add %xval, %expr : i32
108+
omp.yield(%newval : i32)
109+
}
110+
111+
llvm.return
112+
}
113+
114+
// CHECK-LABEL: define void @test_nested_target_in_parallel_with_private({{.*}}) {
115+
// CHECK: br label %omp.parallel.fake.region
116+
// CHECK: omp.parallel.fake.region:
117+
// CHECK: br label %omp.region.cont
118+
// CHECK: omp.region.cont:
119+
// CHECK-NEXT: ret void
120+
// CHECK-NEXT: }
121+
122+
// CHECK-LABEL: define {{.*}} amdgpu_kernel void @__omp_offloading_{{.*}}_nested_target_in_parallel_{{.*}} {
123+
// CHECK: call i32 @__kmpc_target_init
124+
// CHECK: user_code.entry:
125+
// CHECK: call void @__kmpc_target_deinit()
126+
// CHECK: ret void
127+
// CHECK: }
128+
129+
// CHECK-LABEL: define {{.*}} amdgpu_kernel void @__omp_offloading_{{.*}}_test_nested_target_in_wsloop_{{.*}} {
130+
// CHECK: call i32 @__kmpc_target_init
131+
// CHECK: user_code.entry:
132+
// CHECK: call void @__kmpc_target_deinit()
133+
// CHECK: ret void
134+
// CHECK: }
135+
136+
// CHECK-LABEL: define {{.*}} amdgpu_kernel void @__omp_offloading_{{.*}}_test_nested_target_in_parallel_with_private_{{.*}} {
137+
// CHECK: call i32 @__kmpc_target_init
138+
// CHECK: user_code.entry:
139+
// CHECK: call void @__kmpc_target_deinit()
140+
// CHECK: ret void
141+
// CHECK: }
142+
143+
// CHECK-LABEL: define {{.*}} amdgpu_kernel void @__omp_offloading_{{.*}}_test_nested_target_in_task_with_private_{{.*}} {
144+
// CHECK: call i32 @__kmpc_target_init
145+
// CHECK: user_code.entry:
146+
// CHECK: call void @__kmpc_target_deinit()
147+
// CHECK: ret void
148+
// CHECK: }
149+
150+
// CHECK-LABEL: define {{.*}} amdgpu_kernel void @__omp_offloading_{{.*}}_test_target_and_atomic_update_{{.*}} {
151+
// CHECK: call i32 @__kmpc_target_init
152+
// CHECK: user_code.entry:
153+
// CHECK: call void @__kmpc_target_deinit()
154+
// CHECK: ret void
155+
// CHECK: }
156+
}

0 commit comments

Comments
 (0)