Skip to content

Commit 33409d2

Browse files
committed
[MLIR][OpenMP] Support target SPMD
This patch implements MLIR to LLVM IR translation of host-evaluated loop bounds, completing initial support for `target teams distribute parallel do [simd]` and `target teams distribute [simd]`.
1 parent 5153e0d commit 33409d2

File tree

3 files changed

+159
-44
lines changed

3 files changed

+159
-44
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 63 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -176,15 +176,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
176176
if (op.getHint())
177177
op.emitWarning("hint clause discarded");
178178
};
179-
auto checkHostEval = [](auto op, LogicalResult &result) {
180-
// Host evaluated clauses are supported, except for loop bounds.
181-
for (BlockArgument arg :
182-
cast<omp::BlockArgOpenMPOpInterface>(*op).getHostEvalBlockArgs())
183-
for (Operation *user : arg.getUsers())
184-
if (isa<omp::LoopNestOp>(user))
185-
result = op.emitError("not yet implemented: host evaluation of loop "
186-
"bounds in omp.target operation");
187-
};
188179
auto checkInReduction = [&todo](auto op, LogicalResult &result) {
189180
if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
190181
op.getInReductionSyms())
@@ -321,7 +312,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
321312
checkBare(op, result);
322313
checkDevice(op, result);
323314
checkHasDeviceAddr(op, result);
324-
checkHostEval(op, result);
325315
checkInReduction(op, result);
326316
checkIsDevicePtr(op, result);
327317
checkPrivate(op, result);
@@ -4054,9 +4044,13 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
40544044
///
40554045
/// Loop bounds and steps are only optionally populated, if output vectors are
40564046
/// provided.
4057-
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
4058-
Value &numTeamsLower, Value &numTeamsUpper,
4059-
Value &threadLimit) {
4047+
static void
4048+
extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
4049+
Value &numTeamsLower, Value &numTeamsUpper,
4050+
Value &threadLimit,
4051+
llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
4052+
llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
4053+
llvm::SmallVectorImpl<Value> *steps = nullptr) {
40604054
auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
40614055
for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
40624056
blockArgIface.getHostEvalBlockArgs())) {
@@ -4081,11 +4075,26 @@ static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
40814075
llvm_unreachable("unsupported host_eval use");
40824076
})
40834077
.Case([&](omp::LoopNestOp loopOp) {
4084-
// TODO: Extract bounds and step values. Currently, this cannot be
4085-
// reached because translation would have been stopped earlier as a
4086-
// result of `checkImplementationStatus` detecting and reporting
4087-
// this situation.
4088-
llvm_unreachable("unsupported host_eval use");
4078+
auto processBounds =
4079+
[&](OperandRange opBounds,
4080+
llvm::SmallVectorImpl<Value> *outBounds) -> bool {
4081+
bool found = false;
4082+
for (auto [i, lb] : llvm::enumerate(opBounds)) {
4083+
if (lb == blockArg) {
4084+
found = true;
4085+
if (outBounds)
4086+
(*outBounds)[i] = hostEvalVar;
4087+
}
4088+
}
4089+
return found;
4090+
};
4091+
bool found =
4092+
processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
4093+
found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
4094+
found;
4095+
found = processBounds(loopOp.getLoopSteps(), steps) || found;
4096+
if (!found)
4097+
llvm_unreachable("unsupported host_eval use");
40894098
})
40904099
.Default([](Operation *) {
40914100
llvm_unreachable("unsupported host_eval use");
@@ -4222,6 +4231,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp,
42224231
combinedMaxThreadsVal = maxThreadsVal;
42234232

42244233
// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
4234+
attrs.ExecFlags = targetOp.getKernelExecFlags();
42254235
attrs.MinTeams = minTeamsVal;
42264236
attrs.MaxTeams.front() = maxTeamsVal;
42274237
attrs.MinThreads = 1;
@@ -4239,9 +4249,15 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
42394249
LLVM::ModuleTranslation &moduleTranslation,
42404250
omp::TargetOp targetOp,
42414251
llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
4252+
omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(
4253+
targetOp.getInnermostCapturedOmpOp());
4254+
unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
4255+
42424256
Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
4257+
llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
4258+
steps(numLoops);
42434259
extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
4244-
teamsThreadLimit);
4260+
teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
42454261

42464262
// TODO: Handle constant 'if' clauses.
42474263
if (Value targetThreadLimit = targetOp.getThreadLimit())
@@ -4261,7 +4277,34 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
42614277
if (numThreads)
42624278
attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
42634279

4264-
// TODO: Populate attrs.LoopTripCount if it is target SPMD.
4280+
if (targetOp.getKernelExecFlags() != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
4281+
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4282+
attrs.LoopTripCount = nullptr;
4283+
4284+
// To calculate the trip count, we multiply together the trip counts of
4285+
// every collapsed canonical loop. We don't need to create the loop nests
4286+
// here, since we're only interested in the trip count.
4287+
for (auto [loopLower, loopUpper, loopStep] :
4288+
llvm::zip_equal(lowerBounds, upperBounds, steps)) {
4289+
llvm::Value *lowerBound = moduleTranslation.lookupValue(loopLower);
4290+
llvm::Value *upperBound = moduleTranslation.lookupValue(loopUpper);
4291+
llvm::Value *step = moduleTranslation.lookupValue(loopStep);
4292+
4293+
llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4294+
llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
4295+
loc, lowerBound, upperBound, step, /*IsSigned=*/true,
4296+
loopOp.getLoopInclusive());
4297+
4298+
if (!attrs.LoopTripCount) {
4299+
attrs.LoopTripCount = tripCount;
4300+
continue;
4301+
}
4302+
4303+
// TODO: Enable UndefinedSanitizer to diagnose an overflow here.
4304+
attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
4305+
{}, /*HasNUW=*/true);
4306+
}
4307+
}
42654308
}
42664309

42674310
static LogicalResult
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// RUN: split-file %s %t
2+
// RUN: mlir-translate -mlir-to-llvmir %t/host.mlir | FileCheck %s --check-prefix=HOST
3+
// RUN: mlir-translate -mlir-to-llvmir %t/device.mlir | FileCheck %s --check-prefix=DEVICE
4+
5+
//--- host.mlir
6+
7+
module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
8+
llvm.func @main(%x : i32) {
9+
omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
10+
omp.teams {
11+
omp.parallel {
12+
omp.distribute {
13+
omp.wsloop {
14+
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
15+
omp.yield
16+
}
17+
} {omp.composite}
18+
} {omp.composite}
19+
omp.terminator
20+
} {omp.composite}
21+
omp.terminator
22+
}
23+
omp.terminator
24+
}
25+
llvm.return
26+
}
27+
}
28+
29+
// HOST-LABEL: define void @main
30+
// HOST: %omp_loop.tripcount = {{.*}}
31+
// HOST-NEXT: br label %[[ENTRY:.*]]
32+
// HOST: [[ENTRY]]:
33+
// HOST-NEXT: %[[TRIPCOUNT:.*]] = zext i32 %omp_loop.tripcount to i64
34+
// HOST: %[[TRIPCOUNT_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KARGS:.*]], i32 0, i32 8
35+
// HOST-NEXT: store i64 %[[TRIPCOUNT]], ptr %[[TRIPCOUNT_KARG]]
36+
// HOST: %[[RESULT:.*]] = call i32 @__tgt_target_kernel({{.*}}, ptr %[[KARGS]])
37+
// HOST-NEXT: %[[CMP:.*]] = icmp ne i32 %[[RESULT]], 0
38+
// HOST-NEXT: br i1 %[[CMP]], label %[[OFFLOAD_FAILED:.*]], label %{{.*}}
39+
// HOST: [[OFFLOAD_FAILED]]:
40+
// HOST: call void @[[TARGET_OUTLINE:.*]]({{.*}})
41+
42+
// HOST: define internal void @[[TARGET_OUTLINE]]
43+
// HOST: call void{{.*}}@__kmpc_fork_teams({{.*}}, ptr @[[TEAMS_OUTLINE:.*]], {{.*}})
44+
45+
// HOST: define internal void @[[TEAMS_OUTLINE]]
46+
// HOST: call void{{.*}}@__kmpc_fork_call({{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], {{.*}})
47+
48+
// HOST: define internal void @[[PARALLEL_OUTLINE]]
49+
// HOST: call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})
50+
51+
// HOST: define internal void @[[DISTRIBUTE_OUTLINE]]
52+
// HOST: call void @__kmpc_dist_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 34, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
53+
54+
//--- device.mlir
55+
56+
module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_target_device = true, omp.is_gpu = true} {
57+
llvm.func @main(%x : i32) {
58+
omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
59+
omp.teams {
60+
omp.parallel {
61+
omp.distribute {
62+
omp.wsloop {
63+
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
64+
omp.yield
65+
}
66+
} {omp.composite}
67+
} {omp.composite}
68+
omp.terminator
69+
} {omp.composite}
70+
omp.terminator
71+
}
72+
omp.terminator
73+
}
74+
llvm.return
75+
}
76+
}
77+
78+
// DEVICE: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 2
79+
// DEVICE: @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata"
80+
// DEVICE: @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy {
81+
// DEVICE-SAME: %struct.ConfigurationEnvironmentTy { i8 0, i8 1, i8 [[EXEC_MODE:2]], {{.*}}},
82+
// DEVICE-SAME: ptr @{{.*}}, ptr @{{.*}} }
83+
84+
// DEVICE: define weak_odr protected amdgpu_kernel void @[[KERNEL_NAME]]({{.*}})
85+
// DEVICE: %{{.*}} = call i32 @__kmpc_target_init(ptr @[[KERNEL_NAME]]_kernel_environment, {{.*}})
86+
// DEVICE: call void @[[TARGET_OUTLINE:.*]]({{.*}})
87+
// DEVICE: call void @__kmpc_target_deinit()
88+
89+
// DEVICE: define internal void @[[TARGET_OUTLINE]]({{.*}})
90+
// DEVICE: call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}})
91+
92+
// DEVICE: define internal void @[[PARALLEL_OUTLINE]]({{.*}})
93+
// DEVICE: call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})
94+
95+
// DEVICE: define internal void @[[DISTRIBUTE_OUTLINE]]({{.*}})
96+
// DEVICE: call void @__kmpc_distribute_for_static_loop{{.*}}({{.*}})

mlir/test/Target/LLVMIR/openmp-todo.mlir

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -319,30 +319,6 @@ llvm.func @target_has_device_addr(%x : !llvm.ptr) {
319319

320320
// -----
321321

322-
llvm.func @target_host_eval(%x : i32) {
323-
// expected-error@below {{not yet implemented: host evaluation of loop bounds in omp.target operation}}
324-
// expected-error@below {{LLVM Translation failed for operation: omp.target}}
325-
omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
326-
omp.teams {
327-
omp.parallel {
328-
omp.distribute {
329-
omp.wsloop {
330-
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
331-
omp.yield
332-
}
333-
} {omp.composite}
334-
} {omp.composite}
335-
omp.terminator
336-
} {omp.composite}
337-
omp.terminator
338-
}
339-
omp.terminator
340-
}
341-
llvm.return
342-
}
343-
344-
// -----
345-
346322
omp.declare_reduction @add_f32 : f32
347323
init {
348324
^bb0(%arg: f32):

0 commit comments

Comments
 (0)