@@ -472,6 +472,99 @@ static void printOrderClause(OpAsmPrinter &p, Operation *op,
472
472
p << stringifyClauseOrderKind (order.getValue ());
473
473
}
474
474
475
+ template <typename ClauseTypeAttr, typename ClauseType>
476
+ static ParseResult
477
+ parseGranularityClause (OpAsmParser &parser, ClauseTypeAttr &prescriptiveness,
478
+ std::optional<OpAsmParser::UnresolvedOperand> &operand,
479
+ Type &operandType,
480
+ std::optional<ClauseType> (*symbolizeClause)(StringRef),
481
+ StringRef clauseName) {
482
+ StringRef enumStr;
483
+ if (succeeded (parser.parseOptionalKeyword (&enumStr))) {
484
+ if (std::optional<ClauseType> enumValue = symbolizeClause (enumStr)) {
485
+ prescriptiveness = ClauseTypeAttr::get (parser.getContext (), *enumValue);
486
+ if (parser.parseComma ())
487
+ return failure ();
488
+ } else {
489
+ return parser.emitError (parser.getCurrentLocation ())
490
+ << " invalid " << clauseName << " modifier : '" << enumStr << " '" ;
491
+ ;
492
+ }
493
+ }
494
+
495
+ OpAsmParser::UnresolvedOperand var;
496
+ if (succeeded (parser.parseOperand (var))) {
497
+ operand = var;
498
+ } else {
499
+ return parser.emitError (parser.getCurrentLocation ())
500
+ << " expected " << clauseName << " operand" ;
501
+ }
502
+
503
+ if (operand.has_value ()) {
504
+ if (parser.parseColonType (operandType))
505
+ return failure ();
506
+ }
507
+
508
+ return success ();
509
+ }
510
+
511
+ template <typename ClauseTypeAttr, typename ClauseType>
512
+ static void
513
+ printGranularityClause (OpAsmPrinter &p, Operation *op,
514
+ ClauseTypeAttr prescriptiveness, Value operand,
515
+ mlir::Type operandType,
516
+ StringRef (*stringifyClauseType)(ClauseType)) {
517
+
518
+ if (prescriptiveness)
519
+ p << stringifyClauseType (prescriptiveness.getValue ()) << " , " ;
520
+
521
+ if (operand)
522
+ p << operand << " : " << operandType;
523
+ }
524
+
525
+ // ===----------------------------------------------------------------------===//
526
+ // Parser and printer for grainsize Clause
527
+ // ===----------------------------------------------------------------------===//
528
+
529
+ // grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
530
+ static ParseResult
531
+ parseGrainsizeClause (OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
532
+ std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
533
+ Type &grainsizeType) {
534
+ return parseGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
535
+ parser, grainsizeMod, grainsize, grainsizeType,
536
+ &symbolizeClauseGrainsizeType, " grainsize" );
537
+ }
538
+
539
+ static void printGrainsizeClause (OpAsmPrinter &p, Operation *op,
540
+ ClauseGrainsizeTypeAttr grainsizeMod,
541
+ Value grainsize, mlir::Type grainsizeType) {
542
+ printGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
543
+ p, op, grainsizeMod, grainsize, grainsizeType,
544
+ &stringifyClauseGrainsizeType);
545
+ }
546
+
547
+ // ===----------------------------------------------------------------------===//
548
+ // Parser and printer for num_tasks Clause
549
+ // ===----------------------------------------------------------------------===//
550
+
551
+ // numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
552
+ static ParseResult
553
+ parseNumTasksClause (OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
554
+ std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
555
+ Type &numTasksType) {
556
+ return parseGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
557
+ parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
558
+ " num_tasks" );
559
+ }
560
+
561
+ static void printNumTasksClause (OpAsmPrinter &p, Operation *op,
562
+ ClauseNumTasksTypeAttr numTasksMod,
563
+ Value numTasks, mlir::Type numTasksType) {
564
+ printGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
565
+ p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
566
+ }
567
+
475
568
// ===----------------------------------------------------------------------===//
476
569
// Parsers for operations including clauses that define entry block arguments.
477
570
// ===----------------------------------------------------------------------===//
@@ -2593,15 +2686,17 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state,
2593
2686
const TaskloopOperands &clauses) {
2594
2687
MLIRContext *ctx = builder.getContext ();
2595
2688
// TODO Store clauses in op: privateVars, privateSyms.
2596
- TaskloopOp::build (
2597
- builder, state, clauses.allocateVars , clauses.allocatorVars ,
2598
- clauses.final , clauses.grainsize , clauses.ifExpr , clauses.inReductionVars ,
2599
- makeDenseBoolArrayAttr (ctx, clauses.inReductionByref ),
2600
- makeArrayAttr (ctx, clauses.inReductionSyms ), clauses.mergeable ,
2601
- clauses.nogroup , clauses.numTasks , clauses.priority , /* private_vars=*/ {},
2602
- /* private_syms=*/ nullptr , clauses.reductionMod , clauses.reductionVars ,
2603
- makeDenseBoolArrayAttr (ctx, clauses.reductionByref ),
2604
- makeArrayAttr (ctx, clauses.reductionSyms ), clauses.untied );
2689
+ TaskloopOp::build (builder, state, clauses.allocateVars , clauses.allocatorVars ,
2690
+ clauses.final , clauses.grainsizeMod , clauses.grainsize ,
2691
+ clauses.ifExpr , clauses.inReductionVars ,
2692
+ makeDenseBoolArrayAttr (ctx, clauses.inReductionByref ),
2693
+ makeArrayAttr (ctx, clauses.inReductionSyms ),
2694
+ clauses.mergeable , clauses.nogroup , clauses.numTasksMod ,
2695
+ clauses.numTasks , clauses.priority , /* private_vars=*/ {},
2696
+ /* private_syms=*/ nullptr , clauses.reductionMod ,
2697
+ clauses.reductionVars ,
2698
+ makeDenseBoolArrayAttr (ctx, clauses.reductionByref ),
2699
+ makeArrayAttr (ctx, clauses.reductionSyms ), clauses.untied );
2605
2700
}
2606
2701
2607
2702
SmallVector<Value> TaskloopOp::getAllReductionVars () {
0 commit comments