@@ -278,6 +278,80 @@ genOMP(Fortran::lower::AbstractConverter &converter,
278
278
standaloneConstruct.u );
279
279
}
280
280
281
+ static omp::ClauseProcBindKindAttr genProcBindKindAttr (
282
+ fir::FirOpBuilder &firOpBuilder,
283
+ const Fortran::parser::OmpClause::ProcBind *procBindClause) {
284
+ omp::ClauseProcBindKind pbKind;
285
+ switch (procBindClause->v .v ) {
286
+ case Fortran::parser::OmpProcBindClause::Type::Master:
287
+ pbKind = omp::ClauseProcBindKind::Master;
288
+ break ;
289
+ case Fortran::parser::OmpProcBindClause::Type::Close:
290
+ pbKind = omp::ClauseProcBindKind::Close;
291
+ break ;
292
+ case Fortran::parser::OmpProcBindClause::Type::Spread:
293
+ pbKind = omp::ClauseProcBindKind::Spread;
294
+ break ;
295
+ case Fortran::parser::OmpProcBindClause::Type::Primary:
296
+ pbKind = omp::ClauseProcBindKind::Primary;
297
+ break ;
298
+ }
299
+ return omp::ClauseProcBindKindAttr::get (firOpBuilder.getContext (), pbKind);
300
+ }
301
+
302
+ /* When parallel is used in a combined construct, then use this function to
303
+ * create the parallel operation. It handles the parallel specific clauses
304
+ * and leaves the rest for handling at the inner operations.
305
+ * TODO: Refactor clause handling
306
+ */
307
+ template <typename Directive>
308
+ static void
309
+ createCombinedParallelOp (Fortran::lower::AbstractConverter &converter,
310
+ Fortran::lower::pft::Evaluation &eval,
311
+ const Directive &directive) {
312
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
313
+ mlir::Location currentLocation = converter.getCurrentLocation ();
314
+ Fortran::lower::StatementContext stmtCtx;
315
+ llvm::ArrayRef<mlir::Type> argTy;
316
+ mlir::Value ifClauseOperand, numThreadsClauseOperand;
317
+ SmallVector<Value> allocatorOperands, allocateOperands;
318
+ mlir::omp::ClauseProcBindKindAttr procBindKindAttr;
319
+ const auto &opClauseList =
320
+ std::get<Fortran::parser::OmpClauseList>(directive.t );
321
+ // TODO: Handle the following clauses
322
+ // 1. default
323
+ // 2. copyin
324
+ // Note: rest of the clauses are handled when the inner operation is created
325
+ for (const Fortran::parser::OmpClause &clause : opClauseList.v ) {
326
+ if (const auto &ifClause =
327
+ std::get_if<Fortran::parser::OmpClause::If>(&clause.u )) {
328
+ auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v .t );
329
+ mlir::Value ifVal = fir::getBase (
330
+ converter.genExprValue (*Fortran::semantics::GetExpr (expr), stmtCtx));
331
+ ifClauseOperand = firOpBuilder.createConvert (
332
+ currentLocation, firOpBuilder.getI1Type (), ifVal);
333
+ } else if (const auto &numThreadsClause =
334
+ std::get_if<Fortran::parser::OmpClause::NumThreads>(
335
+ &clause.u )) {
336
+ numThreadsClauseOperand = fir::getBase (converter.genExprValue (
337
+ *Fortran::semantics::GetExpr (numThreadsClause->v ), stmtCtx));
338
+ } else if (const auto &procBindClause =
339
+ std::get_if<Fortran::parser::OmpClause::ProcBind>(
340
+ &clause.u )) {
341
+ procBindKindAttr = genProcBindKindAttr (firOpBuilder, procBindClause);
342
+ }
343
+ }
344
+ // Create and insert the operation.
345
+ auto parallelOp = firOpBuilder.create <mlir::omp::ParallelOp>(
346
+ currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand,
347
+ allocateOperands, allocatorOperands, /* reduction_vars=*/ ValueRange (),
348
+ /* reductions=*/ nullptr , procBindKindAttr);
349
+
350
+ createBodyOfOp<omp::ParallelOp>(parallelOp, converter, currentLocation,
351
+ &opClauseList, /* iv=*/ {},
352
+ /* isCombined=*/ true );
353
+ }
354
+
281
355
static void
282
356
genOMP (Fortran::lower::AbstractConverter &converter,
283
357
Fortran::lower::pft::Evaluation &eval,
@@ -318,23 +392,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
318
392
} else if (const auto &procBindClause =
319
393
std::get_if<Fortran::parser::OmpClause::ProcBind>(
320
394
&clause.u )) {
321
- omp::ClauseProcBindKind pbKind;
322
- switch (procBindClause->v .v ) {
323
- case Fortran::parser::OmpProcBindClause::Type::Master:
324
- pbKind = omp::ClauseProcBindKind::Master;
325
- break ;
326
- case Fortran::parser::OmpProcBindClause::Type::Close:
327
- pbKind = omp::ClauseProcBindKind::Close;
328
- break ;
329
- case Fortran::parser::OmpProcBindClause::Type::Spread:
330
- pbKind = omp::ClauseProcBindKind::Spread;
331
- break ;
332
- case Fortran::parser::OmpProcBindClause::Type::Primary:
333
- pbKind = omp::ClauseProcBindKind::Primary;
334
- break ;
335
- }
336
- procBindKindAttr =
337
- omp::ClauseProcBindKindAttr::get (firOpBuilder.getContext (), pbKind);
395
+ procBindKindAttr = genProcBindKindAttr (firOpBuilder, procBindClause);
338
396
} else if (const auto &allocateClause =
339
397
std::get_if<Fortran::parser::OmpClause::Allocate>(
340
398
&clause.u )) {
@@ -419,11 +477,17 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
419
477
noWaitClauseOperand, orderedClauseOperand, orderClauseOperand;
420
478
const auto &wsLoopOpClauseList = std::get<Fortran::parser::OmpClauseList>(
421
479
std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t ).t );
422
- if (llvm::omp::OMPD_do !=
480
+
481
+ const auto ompDirective =
423
482
std::get<Fortran::parser::OmpLoopDirective>(
424
483
std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t ).t )
425
- .v ) {
426
- TODO (converter.getCurrentLocation (), " Combined worksharing loop construct" );
484
+ .v ;
485
+ if (llvm::omp::OMPD_parallel_do == ompDirective) {
486
+ createCombinedParallelOp<Fortran::parser::OmpBeginLoopDirective>(
487
+ converter, eval,
488
+ std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t ));
489
+ } else if (llvm::omp::OMPD_do != ompDirective) {
490
+ TODO (converter.getCurrentLocation (), " Construct enclosing do loop" );
427
491
}
428
492
429
493
int64_t collapseValue = Fortran::lower::getCollapseValue (wsLoopOpClauseList);
@@ -648,15 +712,14 @@ genOMP(Fortran::lower::AbstractConverter &converter,
648
712
649
713
// Parallel Sections Construct
650
714
if (dir == llvm::omp::Directive::OMPD_parallel_sections) {
651
- auto parallelOp = firOpBuilder.create <mlir::omp::ParallelOp>(
652
- currentLocation, /* if_expr_var*/ nullptr , /* num_threads_var*/ nullptr ,
653
- allocateOperands, allocatorOperands, /* reduction_vars=*/ ValueRange (),
654
- /* reductions=*/ nullptr , /* proc_bind_val*/ nullptr );
655
- createBodyOfOp (parallelOp, converter, currentLocation);
715
+ createCombinedParallelOp<Fortran::parser::OmpBeginSectionsDirective>(
716
+ converter, eval,
717
+ std::get<Fortran::parser::OmpBeginSectionsDirective>(
718
+ sectionsConstruct.t ));
656
719
auto sectionsOp = firOpBuilder.create <mlir::omp::SectionsOp>(
657
720
currentLocation, /* reduction_vars*/ ValueRange (),
658
- /* reductions=*/ nullptr , /* allocate_vars */ ValueRange () ,
659
- /* allocators_vars */ ValueRange (), /* nowait=*/ nullptr );
721
+ /* reductions=*/ nullptr , allocateOperands, allocatorOperands ,
722
+ /* nowait=*/ nullptr );
660
723
createBodyOfOp (sectionsOp, converter, currentLocation);
661
724
662
725
// Sections Construct
0 commit comments