Skip to content

Commit f7d4f86

Browse files
authored
[mlir][OpenMP] Added translation for omp.teams to LLVM IR (#68042)
This patch adds translation from `omp.teams` operation to LLVM IR using OpenMPIRBuilder. The clauses are not handled in this patch.
1 parent 1cfaa86 commit f7d4f86

File tree

2 files changed

+164
-0
lines changed

2 files changed

+164
-0
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,31 @@ convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
661661
return bodyGenStatus;
662662
}
663663

664+
// Convert an OpenMP Teams construct to LLVM IR using OpenMPIRBuilder
665+
static LogicalResult
666+
convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
667+
LLVM::ModuleTranslation &moduleTranslation) {
668+
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
669+
LogicalResult bodyGenStatus = success();
670+
if (op.getNumTeamsLower() || op.getNumTeamsUpper() || op.getIfExpr() ||
671+
op.getThreadLimit() || !op.getAllocatorsVars().empty() ||
672+
op.getReductions()) {
673+
return op.emitError("unhandled clauses for translation to LLVM IR");
674+
}
675+
auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
676+
LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
677+
moduleTranslation, allocaIP);
678+
builder.restoreIP(codegenIP);
679+
convertOmpOpRegions(op.getRegion(), "omp.teams.region", builder,
680+
moduleTranslation, bodyGenStatus);
681+
};
682+
683+
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
684+
builder.restoreIP(
685+
moduleTranslation.getOpenMPBuilder()->createTeams(ompLoc, bodyCB));
686+
return bodyGenStatus;
687+
}
688+
664689
/// Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
665690
static LogicalResult
666691
convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
@@ -2397,6 +2422,9 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
23972422
.Case([&](omp::SingleOp op) {
23982423
return convertOmpSingle(op, builder, moduleTranslation);
23992424
})
2425+
.Case([&](omp::TeamsOp op) {
2426+
return convertOmpTeams(op, builder, moduleTranslation);
2427+
})
24002428
.Case([&](omp::TaskOp op) {
24012429
return convertOmpTaskOp(op, builder, moduleTranslation);
24022430
})
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
2+
3+
llvm.func @foo()
4+
5+
// CHECK-LABEL: @omp_teams_simple
6+
// CHECK: call void {{.*}} @__kmpc_fork_teams(ptr @{{.+}}, i32 0, ptr [[WRAPPER_FN:.+]])
7+
// CHECK: ret void
8+
llvm.func @omp_teams_simple() {
9+
omp.teams {
10+
llvm.call @foo() : () -> ()
11+
omp.terminator
12+
}
13+
llvm.return
14+
}
15+
16+
// CHECK: define internal void @[[OUTLINED_FN:.+]]()
17+
// CHECK: call void @foo()
18+
// CHECK: ret void
19+
// CHECK: define void [[WRAPPER_FN]](ptr {{.+}}, ptr {{.+}})
20+
// CHECK: call void @[[OUTLINED_FN]]
21+
// CHECK: ret void
22+
23+
// -----
24+
25+
llvm.func @foo(i32) -> ()
26+
27+
// CHECK-LABEL: @omp_teams_shared_simple
28+
// CHECK-SAME: (i32 [[ARG0:%.+]])
29+
// CHECK: [[STRUCT_ARG:%.+]] = alloca { i32 }
30+
// CHECK: br
31+
// CHECK: [[GEP:%.+]] = getelementptr { i32 }, ptr [[STRUCT_ARG]], i32 0, i32 0
32+
// CHECK: store i32 [[ARG0]], ptr [[GEP]]
33+
// CHECK: call void {{.+}} @__kmpc_fork_teams(ptr @{{.+}}, i32 1, ptr [[WRAPPER_FN:.+]], ptr [[STRUCT_ARG]])
34+
// CHECK: ret void
35+
llvm.func @omp_teams_shared_simple(%arg0: i32) {
36+
omp.teams {
37+
llvm.call @foo(%arg0) : (i32) -> ()
38+
omp.terminator
39+
}
40+
llvm.return
41+
}
42+
43+
// CHECK: define internal void [[OUTLINED_FN:@.+]](ptr [[STRUCT_ARG:%.+]])
44+
// CHECK: [[GEP:%.+]] = getelementptr { i32 }, ptr [[STRUCT_ARG]], i32 0, i32 0
45+
// CHECK: [[LOAD_GEP:%.+]] = load i32, ptr [[GEP]]
46+
// CHECK: call void @foo(i32 [[LOAD_GEP]])
47+
// CHECK: ret void
48+
// CHECK: define void [[WRAPPER_FN]](ptr {{.+}}, ptr {{.+}}, ptr [[STRUCT_ARG:.+]])
49+
// CHECK: call void [[OUTLINED_FN]](ptr [[STRUCT_ARG]])
50+
// CHECK: ret void
51+
52+
// -----
53+
54+
llvm.func @my_alloca_fn() -> !llvm.ptr<i32>
55+
llvm.func @foo(i32, f32, !llvm.ptr<i32>, f128, !llvm.ptr<i32>, i32) -> ()
56+
llvm.func @bar()
57+
58+
// CHECK-LABEL: @omp_teams_branching_shared
59+
// CHECK-SAME: (i1 [[CONDITION:%.+]], i32 [[ARG0:%.+]], float [[ARG1:%.+]], ptr [[ARG2:%.+]], fp128 [[ARG3:%.+]])
60+
61+
// Checking that the allocation for struct argument happens in the alloca block.
62+
// CHECK: [[STRUCT_ARG:%.+]] = alloca { i1, i32, float, ptr, fp128, ptr, i32 }
63+
// CHECK: [[ALLOCATED:%.+]] = call ptr @my_alloca_fn()
64+
// CHECK: [[LOADED:%.+]] = load i32, ptr [[ALLOCATED]]
65+
// CHECK: br label
66+
67+
// Checking that the shared values are stored properly in the struct arg.
68+
// CHECK: [[CONDITION_PTR:%.+]] = getelementptr {{.+}}, ptr [[STRUCT_ARG]]
69+
// CHECK: store i1 [[CONDITION]], ptr [[CONDITION_PTR]]
70+
// CHECK: [[ARG0_PTR:%.+]] = getelementptr {{.+}}, ptr [[STRUCT_ARG]], i32 0, i32 1
71+
// CHECK: store i32 [[ARG0]], ptr [[ARG0_PTR]]
72+
// CHECK: [[ARG1_PTR:%.+]] = getelementptr {{.+}}, ptr [[STRUCT_ARG]], i32 0, i32 2
73+
// CHECK: store float [[ARG1]], ptr [[ARG1_PTR]]
74+
// CHECK: [[ARG2_PTR:%.+]] = getelementptr {{.+}}, ptr [[STRUCT_ARG]], i32 0, i32 3
75+
// CHECK: store ptr [[ARG2]], ptr [[ARG2_PTR]]
76+
// CHECK: [[ARG3_PTR:%.+]] = getelementptr {{.+}}, ptr [[STRUCT_ARG]], i32 0, i32 4
77+
// CHECK: store fp128 [[ARG3]], ptr [[ARG3_PTR]]
78+
// CHECK: [[ALLOCATED_PTR:%.+]] = getelementptr {{.+}}, ptr [[STRUCT_ARG]], i32 0, i32 5
79+
// CHECK: store ptr [[ALLOCATED]], ptr [[ALLOCATED_PTR]]
80+
// CHECK: [[LOADED_PTR:%.+]] = getelementptr {{.+}}, ptr [[STRUCT_ARG]], i32 0, i32 6
81+
// CHECK: store i32 [[LOADED]], ptr [[LOADED_PTR]]
82+
83+
// Runtime call.
84+
// CHECK: call void {{.+}} @__kmpc_fork_teams(ptr @{{.+}}, i32 1, ptr [[WRAPPER_FN:@.+]], ptr [[STRUCT_ARG]])
85+
// CHECK: br label
86+
// CHECK: call void @bar()
87+
// CHECK: ret void
88+
llvm.func @omp_teams_branching_shared(%condition: i1, %arg0: i32, %arg1: f32, %arg2: !llvm.ptr<i32>, %arg3: f128) {
89+
%allocated = llvm.call @my_alloca_fn(): () -> !llvm.ptr<i32>
90+
%loaded = llvm.load %allocated : !llvm.ptr<i32>
91+
llvm.br ^codegenBlock
92+
^codegenBlock:
93+
omp.teams {
94+
llvm.cond_br %condition, ^true_block, ^false_block
95+
^true_block:
96+
llvm.call @foo(%arg0, %arg1, %arg2, %arg3, %allocated, %loaded) : (i32, f32, !llvm.ptr<i32>, f128, !llvm.ptr<i32>, i32) -> ()
97+
llvm.br ^exit
98+
^false_block:
99+
llvm.br ^exit
100+
^exit:
101+
omp.terminator
102+
}
103+
llvm.call @bar() : () -> ()
104+
llvm.return
105+
}
106+
107+
// Check the outlined function.
108+
// CHECK: define internal void [[OUTLINED_FN:@.+]](ptr [[DATA:%.+]])
109+
// CHECK: [[CONDITION_PTR:%.+]] = getelementptr {{.+}}, ptr [[DATA]]
110+
// CHECK: [[CONDITION:%.+]] = load i1, ptr [[CONDITION_PTR]]
111+
// CHECK: [[ARG0_PTR:%.+]] = getelementptr {{.+}}, ptr [[DATA]], i32 0, i32 1
112+
// CHECK: [[ARG0:%.+]] = load i32, ptr [[ARG0_PTR]]
113+
// CHECK: [[ARG1_PTR:%.+]] = getelementptr {{.+}}, ptr [[DATA]], i32 0, i32 2
114+
// CHECK: [[ARG1:%.+]] = load float, ptr [[ARG1_PTR]]
115+
// CHECK: [[ARG2_PTR:%.+]] = getelementptr {{.+}}, ptr [[DATA]], i32 0, i32 3
116+
// CHECK: [[ARG2:%.+]] = load ptr, ptr [[ARG2_PTR]]
117+
// CHECK: [[ARG3_PTR:%.+]] = getelementptr {{.+}}, ptr [[DATA]], i32 0, i32 4
118+
// CHECK: [[ARG3:%.+]] = load fp128, ptr [[ARG3_PTR]]
119+
// CHECK: [[ALLOCATED_PTR:%.+]] = getelementptr {{.+}}, ptr [[DATA]], i32 0, i32 5
120+
// CHECK: [[ALLOCATED:%.+]] = load ptr, ptr [[ALLOCATED_PTR]]
121+
// CHECK: [[LOADED_PTR:%.+]] = getelementptr {{.+}}, ptr [[DATA]], i32 0, i32 6
122+
// CHECK: [[LOADED:%.+]] = load i32, ptr [[LOADED_PTR]]
123+
// CHECK: br label
124+
125+
// CHECK: br i1 [[CONDITION]], label %[[TRUE:.+]], label %[[FALSE:.+]]
126+
// CHECK: [[FALSE]]:
127+
// CHECK-NEXT: br label
128+
// CHECK: [[TRUE]]:
129+
// CHECK: call void @foo(i32 [[ARG0]], float [[ARG1]], ptr [[ARG2]], fp128 [[ARG3]], ptr [[ALLOCATED]], i32 [[LOADED]])
130+
// CHECK-NEXT: br label
131+
// CHECK: ret void
132+
133+
// Check the wrapper function
134+
// CHECK: define void [[WRAPPER_FN]](ptr {{.+}}, ptr {{.+}}, ptr [[DATA:%.+]])
135+
// CHECK: call void [[OUTLINED_FN]](ptr [[DATA]])
136+
// CHECK: ret void

0 commit comments

Comments
 (0)