Skip to content

Commit 1424297

Browse files
SouraVXschweitzpgi
authored andcommitted
[OpenMP][flang]Lower NUM_THREADS clause for parallel construct
1 parent 90b3f5b commit 1424297

File tree

2 files changed

+54
-10
lines changed

2 files changed

+54
-10
lines changed

flang/lib/Lower/OpenMP.cpp

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "flang/Lower/FIRBuilder.h"
1616
#include "flang/Lower/PFTBuilder.h"
1717
#include "flang/Parser/parse-tree.h"
18+
#include "flang/Semantics/tools.h"
1819
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
1920
#include "llvm/Frontend/OpenMP/OMPConstants.h"
2021

@@ -87,19 +88,34 @@ genOMP(Fortran::lower::AbstractConverter &absConv,
8788
auto &firOpBuilder = absConv.getFirOpBuilder();
8889
auto currentLocation = absConv.getCurrentLocation();
8990
auto insertPt = firOpBuilder.saveInsertionPoint();
91+
92+
// Clauses.
93+
// FIXME: Add support for other clauses.
94+
mlir::Value numThreads;
95+
96+
const auto &parallelOpClauseList =
97+
std::get<Fortran::parser::OmpClauseList>(blockDirective.t);
98+
for (const auto &clause : parallelOpClauseList.v) {
99+
if (const auto &numThreadsClause =
100+
std::get_if<Fortran::parser::OmpClause::NumThreads>(&clause.u)) {
101+
// OMPIRBuilder expects `NUM_THREAD` clause as a `Value`.
102+
numThreads = absConv.genExprValue(
103+
*Fortran::semantics::GetExpr(numThreadsClause->v));
104+
}
105+
}
90106
llvm::ArrayRef<mlir::Type> argTy;
91-
mlir::ValueRange range;
92-
llvm::SmallVector<int32_t, 6> operandSegmentSizes(6 /*Size=*/,
93-
0 /*Value=*/);
94-
// create and insert the operation.
107+
Attribute defaultValue, procBindValue;
108+
// Create and insert the operation.
109+
// Create the Op with empty ranges for clauses that are yet to be lowered.
95110
auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>(
96-
currentLocation, argTy, range);
97-
parallelOp.setAttr(mlir::omp::ParallelOp::getOperandSegmentSizeAttr(),
98-
firOpBuilder.getI32VectorAttr(operandSegmentSizes));
99-
parallelOp.getRegion().push_back(new Block{});
111+
currentLocation, argTy, Value(), numThreads,
112+
defaultValue.dyn_cast_or_null<StringAttr>(), ValueRange(), ValueRange(),
113+
ValueRange(), ValueRange(),
114+
procBindValue.dyn_cast_or_null<StringAttr>());
115+
firOpBuilder.createBlock(&parallelOp.getRegion());
100116
auto &block = parallelOp.getRegion().back();
101117
firOpBuilder.setInsertionPointToStart(&block);
102-
// ensure the block is well-formed.
118+
// Ensure the block is well-formed.
103119
firOpBuilder.create<mlir::omp::TerminatorOp>(currentLocation);
104120
firOpBuilder.restoreInsertionPoint(insertPt);
105121
}

flang/test/Lower/OpenMP/empty-omp-parallel.f90

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010
program parallel
1111

1212
integer :: a,b,c
13+
integer :: num_threads
1314
! This and last statements are just for the sake ensuring that the
1415
! operation is created/inserted correctly and does not break/interfere with
1516
! other pieces which may be present before/after the operation.
1617
! However this test does not verify operation corresponding to this
1718
! statment.
1819
c = a + b
1920
!$OMP PARALLEL
21+
!$OMP END PARALLEL
2022
!FIRDialect: omp.parallel {
2123
!FIRDialect-NEXT: omp.terminator
2224
!FIRDialect-NEXT: }
@@ -25,9 +27,35 @@ program parallel
2527
!LLVMIRDialect-NEXT: omp.terminator
2628
!LLVMIRDialect-NEXT: }
2729

30+
!$OMP PARALLEL NUM_THREADS(16)
31+
!$OMP END PARALLEL
32+
num_threads = 4
33+
!$OMP PARALLEL NUM_THREADS(num_threads)
34+
!$OMP END PARALLEL
35+
36+
!FIRDialect: omp.parallel num_threads(%{{.*}} : i32) {
37+
!FIRDialect-NEXT: omp.terminator
38+
!FIRDialect-NEXT: }
39+
40+
!LLVMIRDialect: omp.parallel num_threads(%{{.*}} : !llvm.i32) {
41+
!LLVMIRDialect-NEXT: omp.terminator
42+
!LLVMIRDialect-NEXT: }
43+
44+
45+
!LLVMIR-LABEL: call i32 @__kmpc_global_thread_num(%struct.ident_t* @{{.*}})
2846
!LLVMIR: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN:.*]] to {{.*}}
47+
48+
!LLVMIR: %[[GLOBAL_THREAD_NUM1:.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @{{.*}})
49+
!LLVMIR: call void @__kmpc_push_num_threads(%struct.ident_t* @{{.*}}, i32 %[[GLOBAL_THREAD_NUM1]], i32 16)
50+
!LLVMIR: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN1:.*]] to {{.*}}
51+
52+
!LLVMIR: %[[GLOBAL_THREAD_NUM2:.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @{{.*}})
53+
!LLVMIR: call void @__kmpc_push_num_threads(%struct.ident_t* @{{.*}}, i32 %[[GLOBAL_THREAD_NUM2]], i32 %{{.*}})
54+
!LLVMIR: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN2:.*]] to {{.*}}
55+
56+
!LLVMIR: define internal void @[[OMP_OUTLINED_FN2]]
57+
!LLVMIR: define internal void @[[OMP_OUTLINED_FN1]]
2958
!LLVMIR: define internal void @[[OMP_OUTLINED_FN]]
30-
!$OMP END PARALLEL
3159
b = a + c
3260

3361
end program

0 commit comments

Comments
 (0)