Skip to content

Commit a7119a1

Browse files
authored
[OpenMP][mlir] Add translation for if in omp.teams (#69404)
This patch adds translation for `if` clause on `teams` construct in OpenMP Dialect.
1 parent 28ae42e commit a7119a1

File tree

2 files changed

+57
-2
lines changed

2 files changed

+57
-2
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,7 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
666666
LLVM::ModuleTranslation &moduleTranslation) {
667667
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
668668
LogicalResult bodyGenStatus = success();
669-
if (op.getIfExpr() || !op.getAllocatorsVars().empty() || op.getReductions())
669+
if (!op.getAllocatorsVars().empty() || op.getReductions())
670670
return op.emitError("unhandled clauses for translation to LLVM IR");
671671

672672
auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
@@ -689,9 +689,13 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
689689
if (Value threadLimitVar = op.getThreadLimit())
690690
threadLimit = moduleTranslation.lookupValue(threadLimitVar);
691691

692+
llvm::Value *ifExpr = nullptr;
693+
if (Value ifExprVar = op.getIfExpr())
694+
ifExpr = moduleTranslation.lookupValue(ifExprVar);
695+
692696
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
693697
builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTeams(
694-
ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit));
698+
ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr));
695699
return bodyGenStatus;
696700
}
697701

mlir/test/Target/LLVMIR/openmp-teams.mlir

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,54 @@ llvm.func @omp_teams_num_teams_and_thread_limit(%numTeamsLower: i32, %numTeamsUp
235235
// CHECK: define internal void [[OUTLINED_FN]](ptr {{.+}}, ptr {{.+}})
236236
// CHECK: call void @duringTeams()
237237
// CHECK: ret void
238+
239+
// -----
240+
241+
llvm.func @beforeTeams()
242+
llvm.func @duringTeams()
243+
llvm.func @afterTeams()
244+
245+
// CHECK-LABEL: @teams_if
246+
// CHECK-SAME: (i1 [[ARG:.+]])
247+
llvm.func @teams_if(%arg : i1) {
248+
// CHECK-NEXT: call void @beforeTeams()
249+
llvm.call @beforeTeams() : () -> ()
250+
// If the condition is true, then the value of bounds is zero - which basically means "implementation-defined".
251+
// The runtime sees zero and sets a default value of number of teams. This behavior is according to the standard.
252+
// The same is true for `thread_limit`.
253+
// CHECK: [[NUM_TEAMS_UPPER:%.+]] = select i1 [[ARG]], i32 0, i32 1
254+
// CHECK: [[NUM_TEAMS_LOWER:%.+]] = select i1 [[ARG]], i32 0, i32 1
255+
// CHECK: call void @__kmpc_push_num_teams_51(ptr {{.+}}, i32 {{.+}}, i32 [[NUM_TEAMS_LOWER]], i32 [[NUM_TEAMS_UPPER]], i32 0)
256+
// CHECK: call void {{.+}} @__kmpc_fork_teams({{.+}})
257+
omp.teams if(%arg) {
258+
llvm.call @duringTeams() : () -> ()
259+
omp.terminator
260+
}
261+
// CHECK: call void @afterTeams()
262+
llvm.call @afterTeams() : () -> ()
263+
llvm.return
264+
}
265+
266+
// -----
267+
268+
llvm.func @beforeTeams()
269+
llvm.func @duringTeams()
270+
llvm.func @afterTeams()
271+
272+
// CHECK-LABEL: @teams_if_with_num_teams
273+
// CHECK-SAME: (i1 [[CONDITION:.+]], i32 [[NUM_TEAMS_LOWER:.+]], i32 [[NUM_TEAMS_UPPER:.+]], i32 [[THREAD_LIMIT:.+]])
274+
llvm.func @teams_if_with_num_teams(%condition: i1, %numTeamsLower: i32, %numTeamsUpper: i32, %threadLimit: i32) {
275+
// CHECK: call void @beforeTeams()
276+
llvm.call @beforeTeams() : () -> ()
277+
// CHECK: [[NUM_TEAMS_UPPER_NEW:%.+]] = select i1 [[CONDITION]], i32 [[NUM_TEAMS_UPPER]], i32 1
278+
// CHECK: [[NUM_TEAMS_LOWER_NEW:%.+]] = select i1 [[CONDITION]], i32 [[NUM_TEAMS_LOWER]], i32 1
279+
// CHECK: call void @__kmpc_push_num_teams_51(ptr {{.+}}, i32 {{.+}}, i32 [[NUM_TEAMS_LOWER_NEW]], i32 [[NUM_TEAMS_UPPER_NEW]], i32 [[THREAD_LIMIT]])
280+
// CHECK: call void {{.+}} @__kmpc_fork_teams({{.+}})
281+
omp.teams if(%condition) num_teams(%numTeamsLower: i32 to %numTeamsUpper: i32) thread_limit(%threadLimit: i32) {
282+
llvm.call @duringTeams() : () -> ()
283+
omp.terminator
284+
}
285+
// CHECK: call void @afterTeams()
286+
llvm.call @afterTeams() : () -> ()
287+
llvm.return
288+
}

0 commit comments

Comments
 (0)