Skip to content

Commit 544c806

Browse files
committed
[MLIR][OpenMP] Add private clause to omp.parallel
Extends the `omp.parallel` op by adding a `private` clause to model [first]private variables. This uses the `omp.private` op to map privatized variables to their corresponding privatizers. Example `omp.private` op with `private` variable: ``` omp.parallel private(@x.privatizer %arg0 -> %arg1 : !llvm.ptr) { // ... use %arg1 ... omp.terminator } ``` Whether the variable is private or firstprivate is determined by the attributes of the corresponding `omp.private` op.
1 parent 8c6e96d commit 544c806

File tree

8 files changed

+258
-65
lines changed

8 files changed

+258
-65
lines changed

flang/lib/Lower/OpenMP.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2640,7 +2640,8 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
26402640
? nullptr
26412641
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
26422642
reductionDeclSymbols),
2643-
procBindKindAttr);
2643+
procBindKindAttr, /*private_vars=*/llvm::SmallVector<mlir::Value>{},
2644+
/*privatizers=*/nullptr);
26442645
}
26452646

26462647
static mlir::omp::SectionOp

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,9 @@ def ParallelOp : OpenMP_Op<"parallel", [
270270
Variadic<AnyType>:$allocators_vars,
271271
Variadic<OpenMP_PointerLikeType>:$reduction_vars,
272272
OptionalAttr<SymbolRefArrayAttr>:$reductions,
273-
OptionalAttr<ProcBindKindAttr>:$proc_bind_val);
273+
OptionalAttr<ProcBindKindAttr>:$proc_bind_val,
274+
Variadic<AnyType>:$private_vars,
275+
OptionalAttr<SymbolRefArrayAttr>:$privatizers);
274276

275277
let regions = (region AnyRegion:$region);
276278

@@ -291,7 +293,7 @@ def ParallelOp : OpenMP_Op<"parallel", [
291293
$allocators_vars, type($allocators_vars)
292294
) `)`
293295
| `proc_bind` `(` custom<ClauseAttr>($proc_bind_val) `)`
294-
) custom<ParallelRegion>($region, $reduction_vars, type($reduction_vars), $reductions) attr-dict
296+
) custom<ParallelRegion>($region, $reduction_vars, type($reduction_vars), $reductions, $private_vars, type($private_vars), $privatizers) attr-dict
295297
}];
296298
let hasVerifier = 1;
297299
}

mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,9 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
420420
/* allocators_vars = */ llvm::SmallVector<Value>{},
421421
/* reduction_vars = */ llvm::SmallVector<Value>{},
422422
/* reductions = */ ArrayAttr{},
423-
/* proc_bind_val = */ omp::ClauseProcBindKindAttr{});
423+
/* proc_bind_val = */ omp::ClauseProcBindKindAttr{},
424+
/* private_vars = */ ValueRange(),
425+
/* privatizers = */ nullptr);
424426
{
425427

426428
OpBuilder::InsertionGuard guard(rewriter);

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 114 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -429,68 +429,98 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op,
429429
// Parser, printer and verifier for ReductionVarList
430430
//===----------------------------------------------------------------------===//
431431

432-
ParseResult
433-
parseReductionClause(OpAsmParser &parser, Region &region,
434-
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
435-
SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols,
436-
SmallVectorImpl<OpAsmParser::Argument> &privates) {
437-
if (failed(parser.parseOptionalKeyword("reduction")))
438-
return failure();
439-
432+
ParseResult parseClauseWithRegionArgs(
433+
OpAsmParser &parser, Region &region,
434+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
435+
SmallVectorImpl<Type> &types, ArrayAttr &symbols,
436+
SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs) {
440437
SmallVector<SymbolRefAttr> reductionVec;
438+
unsigned regionArgOffset = regionPrivateArgs.size();
441439

442440
if (failed(
443441
parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() {
444442
if (parser.parseAttribute(reductionVec.emplace_back()) ||
445443
parser.parseOperand(operands.emplace_back()) ||
446444
parser.parseArrow() ||
447-
parser.parseArgument(privates.emplace_back()) ||
445+
parser.parseArgument(regionPrivateArgs.emplace_back()) ||
448446
parser.parseColonType(types.emplace_back()))
449447
return failure();
450448
return success();
451449
})))
452450
return failure();
453451

454-
for (auto [prv, type] : llvm::zip_equal(privates, types)) {
452+
auto *argsBegin = regionPrivateArgs.begin();
453+
MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
454+
argsBegin + regionArgOffset + types.size());
455+
for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
455456
prv.type = type;
456457
}
457458
SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
458-
reductionSymbols = ArrayAttr::get(parser.getContext(), reductions);
459+
symbols = ArrayAttr::get(parser.getContext(), reductions);
459460
return success();
460461
}
461462

462-
static void printReductionClause(OpAsmPrinter &p, Operation *op, Region &region,
463-
ValueRange operands, TypeRange types,
464-
ArrayAttr reductionSymbols) {
465-
p << "reduction(";
466-
llvm::interleaveComma(llvm::zip_equal(reductionSymbols, operands,
467-
region.front().getArguments(), types),
468-
p, [&p](auto t) {
469-
auto [sym, op, arg, type] = t;
470-
p << sym << " " << op << " -> " << arg << " : "
471-
<< type;
472-
});
463+
static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op,
464+
Region &region, StringRef clauseName,
465+
ValueRange operands, TypeRange types,
466+
ArrayAttr symbols,
467+
unsigned regionArgOffset) {
468+
p << clauseName << "(";
469+
470+
auto *argsBegin = region.front().getArguments().begin();
471+
MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
472+
argsBegin + regionArgOffset + types.size());
473+
llvm::interleaveComma(
474+
llvm::zip_equal(symbols, operands, argsSubrange, types), p, [&p](auto t) {
475+
auto [sym, op, arg, type] = t;
476+
p << sym << " " << op << " -> " << arg << " : " << type;
477+
});
473478
p << ") ";
474479
}
475480

476-
static ParseResult
477-
parseParallelRegion(OpAsmParser &parser, Region &region,
478-
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
479-
SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols) {
481+
static ParseResult parseParallelRegion(
482+
OpAsmParser &parser, Region &region,
483+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVarOperands,
484+
SmallVectorImpl<Type> &reductionVarTypes, ArrayAttr &reductionSymbols,
485+
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVarOperands,
486+
llvm::SmallVectorImpl<Type> &privateVarsTypes,
487+
ArrayAttr &privatizerSymbols) {
488+
llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;
489+
490+
if (succeeded(parser.parseOptionalKeyword("reduction"))) {
491+
if (failed(parseClauseWithRegionArgs(parser, region, reductionVarOperands,
492+
reductionVarTypes, reductionSymbols,
493+
regionPrivateArgs)))
494+
return failure();
495+
}
480496

481-
llvm::SmallVector<OpAsmParser::Argument> privates;
482-
if (succeeded(parseReductionClause(parser, region, operands, types,
483-
reductionSymbols, privates)))
484-
return parser.parseRegion(region, privates);
497+
if (succeeded(parser.parseOptionalKeyword("private"))) {
498+
if (failed(parseClauseWithRegionArgs(parser, region, privateVarOperands,
499+
privateVarsTypes, privatizerSymbols,
500+
regionPrivateArgs)))
501+
return failure();
502+
}
485503

486-
return parser.parseRegion(region);
504+
return parser.parseRegion(region, regionPrivateArgs);
487505
}
488506

489507
static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
490-
ValueRange operands, TypeRange types,
491-
ArrayAttr reductionSymbols) {
508+
ValueRange reductionVarOperands,
509+
TypeRange reductionVarTypes,
510+
ArrayAttr reductionSymbols,
511+
ValueRange privateVarOperands,
512+
TypeRange privateVarTypes,
513+
ArrayAttr privatizerSymbols) {
492514
if (reductionSymbols)
493-
printReductionClause(p, op, region, operands, types, reductionSymbols);
515+
printClauseWithRegionArgs(p, op, region, "reduction", reductionVarOperands,
516+
reductionVarTypes, reductionSymbols,
517+
/*regionArgOffset=*/0);
518+
519+
if (privatizerSymbols)
520+
printClauseWithRegionArgs(p, op, region, "private", privateVarOperands,
521+
privateVarTypes, privatizerSymbols,
522+
reductionVarOperands.size());
523+
494524
p.printRegion(region, /*printEntryBlockArgs=*/false);
495525
}
496526

@@ -1057,14 +1087,63 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
10571087
builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
10581088
/*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
10591089
/*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
1060-
/*proc_bind_val=*/nullptr);
1090+
/*proc_bind_val=*/nullptr, /*private_vars=*/ValueRange(),
1091+
/*privatizers=*/nullptr);
10611092
state.addAttributes(attributes);
10621093
}
10631094

1095+
static LogicalResult verifyPrivateVarList(ParallelOp &op) {
1096+
auto privateVars = op.getPrivateVars();
1097+
auto privatizers = op.getPrivatizersAttr();
1098+
1099+
if (privateVars.empty() && (privatizers == nullptr || privatizers.empty()))
1100+
return success();
1101+
1102+
auto numPrivateVars = privateVars.size();
1103+
auto numPrivatizers = (privatizers == nullptr) ? 0 : privatizers.size();
1104+
1105+
if (numPrivateVars != numPrivatizers)
1106+
return op.emitError() << "inconsistent number of private variables and "
1107+
"privatizer op symbols, private vars: "
1108+
<< numPrivateVars
1109+
<< " vs. privatizer op symbols: " << numPrivatizers;
1110+
1111+
for (auto privateVarInfo : llvm::zip(privateVars, privatizers)) {
1112+
Type varType = std::get<0>(privateVarInfo).getType();
1113+
SymbolRefAttr privatizerSym =
1114+
std::get<1>(privateVarInfo).cast<SymbolRefAttr>();
1115+
PrivateClauseOp privatizerOp =
1116+
SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op,
1117+
privatizerSym);
1118+
1119+
if (privatizerOp == nullptr)
1120+
return op.emitError() << "failed to lookup privatizer op with symbol: '"
1121+
<< privatizerSym << "'";
1122+
1123+
Type privatizerType = privatizerOp.getType();
1124+
1125+
if (varType != privatizerType)
1126+
return op.emitError()
1127+
<< "type mismatch between a "
1128+
<< (privatizerOp.getDataSharingType() ==
1129+
DataSharingClauseType::Private
1130+
? "private"
1131+
: "firstprivate")
1132+
<< " variable and its privatizer op, var type: " << varType
1133+
<< " vs. privatizer op type: " << privatizerType;
1134+
}
1135+
1136+
return success();
1137+
}
1138+
10641139
LogicalResult ParallelOp::verify() {
10651140
if (getAllocateVars().size() != getAllocatorsVars().size())
10661141
return emitError(
10671142
"expected equal sizes for allocate and allocator variables");
1143+
1144+
if (failed(verifyPrivateVarList(*this)))
1145+
return failure();
1146+
10681147
return verifyReductionVarList(*this, getReductions(), getReductionVars());
10691148
}
10701149

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1801,3 +1801,59 @@ omp.private {type = firstprivate} @x.privatizer : f32 alloc {
18011801
^bb0(%arg0: f32):
18021802
omp.yield(%arg0 : f32)
18031803
}
1804+
1805+
// -----
1806+
1807+
func.func @private_type_mismatch(%arg0: index) {
1808+
// expected-error @below {{type mismatch between a private variable and its privatizer op, var type: 'index' vs. privatizer op type: '!llvm.ptr'}}
1809+
omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
1810+
omp.terminator
1811+
}
1812+
1813+
return
1814+
}
1815+
1816+
omp.private {type = private} @var1.privatizer : !llvm.ptr alloc {
1817+
^bb0(%arg0: !llvm.ptr):
1818+
omp.yield(%arg0 : !llvm.ptr)
1819+
}
1820+
1821+
// -----
1822+
1823+
func.func @firstprivate_type_mismatch(%arg0: index) {
1824+
// expected-error @below {{type mismatch between a firstprivate variable and its privatizer op, var type: 'index' vs. privatizer op type: '!llvm.ptr'}}
1825+
omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
1826+
omp.terminator
1827+
}
1828+
1829+
return
1830+
}
1831+
1832+
omp.private {type = firstprivate} @var1.privatizer : !llvm.ptr alloc {
1833+
^bb0(%arg0: !llvm.ptr):
1834+
omp.yield(%arg0 : !llvm.ptr)
1835+
} copy {
1836+
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
1837+
omp.yield(%arg0 : !llvm.ptr)
1838+
}
1839+
1840+
// -----
1841+
1842+
func.func @undefined_privatizer(%arg0: index) {
1843+
// expected-error @below {{failed to lookup privatizer op with symbol: '@var1.privatizer'}}
1844+
omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
1845+
omp.terminator
1846+
}
1847+
1848+
return
1849+
}
1850+
1851+
// -----
1852+
func.func @undefined_privatizer(%arg0: !llvm.ptr) {
1853+
// expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
1854+
"omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1>, privatizers = [@x.privatizer, @y.privatizer]}> ({
1855+
^bb0(%arg2: !llvm.ptr):
1856+
omp.terminator
1857+
}) : (!llvm.ptr) -> ()
1858+
return
1859+
}

mlir/test/Dialect/OpenMP/ops-2.mlir

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
2+
3+
// CHECK-LABEL: parallel_op_privatizers
4+
func.func @parallel_op_privatizers(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
5+
// CHECK: omp.parallel private(@x.privatizer %arg0 -> %arg2 : !llvm.ptr, @y.privatizer %arg1 -> %arg3 : !llvm.ptr)
6+
omp.parallel private(@x.privatizer %arg0 -> %arg2 : !llvm.ptr, @y.privatizer %arg1 -> %arg3 : !llvm.ptr) {
7+
%0 = llvm.load %arg2 : !llvm.ptr -> i32
8+
%1 = llvm.load %arg3 : !llvm.ptr -> i32
9+
omp.terminator
10+
}
11+
return
12+
}
13+
14+
// CHECK: omp.private {type = private} @x.privatizer : !llvm.ptr alloc {
15+
omp.private {type = private} @x.privatizer : !llvm.ptr alloc {
16+
// CHECK: ^bb0(%arg0: {{.*}}):
17+
^bb0(%arg0: !llvm.ptr):
18+
omp.yield(%arg0 : !llvm.ptr)
19+
}
20+
21+
// CHECK: omp.private {type = firstprivate} @y.privatizer : !llvm.ptr alloc {
22+
omp.private {type = firstprivate} @y.privatizer : !llvm.ptr alloc {
23+
// CHECK: ^bb0(%arg0: {{.*}}):
24+
^bb0(%arg0: !llvm.ptr):
25+
omp.yield(%arg0 : !llvm.ptr)
26+
// CHECK: } copy {
27+
} copy {
28+
// CHECK: ^bb0(%arg0: {{.*}}, %arg1: {{.*}}):
29+
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
30+
omp.yield(%arg0 : !llvm.ptr)
31+
}
32+
33+
// CHECK-LABEL: parallel_op_reduction_and_private
34+
func.func @parallel_op_reduction_and_private(%priv_var: !llvm.ptr, %priv_var2: !llvm.ptr, %reduc_var: !llvm.ptr, %reduc_var2: !llvm.ptr) {
35+
// CHECK: omp.parallel
36+
// CHECK-SAME: reduction(
37+
// CHECK-SAME: @add_f32 %[[reduc_var:[0-9a-z]+]] -> %[[reduc_arg:[0-9a-z]+]] : !llvm.ptr,
38+
// CHECK-SAME: @add_f32 %[[reduc_var2:[0-9a-z]+]] -> %[[reduc_arg2:[0-9a-z]+]] : !llvm.ptr)
39+
//
40+
// CHECK-SAME: private(
41+
// CHECK-SAME: @x.privatizer %[[priv_var:[0-9a-z]+]] -> %[[priv_arg:[0-9a-z]+]] : !llvm.ptr,
42+
// CHECK-SAME: @y.privatizer %[[priv_var2:[0-9a-z]+]] -> %[[priv_arg2:[0-9a-z]+]] : !llvm.ptr)
43+
omp.parallel reduction(@add_f32 %reduc_var -> %reduc_arg : !llvm.ptr, @add_f32 %reduc_var2 -> %reduc_arg2 : !llvm.ptr)
44+
private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr, @y.privatizer %priv_var2 -> %priv_arg2 : !llvm.ptr) {
45+
// CHECK: llvm.load %[[priv_arg]]
46+
%0 = llvm.load %priv_arg : !llvm.ptr -> f32
47+
// CHECK: llvm.load %[[priv_arg2]]
48+
%1 = llvm.load %priv_arg2 : !llvm.ptr -> f32
49+
// CHECK: llvm.load %[[reduc_arg]]
50+
%2 = llvm.load %reduc_arg : !llvm.ptr -> f32
51+
// CHECK: llvm.load %[[reduc_arg2]]
52+
%3 = llvm.load %reduc_arg2 : !llvm.ptr -> f32
53+
omp.terminator
54+
}
55+
return
56+
}
57+
58+
omp.reduction.declare @add_f32 : f32
59+
init {
60+
^bb0(%arg: f32):
61+
%0 = arith.constant 0.0 : f32
62+
omp.yield (%0 : f32)
63+
}
64+
combiner {
65+
^bb1(%arg0: f32, %arg1: f32):
66+
%1 = arith.addf %arg0, %arg1 : f32
67+
omp.yield (%1 : f32)
68+
}
69+
atomic {
70+
^bb2(%arg2: !llvm.ptr, %arg3: !llvm.ptr):
71+
%2 = llvm.load %arg3 : !llvm.ptr -> f32
72+
llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32
73+
omp.yield
74+
}

0 commit comments

Comments
 (0)