Skip to content

Commit d39f4a1

Browse files
authored
[MLIR][OpenMP]Add prescriptiveness-modifier support to granularity clauses of taskloop construct (#128477)
Added modifier(strict) support to the granularity(grainsize and num_tasks) clauses of taskloop construct.
1 parent 4af2e36 commit d39f4a1

File tree

4 files changed

+150
-17
lines changed

4 files changed

+150
-17
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -436,12 +436,11 @@ class OpenMP_GrainsizeClauseSkip<
436436
bit description = false, bit extraClassDeclaration = false
437437
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
438438
extraClassDeclaration> {
439-
let arguments = (ins
440-
Optional<IntLikeType>:$grainsize
441-
);
439+
let arguments = (ins OptionalAttr<GrainsizeTypeAttr>:$grainsize_mod,
440+
Optional<IntLikeType>:$grainsize);
442441

443442
let optAssemblyFormat = [{
444-
`grainsize` `(` $grainsize `:` type($grainsize) `)`
443+
`grainsize` `(` custom<GrainsizeClause>($grainsize_mod , $grainsize, type($grainsize)) `)`
445444
}];
446445

447446
let description = [{
@@ -895,12 +894,11 @@ class OpenMP_NumTasksClauseSkip<
895894
bit description = false, bit extraClassDeclaration = false
896895
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
897896
extraClassDeclaration> {
898-
let arguments = (ins
899-
Optional<IntLikeType>:$num_tasks
900-
);
897+
let arguments = (ins OptionalAttr<NumTasksTypeAttr>:$num_tasks_mod,
898+
Optional<IntLikeType>:$num_tasks);
901899

902900
let optAssemblyFormat = [{
903-
`num_tasks` `(` $num_tasks `:` type($num_tasks) `)`
901+
`num_tasks` `(` custom<NumTasksClause>($num_tasks_mod , $num_tasks, type($num_tasks)) `)`
904902
}];
905903

906904
let description = [{

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 104 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,99 @@ static void printOrderClause(OpAsmPrinter &p, Operation *op,
472472
p << stringifyClauseOrderKind(order.getValue());
473473
}
474474

475+
template <typename ClauseTypeAttr, typename ClauseType>
476+
static ParseResult
477+
parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness,
478+
std::optional<OpAsmParser::UnresolvedOperand> &operand,
479+
Type &operandType,
480+
std::optional<ClauseType> (*symbolizeClause)(StringRef),
481+
StringRef clauseName) {
482+
StringRef enumStr;
483+
if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
484+
if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
485+
prescriptiveness = ClauseTypeAttr::get(parser.getContext(), *enumValue);
486+
if (parser.parseComma())
487+
return failure();
488+
} else {
489+
return parser.emitError(parser.getCurrentLocation())
490+
<< "invalid " << clauseName << " modifier : '" << enumStr << "'";
491+
;
492+
}
493+
}
494+
495+
OpAsmParser::UnresolvedOperand var;
496+
if (succeeded(parser.parseOperand(var))) {
497+
operand = var;
498+
} else {
499+
return parser.emitError(parser.getCurrentLocation())
500+
<< "expected " << clauseName << " operand";
501+
}
502+
503+
if (operand.has_value()) {
504+
if (parser.parseColonType(operandType))
505+
return failure();
506+
}
507+
508+
return success();
509+
}
510+
511+
template <typename ClauseTypeAttr, typename ClauseType>
512+
static void
513+
printGranularityClause(OpAsmPrinter &p, Operation *op,
514+
ClauseTypeAttr prescriptiveness, Value operand,
515+
mlir::Type operandType,
516+
StringRef (*stringifyClauseType)(ClauseType)) {
517+
518+
if (prescriptiveness)
519+
p << stringifyClauseType(prescriptiveness.getValue()) << ", ";
520+
521+
if (operand)
522+
p << operand << ": " << operandType;
523+
}
524+
525+
//===----------------------------------------------------------------------===//
526+
// Parser and printer for grainsize Clause
527+
//===----------------------------------------------------------------------===//
528+
529+
// grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
530+
static ParseResult
531+
parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
532+
std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
533+
Type &grainsizeType) {
534+
return parseGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
535+
parser, grainsizeMod, grainsize, grainsizeType,
536+
&symbolizeClauseGrainsizeType, "grainsize");
537+
}
538+
539+
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op,
540+
ClauseGrainsizeTypeAttr grainsizeMod,
541+
Value grainsize, mlir::Type grainsizeType) {
542+
printGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
543+
p, op, grainsizeMod, grainsize, grainsizeType,
544+
&stringifyClauseGrainsizeType);
545+
}
546+
547+
//===----------------------------------------------------------------------===//
548+
// Parser and printer for num_tasks Clause
549+
//===----------------------------------------------------------------------===//
550+
551+
// numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
552+
static ParseResult
553+
parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
554+
std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
555+
Type &numTasksType) {
556+
return parseGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
557+
parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
558+
"num_tasks");
559+
}
560+
561+
static void printNumTasksClause(OpAsmPrinter &p, Operation *op,
562+
ClauseNumTasksTypeAttr numTasksMod,
563+
Value numTasks, mlir::Type numTasksType) {
564+
printGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
565+
p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
566+
}
567+
475568
//===----------------------------------------------------------------------===//
476569
// Parsers for operations including clauses that define entry block arguments.
477570
//===----------------------------------------------------------------------===//
@@ -2593,15 +2686,17 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state,
25932686
const TaskloopOperands &clauses) {
25942687
MLIRContext *ctx = builder.getContext();
25952688
// TODO Store clauses in op: privateVars, privateSyms.
2596-
TaskloopOp::build(
2597-
builder, state, clauses.allocateVars, clauses.allocatorVars,
2598-
clauses.final, clauses.grainsize, clauses.ifExpr, clauses.inReductionVars,
2599-
makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
2600-
makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2601-
clauses.nogroup, clauses.numTasks, clauses.priority, /*private_vars=*/{},
2602-
/*private_syms=*/nullptr, clauses.reductionMod, clauses.reductionVars,
2603-
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2604-
makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
2689+
TaskloopOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2690+
clauses.final, clauses.grainsizeMod, clauses.grainsize,
2691+
clauses.ifExpr, clauses.inReductionVars,
2692+
makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
2693+
makeArrayAttr(ctx, clauses.inReductionSyms),
2694+
clauses.mergeable, clauses.nogroup, clauses.numTasksMod,
2695+
clauses.numTasks, clauses.priority, /*private_vars=*/{},
2696+
/*private_syms=*/nullptr, clauses.reductionMod,
2697+
clauses.reductionVars,
2698+
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2699+
makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
26052700
}
26062701

26072702
SmallVector<Value> TaskloopOp::getAllReductionVars() {

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2064,6 +2064,30 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
20642064

20652065
// -----
20662066

2067+
func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
2068+
%testi64 = "test.i64"() : () -> (i64)
2069+
// expected-error @below {{invalid grainsize modifier : 'strict1'}}
2070+
omp.taskloop grainsize(strict1, %testi64: i64) {
2071+
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
2072+
omp.yield
2073+
}
2074+
}
2075+
return
2076+
}
2077+
// -----
2078+
2079+
func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
2080+
%testi64 = "test.i64"() : () -> (i64)
2081+
// expected-error @below {{invalid num_tasks modifier : 'default'}}
2082+
omp.taskloop num_tasks(default, %testi64: i64) {
2083+
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
2084+
omp.yield
2085+
}
2086+
}
2087+
return
2088+
}
2089+
// -----
2090+
20672091
func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
20682092
// expected-error @below {{op nested in loop wrapper is not another loop wrapper or `omp.loop_nest`}}
20692093
omp.taskloop {

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2417,6 +2417,22 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
24172417
}
24182418
}
24192419

2420+
// CHECK: omp.taskloop grainsize(strict, %{{[^:]+}}: i64) {
2421+
omp.taskloop grainsize(strict, %testi64: i64) {
2422+
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
2423+
// CHECK: omp.yield
2424+
omp.yield
2425+
}
2426+
}
2427+
2428+
// CHECK: omp.taskloop num_tasks(strict, %{{[^:]+}}: i64) {
2429+
omp.taskloop num_tasks(strict, %testi64: i64) {
2430+
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
2431+
// CHECK: omp.yield
2432+
omp.yield
2433+
}
2434+
}
2435+
24202436
// CHECK: omp.taskloop nogroup {
24212437
omp.taskloop nogroup {
24222438
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {

0 commit comments

Comments
 (0)