Skip to content

Commit 3b390a1

Browse files
[flang][OpenMP] Support for Collapse
Convert Fortran parse-tree into MLIR for collapse-clause. Includes simple Fortran to LLVM-IR test, with auto-generated check-lines (some of which have been edited by hand). Reviewed By: kiranchandramohan, shraiysh, peixin Differential Revision: https://reviews.llvm.org/D125302
1 parent 729467a commit 3b390a1

File tree

4 files changed

+162
-41
lines changed

4 files changed

+162
-41
lines changed

flang/include/flang/Lower/OpenMP.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,14 @@
1313
#ifndef FORTRAN_LOWER_OPENMP_H
1414
#define FORTRAN_LOWER_OPENMP_H
1515

16+
#include <cinttypes>
17+
1618
namespace Fortran {
1719
namespace parser {
1820
struct OpenMPConstruct;
1921
struct OpenMPDeclarativeConstruct;
22+
struct OmpEndLoopDirective;
23+
struct OmpClauseList;
2024
} // namespace parser
2125

2226
namespace lower {
@@ -31,6 +35,7 @@ void genOpenMPConstruct(AbstractConverter &, pft::Evaluation &,
3135
const parser::OpenMPConstruct &);
3236
void genOpenMPDeclarativeConstruct(AbstractConverter &, pft::Evaluation &,
3337
const parser::OpenMPDeclarativeConstruct &);
38+
int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList);
3439

3540
} // namespace lower
3641
} // namespace Fortran

flang/lib/Lower/Bridge.cpp

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1401,15 +1401,29 @@ class FirConverter : public Fortran::lower::AbstractConverter {
14011401
void genFIR(const Fortran::parser::OpenMPConstruct &omp) {
14021402
mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
14031403
localSymbols.pushScope();
1404-
Fortran::lower::genOpenMPConstruct(*this, getEval(), omp);
1404+
genOpenMPConstruct(*this, getEval(), omp);
1405+
1406+
const Fortran::parser::OpenMPLoopConstruct *ompLoop =
1407+
std::get_if<Fortran::parser::OpenMPLoopConstruct>(&omp.u);
14051408

14061409
// If loop is part of an OpenMP Construct then the OpenMP dialect
14071410
// workshare loop operation has already been created. Only the
14081411
// body needs to be created here and the do_loop can be skipped.
1409-
Fortran::lower::pft::Evaluation *curEval =
1410-
std::get_if<Fortran::parser::OpenMPLoopConstruct>(&omp.u)
1411-
? &getEval().getFirstNestedEvaluation()
1412-
: &getEval();
1412+
// Skip the number of collapsed loops, which is 1 when there is a
1413+
// no collapse requested.
1414+
1415+
Fortran::lower::pft::Evaluation *curEval = &getEval();
1416+
if (ompLoop) {
1417+
const auto &wsLoopOpClauseList = std::get<Fortran::parser::OmpClauseList>(
1418+
std::get<Fortran::parser::OmpBeginLoopDirective>(ompLoop->t).t);
1419+
int64_t collapseValue =
1420+
Fortran::lower::getCollapseValue(wsLoopOpClauseList);
1421+
1422+
curEval = &curEval->getFirstNestedEvaluation();
1423+
for (int64_t i = 1; i < collapseValue; i++) {
1424+
curEval = &*std::next(curEval->getNestedEvaluations().begin());
1425+
}
1426+
}
14131427

14141428
for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
14151429
genFIR(e);

flang/lib/Lower/OpenMP.cpp

Lines changed: 81 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,18 @@
2525

2626
using namespace mlir;
2727

28+
int64_t Fortran::lower::getCollapseValue(
29+
const Fortran::parser::OmpClauseList &clauseList) {
30+
for (const auto &clause : clauseList.v) {
31+
if (const auto &collapseClause =
32+
std::get_if<Fortran::parser::OmpClause::Collapse>(&clause.u)) {
33+
const auto *expr = Fortran::semantics::GetExpr(collapseClause->v);
34+
return Fortran::evaluate::ToInt64(*expr).value();
35+
}
36+
}
37+
return 1;
38+
}
39+
2840
static const Fortran::parser::Name *
2941
getDesignatorNameIfDataRef(const Fortran::parser::Designator &designator) {
3042
const auto *dataRef = std::get_if<Fortran::parser::DataRef>(&designator.u);
@@ -108,22 +120,42 @@ static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
108120
}
109121
}
110122

123+
/// Create the body (block) for an OpenMP Operation.
124+
///
125+
/// \param [in] op - the operation the body belongs to.
126+
/// \param [inout] converter - converter to use for the clauses.
127+
/// \param [in] loc - location in source code.
128+
/// \oaran [in] clauses - list of clauses to process.
129+
/// \param [in] args - block arguments (induction variable[s]) for the
130+
//// region.
131+
/// \param [in] outerCombined - is this an outer operation - prevents
132+
/// privatization.
111133
template <typename Op>
112134
static void
113135
createBodyOfOp(Op &op, Fortran::lower::AbstractConverter &converter,
114136
mlir::Location &loc,
115137
const Fortran::parser::OmpClauseList *clauses = nullptr,
116-
const Fortran::semantics::Symbol *arg = nullptr,
138+
const SmallVector<const Fortran::semantics::Symbol *> &args = {},
117139
bool outerCombined = false) {
118140
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
119-
// If an argument for the region is provided then create the block with that
120-
// argument. Also update the symbol's address with the mlir argument value.
121-
// e.g. For loops the argument is the induction variable. And all further
141+
// If arguments for the region are provided then create the block with those
142+
// arguments. Also update the symbol's address with the mlir argument values.
143+
// e.g. For loops the arguments are the induction variable. And all further
122144
// uses of the induction variable should use this mlir value.
123-
if (arg) {
124-
firOpBuilder.createBlock(&op.getRegion(), {}, {converter.genType(*arg)},
125-
{loc});
126-
converter.bindSymbol(*arg, op.getRegion().front().getArgument(0));
145+
if (args.size()) {
146+
SmallVector<Type> tiv;
147+
SmallVector<Location> locs;
148+
int argIndex = 0;
149+
for (auto &arg : args) {
150+
tiv.push_back(converter.genType(*arg));
151+
locs.push_back(loc);
152+
}
153+
firOpBuilder.createBlock(&op.getRegion(), {}, tiv, locs);
154+
for (auto &arg : args) {
155+
fir::ExtendedValue exval = op.getRegion().front().getArgument(argIndex);
156+
converter.bindSymbol(*arg, exval);
157+
argIndex++;
158+
}
127159
} else {
128160
firOpBuilder.createBlock(&op.getRegion());
129161
}
@@ -394,38 +426,44 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
394426
TODO(converter.getCurrentLocation(), "Combined worksharing loop construct");
395427
}
396428

397-
Fortran::lower::pft::Evaluation *doConstructEval =
398-
&eval.getFirstNestedEvaluation();
399-
400-
Fortran::lower::pft::Evaluation *doLoop =
401-
&doConstructEval->getFirstNestedEvaluation();
402-
auto *doStmt = doLoop->getIf<Fortran::parser::NonLabelDoStmt>();
403-
assert(doStmt && "Expected do loop to be in the nested evaluation");
404-
const auto &loopControl =
405-
std::get<std::optional<Fortran::parser::LoopControl>>(doStmt->t);
406-
const Fortran::parser::LoopControl::Bounds *bounds =
407-
std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
408-
assert(bounds && "Expected bounds for worksharing do loop");
409-
Fortran::semantics::Symbol *iv = nullptr;
410-
Fortran::lower::StatementContext stmtCtx;
411-
lowerBound.push_back(fir::getBase(converter.genExprValue(
412-
*Fortran::semantics::GetExpr(bounds->lower), stmtCtx)));
413-
upperBound.push_back(fir::getBase(converter.genExprValue(
414-
*Fortran::semantics::GetExpr(bounds->upper), stmtCtx)));
415-
if (bounds->step) {
416-
step.push_back(fir::getBase(converter.genExprValue(
417-
*Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
418-
} else { // If `step` is not present, assume it as `1`.
419-
step.push_back(firOpBuilder.createIntegerConstant(
420-
currentLocation, firOpBuilder.getIntegerType(32), 1));
421-
}
422-
iv = bounds->name.thing.symbol;
429+
int64_t collapseValue = Fortran::lower::getCollapseValue(wsLoopOpClauseList);
430+
431+
// Collect the loops to collapse.
432+
auto *doConstructEval = &eval.getFirstNestedEvaluation();
433+
434+
SmallVector<const Fortran::semantics::Symbol *> iv;
435+
do {
436+
auto *doLoop = &doConstructEval->getFirstNestedEvaluation();
437+
auto *doStmt = doLoop->getIf<Fortran::parser::NonLabelDoStmt>();
438+
assert(doStmt && "Expected do loop to be in the nested evaluation");
439+
const auto &loopControl =
440+
std::get<std::optional<Fortran::parser::LoopControl>>(doStmt->t);
441+
const Fortran::parser::LoopControl::Bounds *bounds =
442+
std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
443+
assert(bounds && "Expected bounds for worksharing do loop");
444+
Fortran::lower::StatementContext stmtCtx;
445+
lowerBound.push_back(fir::getBase(converter.genExprValue(
446+
*Fortran::semantics::GetExpr(bounds->lower), stmtCtx)));
447+
upperBound.push_back(fir::getBase(converter.genExprValue(
448+
*Fortran::semantics::GetExpr(bounds->upper), stmtCtx)));
449+
if (bounds->step) {
450+
step.push_back(fir::getBase(converter.genExprValue(
451+
*Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
452+
} else { // If `step` is not present, assume it as `1`.
453+
step.push_back(firOpBuilder.createIntegerConstant(
454+
currentLocation, firOpBuilder.getIntegerType(32), 1));
455+
}
456+
iv.push_back(bounds->name.thing.symbol);
457+
458+
collapseValue--;
459+
doConstructEval =
460+
&*std::next(doConstructEval->getNestedEvaluations().begin());
461+
} while (collapseValue > 0);
423462

424463
// FIXME: Add support for following clauses:
425464
// 1. linear
426465
// 2. order
427-
// 3. collapse
428-
// 4. schedule (with chunk)
466+
// 3. schedule (with chunk)
429467
auto wsLoopOp = firOpBuilder.create<mlir::omp::WsLoopOp>(
430468
currentLocation, lowerBound, upperBound, step, linearVars, linearStepVars,
431469
reductionVars, /*reductions=*/nullptr,
@@ -451,6 +489,13 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
451489
} else {
452490
wsLoopOp.ordered_valAttr(firOpBuilder.getI64IntegerAttr(0));
453491
}
492+
} else if (const auto &collapseClause =
493+
std::get_if<Fortran::parser::OmpClause::Collapse>(
494+
&clause.u)) {
495+
const auto *expr = Fortran::semantics::GetExpr(collapseClause->v);
496+
const std::optional<std::int64_t> collapseValue =
497+
Fortran::evaluate::ToInt64(*expr);
498+
wsLoopOp.collapse_valAttr(firOpBuilder.getI64IntegerAttr(*collapseValue));
454499
} else if (const auto &scheduleClause =
455500
std::get_if<Fortran::parser::OmpClause::Schedule>(
456501
&clause.u)) {
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
! This test checks lowering of OpenMP DO Directive(Worksharing) with collapse.
2+
3+
! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s
4+
5+
program wsloop_collapse
6+
integer :: i, j, k
7+
integer :: a, b, c
8+
integer :: x
9+
! CHECK: %[[VAL_0:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFEa"}
10+
! CHECK: %[[VAL_1:.*]] = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFEb"}
11+
! CHECK: %[[VAL_2:.*]] = fir.alloca i32 {bindc_name = "c", uniq_name = "_QFEc"}
12+
! CHECK: %[[VAL_3:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFEi"}
13+
! CHECK: %[[VAL_4:.*]] = fir.alloca i32 {bindc_name = "j", uniq_name = "_QFEj"}
14+
! CHECK: %[[VAL_5:.*]] = fir.alloca i32 {bindc_name = "k", uniq_name = "_QFEk"}
15+
! CHECK: %[[VAL_6:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"}
16+
a=3
17+
! CHECK: %[[VAL_7:.*]] = arith.constant 3 : i32
18+
! CHECK: fir.store %[[VAL_7]] to %[[VAL_0]] : !fir.ref<i32>
19+
b=2
20+
! CHECK: %[[VAL_8:.*]] = arith.constant 2 : i32
21+
! CHECK: fir.store %[[VAL_8]] to %[[VAL_1]] : !fir.ref<i32>
22+
c=5
23+
! CHECK: %[[VAL_9:.*]] = arith.constant 5 : i32
24+
! CHECK: fir.store %[[VAL_9]] to %[[VAL_2]] : !fir.ref<i32>
25+
x=0
26+
! CHECK: %[[VAL_10:.*]] = arith.constant 0 : i32
27+
! CHECK: fir.store %[[VAL_10]] to %[[VAL_6]] : !fir.ref<i32>
28+
29+
!$omp do collapse(3)
30+
! CHECK: %[[VAL_20:.*]] = arith.constant 1 : i32
31+
! CHECK: %[[VAL_21:.*]] = fir.load %[[VAL_0]] : !fir.ref<i32>
32+
! CHECK: %[[VAL_22:.*]] = arith.constant 1 : i32
33+
! CHECK: %[[VAL_23:.*]] = arith.constant 1 : i32
34+
! CHECK: %[[VAL_24:.*]] = fir.load %[[VAL_1]] : !fir.ref<i32>
35+
! CHECK: %[[VAL_25:.*]] = arith.constant 1 : i32
36+
! CHECK: %[[VAL_26:.*]] = arith.constant 1 : i32
37+
! CHECK: %[[VAL_27:.*]] = fir.load %[[VAL_2]] : !fir.ref<i32>
38+
! CHECK: %[[VAL_28:.*]] = arith.constant 1 : i32
39+
do i = 1, a
40+
do j= 1, b
41+
do k = 1, c
42+
! CHECK: omp.wsloop collapse(3) for (%[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]]) : i32 = (%[[VAL_20]], %[[VAL_23]], %[[VAL_26]]) to (%[[VAL_21]], %[[VAL_24]], %[[VAL_27]]) inclusive step (%[[VAL_22]], %[[VAL_25]], %[[VAL_28]]) {
43+
! CHECK: %[[VAL_12:.*]] = fir.load %[[VAL_6]] : !fir.ref<i32>
44+
! CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_12]], %[[VAL_9]] : i32
45+
! CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_10]] : i32
46+
! CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_11]] : i32
47+
! CHECK: fir.store %[[VAL_15]] to %[[VAL_6]] : !fir.ref<i32>
48+
! CHECK: omp.yield
49+
! CHECK: }
50+
x = x + i + j + k
51+
end do
52+
end do
53+
end do
54+
!$omp end do
55+
! CHECK: return
56+
! CHECK: }
57+
end program wsloop_collapse

0 commit comments

Comments
 (0)