@@ -496,6 +496,8 @@ class ClauseProcessor {
496
496
bool processHint (mlir::IntegerAttr &result) const ;
497
497
bool processMergeable (mlir::UnitAttr &result) const ;
498
498
bool processNowait (mlir::UnitAttr &result) const ;
499
+ bool processNumTeams (Fortran::lower::StatementContext &stmtCtx,
500
+ mlir::Value &result) const ;
499
501
bool processNumThreads (Fortran::lower::StatementContext &stmtCtx,
500
502
mlir::Value &result) const ;
501
503
bool processOrdered (mlir::IntegerAttr &result) const ;
@@ -1347,6 +1349,18 @@ bool ClauseProcessor::processNowait(mlir::UnitAttr &result) const {
1347
1349
return markClauseOccurrence<ClauseTy::Nowait>(result);
1348
1350
}
1349
1351
1352
+ bool ClauseProcessor::processNumTeams (Fortran::lower::StatementContext &stmtCtx,
1353
+ mlir::Value &result) const {
1354
+ // TODO Get lower and upper bounds for num_teams when parser is updated to
1355
+ // accept both.
1356
+ if (auto *numTeamsClause = findUniqueClause<ClauseTy::NumTeams>()) {
1357
+ result = fir::getBase (converter.genExprValue (
1358
+ *Fortran::semantics::GetExpr (numTeamsClause->v ), stmtCtx));
1359
+ return true ;
1360
+ }
1361
+ return false ;
1362
+ }
1363
+
1350
1364
bool ClauseProcessor::processNumThreads (
1351
1365
Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
1352
1366
if (auto *numThreadsClause = findUniqueClause<ClauseTy::NumThreads>()) {
@@ -2359,6 +2373,40 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
2359
2373
mapOperands, mapTypesArrayAttr);
2360
2374
}
2361
2375
2376
+ static mlir::omp::TeamsOp
2377
+ genTeamsOp (Fortran::lower::AbstractConverter &converter,
2378
+ Fortran::lower::pft::Evaluation &eval,
2379
+ mlir::Location currentLocation,
2380
+ const Fortran::parser::OmpClauseList &clauseList,
2381
+ bool outerCombined = false ) {
2382
+ Fortran::lower::StatementContext stmtCtx;
2383
+ mlir::Value numTeamsClauseOperand, ifClauseOperand, threadLimitClauseOperand;
2384
+ llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
2385
+ reductionVars;
2386
+ llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
2387
+
2388
+ ClauseProcessor cp (converter, clauseList);
2389
+ cp.processIf (stmtCtx,
2390
+ Fortran::parser::OmpIfClause::DirectiveNameModifier::Teams,
2391
+ ifClauseOperand);
2392
+ cp.processAllocate (allocatorOperands, allocateOperands);
2393
+ cp.processDefault ();
2394
+ cp.processNumTeams (stmtCtx, numTeamsClauseOperand);
2395
+ cp.processThreadLimit (stmtCtx, threadLimitClauseOperand);
2396
+ if (cp.processReduction (currentLocation, reductionVars, reductionDeclSymbols))
2397
+ TODO (currentLocation, " Reduction of TEAMS directive" );
2398
+
2399
+ return genOpWithBody<mlir::omp::TeamsOp>(
2400
+ converter, eval, currentLocation, outerCombined, &clauseList,
2401
+ /* num_teams_lower=*/ nullptr , numTeamsClauseOperand, ifClauseOperand,
2402
+ threadLimitClauseOperand, allocateOperands, allocatorOperands,
2403
+ reductionVars,
2404
+ reductionDeclSymbols.empty ()
2405
+ ? nullptr
2406
+ : mlir::ArrayAttr::get (converter.getFirOpBuilder ().getContext (),
2407
+ reductionDeclSymbols));
2408
+ }
2409
+
2362
2410
// ===----------------------------------------------------------------------===//
2363
2411
// genOMP() Code generation helper functions
2364
2412
// ===----------------------------------------------------------------------===//
@@ -2483,7 +2531,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
2483
2531
if ((llvm::omp::allTeamsSet & llvm::omp::loopConstructSet)
2484
2532
.test (ompDirective)) {
2485
2533
validDirective = true ;
2486
- TODO (currentLocation, " Teams construct" );
2534
+ genTeamsOp (converter, eval, currentLocation, loopOpClauseList,
2535
+ /* outerCombined=*/ true );
2487
2536
}
2488
2537
if (llvm::omp::allDistributeSet.test (ompDirective)) {
2489
2538
validDirective = true ;
@@ -2628,7 +2677,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
2628
2677
!std::get_if<Fortran::parser::OmpClause::Map>(&clause.u ) &&
2629
2678
!std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(&clause.u ) &&
2630
2679
!std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(&clause.u ) &&
2631
- !std::get_if<Fortran::parser::OmpClause::ThreadLimit>(&clause.u )) {
2680
+ !std::get_if<Fortran::parser::OmpClause::ThreadLimit>(&clause.u ) &&
2681
+ !std::get_if<Fortran::parser::OmpClause::NumTeams>(&clause.u )) {
2632
2682
TODO (clauseLocation, " OpenMP Block construct clause" );
2633
2683
}
2634
2684
}
@@ -2667,7 +2717,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
2667
2717
genTaskGroupOp (converter, eval, currentLocation, beginClauseList);
2668
2718
break ;
2669
2719
case llvm::omp::Directive::OMPD_teams:
2670
- TODO (currentLocation, " Teams construct" );
2720
+ genTeamsOp (converter, eval, currentLocation, beginClauseList,
2721
+ /* outerCombined=*/ false );
2671
2722
break ;
2672
2723
case llvm::omp::Directive::OMPD_workshare:
2673
2724
TODO (currentLocation, " Workshare construct" );
@@ -2683,7 +2734,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
2683
2734
}
2684
2735
if ((llvm::omp::allTeamsSet & llvm::omp::blockConstructSet)
2685
2736
.test (directive.v )) {
2686
- TODO (currentLocation, " Teams construct " );
2737
+ genTeamsOp (converter, eval, currentLocation, beginClauseList );
2687
2738
combinedDirective = true ;
2688
2739
}
2689
2740
if ((llvm::omp::allParallelSet & llvm::omp::blockConstructSet)
0 commit comments