25
25
26
26
using namespace mlir ;
27
27
28
+ int64_t Fortran::lower::getCollapseValue (
29
+ const Fortran::parser::OmpClauseList &clauseList) {
30
+ for (const auto &clause : clauseList.v ) {
31
+ if (const auto &collapseClause =
32
+ std::get_if<Fortran::parser::OmpClause::Collapse>(&clause.u )) {
33
+ const auto *expr = Fortran::semantics::GetExpr (collapseClause->v );
34
+ return Fortran::evaluate::ToInt64 (*expr).value ();
35
+ }
36
+ }
37
+ return 1 ;
38
+ }
39
+
28
40
static const Fortran::parser::Name *
29
41
getDesignatorNameIfDataRef (const Fortran::parser::Designator &designator) {
30
42
const auto *dataRef = std::get_if<Fortran::parser::DataRef>(&designator.u );
@@ -108,22 +120,42 @@ static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
108
120
}
109
121
}
110
122
123
+ // / Create the body (block) for an OpenMP Operation.
124
+ // /
125
+ // / \param [in] op - the operation the body belongs to.
126
+ // / \param [inout] converter - converter to use for the clauses.
127
+ // / \param [in] loc - location in source code.
128
+ // / \oaran [in] clauses - list of clauses to process.
129
+ // / \param [in] args - block arguments (induction variable[s]) for the
130
+ // // region.
131
+ // / \param [in] outerCombined - is this an outer operation - prevents
132
+ // / privatization.
111
133
template <typename Op>
112
134
static void
113
135
createBodyOfOp (Op &op, Fortran::lower::AbstractConverter &converter,
114
136
mlir::Location &loc,
115
137
const Fortran::parser::OmpClauseList *clauses = nullptr ,
116
- const Fortran::semantics::Symbol *arg = nullptr ,
138
+ const SmallVector< const Fortran::semantics::Symbol *> &args = {} ,
117
139
bool outerCombined = false ) {
118
140
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
119
- // If an argument for the region is provided then create the block with that
120
- // argument . Also update the symbol's address with the mlir argument value .
121
- // e.g. For loops the argument is the induction variable. And all further
141
+ // If arguments for the region are provided then create the block with those
142
+ // arguments . Also update the symbol's address with the mlir argument values .
143
+ // e.g. For loops the arguments are the induction variable. And all further
122
144
// uses of the induction variable should use this mlir value.
123
- if (arg) {
124
- firOpBuilder.createBlock (&op.getRegion (), {}, {converter.genType (*arg)},
125
- {loc});
126
- converter.bindSymbol (*arg, op.getRegion ().front ().getArgument (0 ));
145
+ if (args.size ()) {
146
+ SmallVector<Type> tiv;
147
+ SmallVector<Location> locs;
148
+ int argIndex = 0 ;
149
+ for (auto &arg : args) {
150
+ tiv.push_back (converter.genType (*arg));
151
+ locs.push_back (loc);
152
+ }
153
+ firOpBuilder.createBlock (&op.getRegion (), {}, tiv, locs);
154
+ for (auto &arg : args) {
155
+ fir::ExtendedValue exval = op.getRegion ().front ().getArgument (argIndex);
156
+ converter.bindSymbol (*arg, exval);
157
+ argIndex++;
158
+ }
127
159
} else {
128
160
firOpBuilder.createBlock (&op.getRegion ());
129
161
}
@@ -394,38 +426,44 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
394
426
TODO (converter.getCurrentLocation (), " Combined worksharing loop construct" );
395
427
}
396
428
397
- Fortran::lower::pft::Evaluation *doConstructEval =
398
- &eval.getFirstNestedEvaluation ();
399
-
400
- Fortran::lower::pft::Evaluation *doLoop =
401
- &doConstructEval->getFirstNestedEvaluation ();
402
- auto *doStmt = doLoop->getIf <Fortran::parser::NonLabelDoStmt>();
403
- assert (doStmt && " Expected do loop to be in the nested evaluation" );
404
- const auto &loopControl =
405
- std::get<std::optional<Fortran::parser::LoopControl>>(doStmt->t );
406
- const Fortran::parser::LoopControl::Bounds *bounds =
407
- std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u );
408
- assert (bounds && " Expected bounds for worksharing do loop" );
409
- Fortran::semantics::Symbol *iv = nullptr ;
410
- Fortran::lower::StatementContext stmtCtx;
411
- lowerBound.push_back (fir::getBase (converter.genExprValue (
412
- *Fortran::semantics::GetExpr (bounds->lower ), stmtCtx)));
413
- upperBound.push_back (fir::getBase (converter.genExprValue (
414
- *Fortran::semantics::GetExpr (bounds->upper ), stmtCtx)));
415
- if (bounds->step ) {
416
- step.push_back (fir::getBase (converter.genExprValue (
417
- *Fortran::semantics::GetExpr (bounds->step ), stmtCtx)));
418
- } else { // If `step` is not present, assume it as `1`.
419
- step.push_back (firOpBuilder.createIntegerConstant (
420
- currentLocation, firOpBuilder.getIntegerType (32 ), 1 ));
421
- }
422
- iv = bounds->name .thing .symbol ;
429
+ int64_t collapseValue = Fortran::lower::getCollapseValue (wsLoopOpClauseList);
430
+
431
+ // Collect the loops to collapse.
432
+ auto *doConstructEval = &eval.getFirstNestedEvaluation ();
433
+
434
+ SmallVector<const Fortran::semantics::Symbol *> iv;
435
+ do {
436
+ auto *doLoop = &doConstructEval->getFirstNestedEvaluation ();
437
+ auto *doStmt = doLoop->getIf <Fortran::parser::NonLabelDoStmt>();
438
+ assert (doStmt && " Expected do loop to be in the nested evaluation" );
439
+ const auto &loopControl =
440
+ std::get<std::optional<Fortran::parser::LoopControl>>(doStmt->t );
441
+ const Fortran::parser::LoopControl::Bounds *bounds =
442
+ std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u );
443
+ assert (bounds && " Expected bounds for worksharing do loop" );
444
+ Fortran::lower::StatementContext stmtCtx;
445
+ lowerBound.push_back (fir::getBase (converter.genExprValue (
446
+ *Fortran::semantics::GetExpr (bounds->lower ), stmtCtx)));
447
+ upperBound.push_back (fir::getBase (converter.genExprValue (
448
+ *Fortran::semantics::GetExpr (bounds->upper ), stmtCtx)));
449
+ if (bounds->step ) {
450
+ step.push_back (fir::getBase (converter.genExprValue (
451
+ *Fortran::semantics::GetExpr (bounds->step ), stmtCtx)));
452
+ } else { // If `step` is not present, assume it as `1`.
453
+ step.push_back (firOpBuilder.createIntegerConstant (
454
+ currentLocation, firOpBuilder.getIntegerType (32 ), 1 ));
455
+ }
456
+ iv.push_back (bounds->name .thing .symbol );
457
+
458
+ collapseValue--;
459
+ doConstructEval =
460
+ &*std::next (doConstructEval->getNestedEvaluations ().begin ());
461
+ } while (collapseValue > 0 );
423
462
424
463
// FIXME: Add support for following clauses:
425
464
// 1. linear
426
465
// 2. order
427
- // 3. collapse
428
- // 4. schedule (with chunk)
466
+ // 3. schedule (with chunk)
429
467
auto wsLoopOp = firOpBuilder.create <mlir::omp::WsLoopOp>(
430
468
currentLocation, lowerBound, upperBound, step, linearVars, linearStepVars,
431
469
reductionVars, /* reductions=*/ nullptr ,
@@ -451,6 +489,13 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
451
489
} else {
452
490
wsLoopOp.ordered_valAttr (firOpBuilder.getI64IntegerAttr (0 ));
453
491
}
492
+ } else if (const auto &collapseClause =
493
+ std::get_if<Fortran::parser::OmpClause::Collapse>(
494
+ &clause.u )) {
495
+ const auto *expr = Fortran::semantics::GetExpr (collapseClause->v );
496
+ const std::optional<std::int64_t > collapseValue =
497
+ Fortran::evaluate::ToInt64 (*expr);
498
+ wsLoopOp.collapse_valAttr (firOpBuilder.getI64IntegerAttr (*collapseValue));
454
499
} else if (const auto &scheduleClause =
455
500
std::get_if<Fortran::parser::OmpClause::Schedule>(
456
501
&clause.u )) {
0 commit comments