Skip to content

Commit c067c6e

Browse files
committed
[mlir][openacc] Use new private representation in acc.parallel
Update acc.parallel private operands list to use the new design introduced in D150622. Test in flang/test/Lower/OpenACC/acc-parallel.f90 and flang/test/Lower/OpenACC/acc-parallel-loop.f90 are temporarly disabled and will be enabled with updated lowering in the follow-up patch. Reviewed By: razvanlupusoru Differential Revision: https://reviews.llvm.org/D150971
1 parent c606fef commit c067c6e

File tree

5 files changed

+137
-23
lines changed

5 files changed

+137
-23
lines changed

flang/test/Lower/OpenACC/acc-parallel-loop.f90

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -442,18 +442,19 @@ subroutine acc_parallel_loop
442442
! CHECK: acc.yield
443443
! CHECK-NEXT: }{{$}}
444444

445-
!$acc parallel loop private(a) firstprivate(b)
446-
DO i = 1, n
447-
a(i) = b(i)
448-
END DO
449-
450-
! CHECK: acc.parallel firstprivate(%[[B]] : !fir.ref<!fir.array<10xf32>>) private(%[[A]] : !fir.ref<!fir.array<10xf32>>) {
451-
! CHECK: acc.loop private(%[[A]] : !fir.ref<!fir.array<10xf32>>) {
452-
! CHECK: fir.do_loop
453-
! CHECK: acc.yield
454-
! CHECK-NEXT: }{{$}}
455-
! CHECK: acc.yield
456-
! CHECK-NEXT: }{{$}}
445+
! TODO: will be updated after lowering change in privatization to MLIR
446+
! !$acc parallel loop private(a) firstprivate(b)
447+
! DO i = 1, n
448+
! a(i) = b(i)
449+
! END DO
450+
451+
! TODO: acc.parallel firstprivate(%[[B]] : !fir.ref<!fir.array<10xf32>>) private(%[[A]] : !fir.ref<!fir.array<10xf32>>) {
452+
! TODO: acc.loop private(%[[A]] : !fir.ref<!fir.array<10xf32>>) {
453+
! TODO: fir.do_loop
454+
! TODO: acc.yield
455+
! TODO-NEXT: }{{$}}
456+
! TODO: acc.yield
457+
! TODO-NEXT: }{{$}}
457458

458459
!$acc parallel loop seq
459460
DO i = 1, n

flang/test/Lower/OpenACC/acc-parallel.f90

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -288,11 +288,12 @@ subroutine acc_parallel
288288
!CHECK: acc.detach accPtr(%[[ATTACH_D]] : !fir.ptr<f32>) {dataClause = 10 : i64, name = "d"}
289289
!CHECK: acc.detach accPtr(%[[ATTACH_E]] : !fir.ptr<f32>) {dataClause = 10 : i64, name = "e"}
290290

291-
!$acc parallel private(a) firstprivate(b) private(c)
292-
!$acc end parallel
291+
! TODO: will be updated after lowering change in privatization to MLIR
292+
! !$acc parallel private(a) firstprivate(b) private(c)
293+
! !$acc end parallel
293294

294-
!CHECK: acc.parallel firstprivate(%[[B]] : !fir.ref<!fir.array<10x10xf32>>) private(%[[A]], %[[C]] : !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>) {
295-
!CHECK: acc.yield
296-
!CHECK-NEXT: }{{$}}
295+
!TODO: acc.parallel firstprivate(%[[B]] : !fir.ref<!fir.array<10x10xf32>>) private(%[[A]], %[[C]] : !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>) {
296+
!TODO: acc.yield
297+
!TODO-NEXT: }{{$}}
297298

298299
end subroutine acc_parallel

mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,8 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
636636
UnitAttr:$selfAttr,
637637
OptionalAttr<OpenACC_ReductionOperatorAttr>:$reductionOp,
638638
Variadic<AnyType>:$reductionOperands,
639-
Variadic<AnyType>:$gangPrivateOperands,
639+
Variadic<OpenACC_PointerLikeTypeInterface>:$gangPrivateOperands,
640+
OptionalAttr<SymbolRefArrayAttr>:$privatizations,
640641
Variadic<AnyType>:$gangFirstPrivateOperands,
641642
Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
642643
OptionalAttr<DefaultValueAttr>:$defaultAttr);
@@ -659,7 +660,9 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
659660
type($gangFirstPrivateOperands) `)`
660661
| `num_gangs` `(` $numGangs `:` type($numGangs) `)`
661662
| `num_workers` `(` $numWorkers `:` type($numWorkers) `)`
662-
| `private` `(` $gangPrivateOperands `:` type($gangPrivateOperands) `)`
663+
| `private` `(` custom<PrivatizationList>(
664+
$gangPrivateOperands, type($gangPrivateOperands), $privatizations)
665+
`)`
663666
| `vector_length` `(` $vectorLength `:` type($vectorLength) `)`
664667
| `wait` `(` $waitOperands `:` type($waitOperands) `)`
665668
| `self` `(` $selfCond `)`

mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,43 @@ LogicalResult acc::ReductionRecipeOp::verifyRegions() {
436436
return success();
437437
}
438438

439+
//===----------------------------------------------------------------------===//
440+
// Custom parser and printer verifier for private clause
441+
//===----------------------------------------------------------------------===//
442+
443+
static ParseResult parsePrivatizationList(
444+
mlir::OpAsmParser &parser,
445+
llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
446+
llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &privatizationSymbols) {
447+
llvm::SmallVector<SymbolRefAttr> privatizationVec;
448+
if (failed(parser.parseCommaSeparatedList([&]() {
449+
if (parser.parseAttribute(privatizationVec.emplace_back()) ||
450+
parser.parseArrow() ||
451+
parser.parseOperand(operands.emplace_back()) ||
452+
parser.parseColonType(types.emplace_back()))
453+
return failure();
454+
return success();
455+
})))
456+
return failure();
457+
llvm::SmallVector<mlir::Attribute> privatizations(privatizationVec.begin(),
458+
privatizationVec.end());
459+
privatizationSymbols = ArrayAttr::get(parser.getContext(), privatizations);
460+
return success();
461+
}
462+
463+
static void
464+
printPrivatizationList(mlir::OpAsmPrinter &p, mlir::Operation *op,
465+
mlir::OperandRange privateOperands,
466+
mlir::TypeRange privateTypes,
467+
std::optional<mlir::ArrayAttr> privatizations) {
468+
for (unsigned i = 0, e = privatizations->size(); i < e; ++i) {
469+
if (i != 0)
470+
p << ", ";
471+
p << (*privatizations)[i] << " -> " << privateOperands[i] << " : "
472+
<< privateOperands[i].getType();
473+
}
474+
}
475+
439476
//===----------------------------------------------------------------------===//
440477
// ParallelOp
441478
//===----------------------------------------------------------------------===//
@@ -455,6 +492,45 @@ static LogicalResult checkDataOperands(Op op,
455492
return success();
456493
}
457494

495+
static LogicalResult
496+
checkPrivatizationList(Operation *op,
497+
std::optional<mlir::ArrayAttr> privatizations,
498+
mlir::OperandRange privateOperands) {
499+
if (!privateOperands.empty()) {
500+
if (!privatizations || privatizations->size() != privateOperands.size())
501+
return op->emitOpError() << "expected as many privatizations symbol "
502+
"reference as private operands";
503+
} else {
504+
if (privatizations)
505+
return op->emitOpError() << "unexpected privatizations symbol reference";
506+
return success();
507+
}
508+
509+
llvm::DenseSet<Value> privates;
510+
for (auto args : llvm::zip(privateOperands, *privatizations)) {
511+
mlir::Value privateOperand = std::get<0>(args);
512+
513+
if (!privates.insert(privateOperand).second)
514+
return op->emitOpError() << "private operand appears more than once";
515+
516+
mlir::Type varType = privateOperand.getType();
517+
auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
518+
auto decl =
519+
SymbolTable::lookupNearestSymbolFrom<PrivateRecipeOp>(op, symbolRef);
520+
if (!decl)
521+
return op->emitOpError() << "expected symbol reference " << symbolRef
522+
<< " to point to a private declaration";
523+
524+
if (decl.getType() && decl.getType() != varType)
525+
return op->emitOpError()
526+
<< "expected private (" << varType
527+
<< ") to be the same type as private declaration ("
528+
<< decl.getType() << ")";
529+
}
530+
531+
return success();
532+
}
533+
458534
unsigned ParallelOp::getNumDataOperands() {
459535
return getReductionOperands().size() + getGangPrivateOperands().size() +
460536
getGangFirstPrivateOperands().size() + getDataClauseOperands().size();
@@ -471,6 +547,9 @@ Value ParallelOp::getDataOperand(unsigned i) {
471547
}
472548

473549
LogicalResult acc::ParallelOp::verify() {
550+
if (failed(checkPrivatizationList(*this, getPrivatizations(),
551+
getGangPrivateOperands())))
552+
return failure();
474553
return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
475554
}
476555

mlir/test/Dialect/OpenACC/ops.mlir

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,16 @@ func.func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x
114114

115115
// -----
116116

117+
acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
118+
^bb0(%arg0: memref<10xf32>):
119+
%0 = memref.alloc() : memref<10xf32>
120+
acc.yield %0 : memref<10xf32>
121+
} destroy {
122+
^bb0(%arg0: memref<10xf32>):
123+
memref.dealloc %arg0 : memref<10xf32>
124+
acc.terminator
125+
}
126+
117127
func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>, %d: memref<10xf32>) -> memref<10xf32> {
118128
%lb = arith.constant 0 : index
119129
%st = arith.constant 1 : index
@@ -126,7 +136,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
126136
%pc = acc.present varPtr(%c : memref<10xf32>) -> memref<10xf32>
127137
%pd = acc.present varPtr(%d : memref<10xf32>) -> memref<10xf32>
128138
acc.data dataOperands(%pa, %pb, %pc, %pd: memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) {
129-
acc.parallel num_gangs(%numGangs: i64) num_workers(%numWorkers: i64) private(%c : memref<10xf32>) {
139+
acc.parallel num_gangs(%numGangs: i64) num_workers(%numWorkers: i64) private(@privatization_memref_10_f32 -> %c : memref<10xf32>) {
130140
acc.loop gang {
131141
scf.for %x = %lb to %c10 step %st {
132142
acc.loop worker {
@@ -168,7 +178,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
168178
// CHECK-NEXT: [[NUMGANG:%.*]] = arith.constant 10 : i64
169179
// CHECK-NEXT: [[NUMWORKERS:%.*]] = arith.constant 10 : i64
170180
// CHECK: acc.data dataOperands(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) {
171-
// CHECK-NEXT: acc.parallel num_gangs([[NUMGANG]] : i64) num_workers([[NUMWORKERS]] : i64) private([[ARG2]] : memref<10xf32>) {
181+
// CHECK-NEXT: acc.parallel num_gangs([[NUMGANG]] : i64) num_workers([[NUMWORKERS]] : i64) private(@privatization_memref_10_f32 -> [[ARG2]] : memref<10xf32>) {
172182
// CHECK-NEXT: acc.loop gang {
173183
// CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] {
174184
// CHECK-NEXT: acc.loop worker {
@@ -358,6 +368,26 @@ func.func @acc_loop_multiple_block() {
358368

359369
// -----
360370

371+
acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
372+
^bb0(%arg0: memref<10xf32>):
373+
%0 = memref.alloc() : memref<10xf32>
374+
acc.yield %0 : memref<10xf32>
375+
} destroy {
376+
^bb0(%arg0: memref<10xf32>):
377+
memref.dealloc %arg0 : memref<10xf32>
378+
acc.terminator
379+
}
380+
381+
acc.private.recipe @privatization_memref_10_10_f32 : memref<10x10xf32> init {
382+
^bb0(%arg0: memref<10x10xf32>):
383+
%0 = memref.alloc() : memref<10x10xf32>
384+
acc.yield %0 : memref<10x10xf32>
385+
} destroy {
386+
^bb0(%arg0: memref<10x10xf32>):
387+
memref.dealloc %arg0 : memref<10x10xf32>
388+
acc.terminator
389+
}
390+
361391
func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () {
362392
%i64value = arith.constant 1 : i64
363393
%i32value = arith.constant 1 : i32
@@ -394,7 +424,7 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x
394424
}
395425
acc.parallel vector_length(%idxValue: index) {
396426
}
397-
acc.parallel private(%a, %c : memref<10xf32>, memref<10x10xf32>) firstprivate(%b: memref<10xf32>) {
427+
acc.parallel private(@privatization_memref_10_f32 -> %a : memref<10xf32>, @privatization_memref_10_10_f32 -> %c : memref<10x10xf32>) firstprivate(%b: memref<10xf32>) {
398428
}
399429
acc.parallel {
400430
} attributes {defaultAttr = #acc<defaultvalue none>}
@@ -445,7 +475,7 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x
445475
// CHECK-NEXT: }
446476
// CHECK: acc.parallel vector_length([[IDXVALUE]] : index) {
447477
// CHECK-NEXT: }
448-
// CHECK: acc.parallel firstprivate([[ARGB]] : memref<10xf32>) private([[ARGA]], [[ARGC]] : memref<10xf32>, memref<10x10xf32>) {
478+
// CHECK: acc.parallel firstprivate([[ARGB]] : memref<10xf32>) private(@privatization_memref_10_f32 -> [[ARGA]] : memref<10xf32>, @privatization_memref_10_10_f32 -> [[ARGC]] : memref<10x10xf32>) {
449479
// CHECK-NEXT: }
450480
// CHECK: acc.parallel {
451481
// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>}

0 commit comments

Comments
 (0)