Skip to content

Commit 5e2e91d

Browse files
committed
[MLIR][OpenMP] Add private clause to omp.parallel
Extends the `omp.parallel` op by adding a `private` clause to model [first]private variables. This uses the `omp.private` op to map privatized variables to their corresponding privatizers. Example `omp.private` op with `private` variable: ``` omp.parallel private(@x.privatizer %arg0 -> %arg1 : !llvm.ptr) { ^bb0(%arg1: !llvm.ptr): // ... use %arg1 ... omp.terminator } ``` Whether the variable is private or firstprivate is determined by the attributes of the corresponding `omp.private` op.
1 parent 3b18d2c commit 5e2e91d

File tree

7 files changed

+198
-10
lines changed

7 files changed

+198
-10
lines changed

flang/lib/Lower/OpenMP.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2590,7 +2590,8 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
25902590
? nullptr
25912591
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
25922592
reductionDeclSymbols),
2593-
procBindKindAttr);
2593+
procBindKindAttr, /*private_vars=*/llvm::SmallVector<mlir::Value>{},
2594+
/*privatizers=*/nullptr);
25942595
}
25952596

25962597
static mlir::omp::SectionOp

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,9 @@ def ParallelOp : OpenMP_Op<"parallel", [
276276
Variadic<AnyType>:$allocators_vars,
277277
Variadic<OpenMP_PointerLikeType>:$reduction_vars,
278278
OptionalAttr<SymbolRefArrayAttr>:$reductions,
279-
OptionalAttr<ProcBindKindAttr>:$proc_bind_val);
279+
OptionalAttr<ProcBindKindAttr>:$proc_bind_val,
280+
Variadic<AnyType>:$private_vars,
281+
OptionalAttr<SymbolRefArrayAttr>:$privatizers);
280282

281283
let regions = (region AnyRegion:$region);
282284

@@ -300,6 +302,10 @@ def ParallelOp : OpenMP_Op<"parallel", [
300302
$allocators_vars, type($allocators_vars)
301303
) `)`
302304
| `proc_bind` `(` custom<ClauseAttr>($proc_bind_val) `)`
305+
| `private` `(`
306+
custom<PrivateVarList>(
307+
$private_vars, type($private_vars), $privatizers
308+
) `)`
303309
) $region attr-dict
304310
}];
305311
let hasVerifier = 1;

mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,9 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
420420
/* allocators_vars = */ llvm::SmallVector<Value>{},
421421
/* reduction_vars = */ llvm::SmallVector<Value>{},
422422
/* reductions = */ ArrayAttr{},
423-
/* proc_bind_val = */ omp::ClauseProcBindKindAttr{});
423+
/* proc_bind_val = */ omp::ClauseProcBindKindAttr{},
424+
/* private_vars = */ ValueRange(),
425+
/* privatizers = */ nullptr);
424426
{
425427

426428
OpBuilder::InsertionGuard guard(rewriter);

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -990,14 +990,63 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
990990
builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
991991
/*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
992992
/*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
993-
/*proc_bind_val=*/nullptr);
993+
/*proc_bind_val=*/nullptr, /*private_vars=*/ValueRange(),
994+
/*privatizers=*/nullptr);
994995
state.addAttributes(attributes);
995996
}
996997

998+
static LogicalResult verifyPrivateVarList(ParallelOp &op) {
999+
auto privateVars = op.getPrivateVars();
1000+
auto privatizers = op.getPrivatizersAttr();
1001+
1002+
if (privateVars.empty() && (privatizers == nullptr || privatizers.empty()))
1003+
return success();
1004+
1005+
auto numPrivateVars = privateVars.size();
1006+
auto numPrivatizers = (privatizers == nullptr) ? 0 : privatizers.size();
1007+
1008+
if (numPrivateVars != numPrivatizers)
1009+
return op.emitError() << "inconsistent number of private variables and "
1010+
"privatizer op symbols, private vars: "
1011+
<< numPrivateVars
1012+
<< " vs. privatizer op symbols: " << numPrivatizers;
1013+
1014+
for (auto privateVarInfo : llvm::zip(privateVars, privatizers)) {
1015+
Type varType = std::get<0>(privateVarInfo).getType();
1016+
SymbolRefAttr privatizerSym =
1017+
std::get<1>(privateVarInfo).cast<SymbolRefAttr>();
1018+
PrivateClauseOp privatizerOp =
1019+
SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op,
1020+
privatizerSym);
1021+
1022+
if (privatizerOp == nullptr)
1023+
return op.emitError() << "failed to lookup privatizer op with symbol: '"
1024+
<< privatizerSym << "'";
1025+
1026+
Type privatizerType = privatizerOp.getType();
1027+
1028+
if (varType != privatizerType)
1029+
return op.emitError()
1030+
<< "type mismatch between a "
1031+
<< (privatizerOp.getDataSharingType() ==
1032+
DataSharingClauseType::Private
1033+
? "private"
1034+
: "firstprivate")
1035+
<< " variable and its privatizer op, var type: " << varType
1036+
<< " vs. privatizer op type: " << privatizerType;
1037+
}
1038+
1039+
return success();
1040+
}
1041+
9971042
LogicalResult ParallelOp::verify() {
9981043
if (getAllocateVars().size() != getAllocatorsVars().size())
9991044
return emitError(
10001045
"expected equal sizes for allocate and allocator variables");
1046+
1047+
if (failed(verifyPrivateVarList(*this)))
1048+
return failure();
1049+
10011050
return verifyReductionVarList(*this, getReductions(), getReductionVars());
10021051
}
10031052

@@ -1670,6 +1719,68 @@ LogicalResult PrivateClauseOp::verify() {
16701719
return success();
16711720
}
16721721

1722+
static ParseResult parsePrivateVarList(
1723+
OpAsmParser &parser,
1724+
llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> &privateVarsOperands,
1725+
llvm::SmallVector<Type, 1> &privateVarsTypes, ArrayAttr &privatizersAttr) {
1726+
SymbolRefAttr privatizerSym;
1727+
OpAsmParser::UnresolvedOperand privateArg;
1728+
OpAsmParser::UnresolvedOperand regionArg;
1729+
Type argType;
1730+
1731+
SmallVector<SymbolRefAttr> privatizersVec;
1732+
1733+
auto parsePrivatizers = [&]() -> ParseResult {
1734+
// @privatizer %var -> %region_arg : type
1735+
if (parser.parseAttribute(privatizerSym) ||
1736+
parser.parseOperand(privateArg) || parser.parseArrow() ||
1737+
parser.parseOperand(regionArg) || parser.parseColon() ||
1738+
parser.parseType(argType)) {
1739+
return failure();
1740+
}
1741+
1742+
privatizersVec.push_back(privatizerSym);
1743+
privateVarsOperands.push_back(privateArg);
1744+
privateVarsTypes.push_back(argType);
1745+
1746+
return success();
1747+
};
1748+
1749+
if (parser.parseCommaSeparatedList(parsePrivatizers))
1750+
return failure();
1751+
1752+
SmallVector<Attribute> privatizers(privatizersVec.begin(),
1753+
privatizersVec.end());
1754+
privatizersAttr = ArrayAttr::get(parser.getContext(), privatizers);
1755+
1756+
return success();
1757+
}
1758+
1759+
static void printPrivateVarList(OpAsmPrinter &printer, Operation *op,
1760+
OperandRange privateVars,
1761+
TypeRange privateVarTypes,
1762+
std::optional<ArrayAttr> privatizersAttr) {
1763+
unsigned argIndex = 0;
1764+
assert(privatizersAttr);
1765+
1766+
for (const auto &priateVarArgInfo :
1767+
llvm::zip(*privatizersAttr, privateVars, op->getRegion(0).getArguments(),
1768+
privateVarTypes)) {
1769+
assert(privatizersAttr);
1770+
printer << std::get<0>(priateVarArgInfo) << " "
1771+
<< std::get<1>(priateVarArgInfo) << " -> ";
1772+
1773+
std::get<2>(priateVarArgInfo)
1774+
.printAsOperand(printer.getStream(), OpPrintingFlags());
1775+
1776+
printer << " : " << std::get<3>(priateVarArgInfo);
1777+
1778+
++argIndex;
1779+
if (argIndex < privateVars.size())
1780+
printer << ", ";
1781+
}
1782+
}
1783+
16731784
#define GET_ATTRDEF_CLASSES
16741785
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
16751786

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1808,3 +1808,62 @@ omp.private {type = firstprivate} @x.privatizer : f32 alloc {
18081808
^bb0(%arg0: f32):
18091809
omp.yield(%arg0 : f32)
18101810
}
1811+
1812+
// -----
1813+
1814+
func.func @private_type_mismatch(%arg0: index) {
1815+
// expected-error @below {{type mismatch between a private variable and its privatizer op, var type: 'index' vs. privatizer op type: '!llvm.ptr'}}
1816+
omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
1817+
^bb0(%arg2: !llvm.ptr):
1818+
omp.terminator
1819+
}
1820+
1821+
return
1822+
}
1823+
1824+
omp.private {type = private} @var1.privatizer : !llvm.ptr alloc {
1825+
^bb0(%arg0: !llvm.ptr):
1826+
omp.yield(%arg0 : !llvm.ptr)
1827+
}
1828+
1829+
// -----
1830+
1831+
func.func @firstprivate_type_mismatch(%arg0: index) {
1832+
// expected-error @below {{type mismatch between a firstprivate variable and its privatizer op, var type: 'index' vs. privatizer op type: '!llvm.ptr'}}
1833+
omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
1834+
^bb0(%arg2: !llvm.ptr):
1835+
omp.terminator
1836+
}
1837+
1838+
return
1839+
}
1840+
1841+
omp.private {type = firstprivate} @var1.privatizer : !llvm.ptr alloc {
1842+
^bb0(%arg0: !llvm.ptr):
1843+
omp.yield(%arg0 : !llvm.ptr)
1844+
} copy {
1845+
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
1846+
omp.yield(%arg0 : !llvm.ptr)
1847+
}
1848+
1849+
// -----
1850+
1851+
func.func @undefined_privatizer(%arg0: index) {
1852+
// expected-error @below {{failed to lookup privatizer op with symbol: '@var1.privatizer'}}
1853+
omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
1854+
^bb0(%arg2: !llvm.ptr):
1855+
omp.terminator
1856+
}
1857+
1858+
return
1859+
}
1860+
1861+
// -----
1862+
func.func @undefined_privatizer(%arg0: !llvm.ptr) {
1863+
// expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
1864+
"omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1>, privatizers = [@x.privatizer, @y.privatizer]}> ({
1865+
^bb0(%arg2: !llvm.ptr):
1866+
omp.terminator
1867+
}) : (!llvm.ptr) -> ()
1868+
return
1869+
}

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
5959
// CHECK: omp.parallel num_threads(%{{.*}} : i32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
6060
"omp.parallel"(%num_threads, %data_var, %data_var) ({
6161
omp.terminator
62-
}) {operandSegmentSizes = array<i32: 0,1,1,1,0>} : (i32, memref<i32>, memref<i32>) -> ()
62+
}) {operandSegmentSizes = array<i32: 0,1,1,1,0,0>} : (i32, memref<i32>, memref<i32>) -> ()
6363

6464
// CHECK: omp.barrier
6565
omp.barrier
@@ -68,22 +68,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
6868
// CHECK: omp.parallel if(%{{.*}}) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
6969
"omp.parallel"(%if_cond, %data_var, %data_var) ({
7070
omp.terminator
71-
}) {operandSegmentSizes = array<i32: 1,0,1,1,0>} : (i1, memref<i32>, memref<i32>) -> ()
71+
}) {operandSegmentSizes = array<i32: 1,0,1,1,0,0>} : (i1, memref<i32>, memref<i32>) -> ()
7272

7373
// test without allocate
7474
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
7575
"omp.parallel"(%if_cond, %num_threads) ({
7676
omp.terminator
77-
}) {operandSegmentSizes = array<i32: 1,1,0,0,0>} : (i1, i32) -> ()
77+
}) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (i1, i32) -> ()
7878

7979
omp.terminator
80-
}) {operandSegmentSizes = array<i32: 1,1,1,1,0>, proc_bind_val = #omp<procbindkind spread>} : (i1, i32, memref<i32>, memref<i32>) -> ()
80+
}) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_val = #omp<procbindkind spread>} : (i1, i32, memref<i32>, memref<i32>) -> ()
8181

8282
// test with multiple parameters for single variadic argument
8383
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
8484
"omp.parallel" (%data_var, %data_var) ({
8585
omp.terminator
86-
}) {operandSegmentSizes = array<i32: 0,0,1,1,0>} : (memref<i32>, memref<i32>) -> ()
86+
}) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (memref<i32>, memref<i32>) -> ()
8787

8888
return
8989
}

mlir/test/Dialect/OpenMP/roundtrip.mlir

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
22

3+
// CHECK-LABEL: parallel_op_privatizers
4+
func.func @parallel_op_privatizers(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
5+
// CHECK: omp.parallel private(@x.privatizer %arg0 -> %arg2 : !llvm.ptr, @y.privatizer %arg1 -> %arg3 : !llvm.ptr)
6+
omp.parallel private(@x.privatizer %arg0 -> %arg2 : !llvm.ptr, @y.privatizer %arg1 -> %arg3 : !llvm.ptr) {
7+
^bb0(%arg2: !llvm.ptr, %arg3: !llvm.ptr):
8+
omp.terminator
9+
}
10+
return
11+
}
12+
313
// CHECK: omp.private {type = private} @x.privatizer : !llvm.ptr alloc {
414
omp.private {type = private} @x.privatizer : !llvm.ptr alloc {
515
// CHECK: ^bb0(%arg0: {{.*}}):
@@ -18,4 +28,3 @@ omp.private {type = firstprivate} @y.privatizer : !llvm.ptr alloc {
1828
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
1929
omp.yield(%arg0 : !llvm.ptr)
2030
}
21-

0 commit comments

Comments
 (0)