Skip to content

Commit 7a9ef0f

Browse files
authored
Adding masked operation to OpenMP Dialect (#96022)
Adding MLIR Op support for omp masked. Omp masked is introduced in 5.2 standard and allows a region to be executed by threads specified by a programmer. This is achieved with the help of filter clause which helps to specify thread id expected to execute the region.
1 parent 3141c11 commit 7a9ef0f

File tree

6 files changed

+91
-2
lines changed

6 files changed

+91
-2
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ struct DoacrossClauseOps {
8181
IntegerAttr doacrossNumLoopsAttr;
8282
};
8383

84+
struct FilterClauseOps {
85+
Value filteredThreadIdVar;
86+
};
87+
8488
struct FinalClauseOps {
8589
Value finalVar;
8690
};
@@ -254,8 +258,7 @@ using DistributeClauseOps =
254258

255259
using LoopNestClauseOps = detail::Clauses<CollapseClauseOps, LoopRelatedOps>;
256260

257-
// TODO `filter` clause.
258-
using MaskedClauseOps = detail::Clauses<>;
261+
using MaskedClauseOps = detail::Clauses<FilterClauseOps>;
259262

260263
using OrderedOpClauseOps = detail::Clauses<DoacrossClauseOps>;
261264

mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,4 +1204,32 @@ class OpenMP_UseDevicePtrClauseSkip<
12041204

12051205
def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>;
12061206

1207+
//===----------------------------------------------------------------------===//
1208+
// V5.2: [10.5.1] `filter` clause
1209+
//===----------------------------------------------------------------------===//
1210+
1211+
class OpenMP_FilterClauseSkip<
1212+
bit traits = false, bit arguments = false, bit assemblyFormat = false,
1213+
bit description = false, bit extraClassDeclaration = false
1214+
> : OpenMP_Clause</*isRequired=*/false, traits, arguments, assemblyFormat,
1215+
description, extraClassDeclaration> {
1216+
let arguments = (ins
1217+
Optional<IntLikeType>:$filtered_thread_id
1218+
);
1219+
1220+
let assemblyFormat = [{
1221+
`filter` `(` $filtered_thread_id `:` type($filtered_thread_id) `)`
1222+
}];
1223+
1224+
let description = [{
1225+
If `filter` is specified, the masked construct masks the execution of
1226+
the region to only the thread id filtered. Other threads executing the
1227+
parallel region are not expected to execute the region specified within
1228+
the `masked` directive. If `filter` is not specified, master thread is
1229+
expected to execute the region enclosed within `masked` directive.
1230+
}];
1231+
}
1232+
1233+
def OpenMP_FilterClause : OpenMP_FilterClauseSkip<>;
1234+
12071235
#endif // OPENMP_CLAUSES

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1577,4 +1577,21 @@ def DeclareReductionOp : OpenMP_Op<"declare_reduction", [IsolatedFromAbove,
15771577
let hasRegionVerifier = 1;
15781578
}
15791579

1580+
//===----------------------------------------------------------------------===//
1581+
// [Spec 5.2] 10.5 masked Construct
1582+
//===----------------------------------------------------------------------===//
1583+
def MaskedOp : OpenMP_Op<"masked", clauses = [
1584+
OpenMP_FilterClause
1585+
], singleRegion = 1> {
1586+
let summary = "masked construct";
1587+
let description = [{
1588+
Masked construct allows to specify a structured block to be executed by a subset of
1589+
threads of the current team.
1590+
}] # clausesDescription;
1591+
1592+
let builders = [
1593+
OpBuilder<(ins CArg<"const MaskedClauseOps &">:$clauses)>
1594+
];
1595+
}
1596+
15801597
#endif // OPENMP_OPS

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2578,6 +2578,15 @@ LogicalResult PrivateClauseOp::verify() {
25782578
return success();
25792579
}
25802580

2581+
//===----------------------------------------------------------------------===//
2582+
// Spec 5.2: Masked construct (10.5)
2583+
//===----------------------------------------------------------------------===//
2584+
2585+
void MaskedOp::build(OpBuilder &builder, OperationState &state,
2586+
const MaskedClauseOps &clauses) {
2587+
MaskedOp::build(builder, state, clauses.filteredThreadIdVar);
2588+
}
2589+
25812590
#define GET_ATTRDEF_CLASSES
25822591
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
25832592

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2358,3 +2358,21 @@ func.func @byref_in_private(%arg0: index) {
23582358

23592359
return
23602360
}
2361+
2362+
// -----
2363+
func.func @masked_arg_type_mismatch(%arg0: f32) {
2364+
// expected-error @below {{'omp.masked' op operand #0 must be integer or index, but got 'f32'}}
2365+
"omp.masked"(%arg0) ({
2366+
omp.terminator
2367+
}) : (f32) -> ()
2368+
return
2369+
}
2370+
2371+
// -----
2372+
func.func @masked_arg_count_mismatch(%arg0: i32, %arg1: i32) {
2373+
// expected-error @below {{'omp.masked' op operand group starting at #0 requires 0 or 1 element, but found 2}}
2374+
"omp.masked"(%arg0, %arg1) ({
2375+
omp.terminator
2376+
}) : (i32, i32) -> ()
2377+
return
2378+
}

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,20 @@ func.func @omp_master() -> () {
1616
return
1717
}
1818

19+
// CHECK-LABEL: omp_masked
20+
func.func @omp_masked(%filtered_thread_id : i32) -> () {
21+
// CHECK: omp.masked filter(%{{.*}} : i32)
22+
"omp.masked" (%filtered_thread_id) ({
23+
omp.terminator
24+
}) : (i32) -> ()
25+
26+
// CHECK: omp.masked
27+
"omp.masked" () ({
28+
omp.terminator
29+
}) : () -> ()
30+
return
31+
}
32+
1933
func.func @omp_taskwait() -> () {
2034
// CHECK: omp.taskwait
2135
omp.taskwait

0 commit comments

Comments
 (0)