Skip to content

Commit 8843d54

Browse files
authored
[MLIR][SCF] Update scf.parallel lowering to OpenMP (3/5) (#89212)
This patch makes changes to the `scf.parallel` to `omp.parallel` + `omp.wsloop` lowering pass in order to introduce a nested `omp.loop_nest` as well, and to follow the new loop wrapper role for `omp.wsloop`. This PR on its own will not pass premerge tests. All patches in the stack are needed before it can be compiled and passes tests.
1 parent 1465299 commit 8843d54

File tree

3 files changed

+67
-20
lines changed

3 files changed

+67
-20
lines changed

mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -461,18 +461,50 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
461461
// Replace the loop.
462462
{
463463
OpBuilder::InsertionGuard allocaGuard(rewriter);
464-
auto loop = rewriter.create<omp::WsloopOp>(
464+
// Create worksharing loop wrapper.
465+
auto wsloopOp = rewriter.create<omp::WsloopOp>(parallelOp.getLoc());
466+
if (!reductionVariables.empty()) {
467+
wsloopOp.setReductionsAttr(
468+
ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols));
469+
wsloopOp.getReductionVarsMutable().append(reductionVariables);
470+
}
471+
rewriter.create<omp::TerminatorOp>(loc); // omp.parallel terminator.
472+
473+
// The wrapper's entry block arguments will define the reduction
474+
// variables.
475+
llvm::SmallVector<mlir::Type> reductionTypes;
476+
reductionTypes.reserve(reductionVariables.size());
477+
llvm::transform(reductionVariables, std::back_inserter(reductionTypes),
478+
[](mlir::Value v) { return v.getType(); });
479+
rewriter.createBlock(
480+
&wsloopOp.getRegion(), {}, reductionTypes,
481+
llvm::SmallVector<mlir::Location>(reductionVariables.size(),
482+
parallelOp.getLoc()));
483+
484+
rewriter.setInsertionPoint(
485+
rewriter.create<omp::TerminatorOp>(parallelOp.getLoc()));
486+
487+
// Create loop nest and populate region with contents of scf.parallel.
488+
auto loopOp = rewriter.create<omp::LoopNestOp>(
465489
parallelOp.getLoc(), parallelOp.getLowerBound(),
466490
parallelOp.getUpperBound(), parallelOp.getStep());
467-
rewriter.create<omp::TerminatorOp>(loc);
468491

469-
rewriter.inlineRegionBefore(parallelOp.getRegion(), loop.getRegion(),
470-
loop.getRegion().begin());
492+
rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(),
493+
loopOp.getRegion().begin());
471494

472-
Block *ops = rewriter.splitBlock(&*loop.getRegion().begin(),
473-
loop.getRegion().begin()->begin());
495+
// Remove reduction-related block arguments from omp.loop_nest and
496+
// redirect uses to the corresponding omp.wsloop block argument.
497+
mlir::Block &loopOpEntryBlock = loopOp.getRegion().front();
498+
unsigned numLoops = parallelOp.getNumLoops();
499+
rewriter.replaceAllUsesWith(
500+
loopOpEntryBlock.getArguments().drop_front(numLoops),
501+
wsloopOp.getRegion().getArguments());
502+
loopOpEntryBlock.eraseArguments(
503+
numLoops, loopOpEntryBlock.getNumArguments() - numLoops);
474504

475-
rewriter.setInsertionPointToStart(&*loop.getRegion().begin());
505+
Block *ops =
506+
rewriter.splitBlock(&loopOpEntryBlock, loopOpEntryBlock.begin());
507+
rewriter.setInsertionPointToStart(&loopOpEntryBlock);
476508

477509
auto scope = rewriter.create<memref::AllocaScopeOp>(parallelOp.getLoc(),
478510
TypeRange());
@@ -481,11 +513,6 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
481513
rewriter.mergeBlocks(ops, scopeBlock);
482514
rewriter.setInsertionPointToEnd(&*scope.getBodyRegion().begin());
483515
rewriter.create<memref::AllocaScopeReturnOp>(loc, ValueRange());
484-
if (!reductionVariables.empty()) {
485-
loop.setReductionsAttr(
486-
ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols));
487-
loop.getReductionVarsMutable().append(reductionVariables);
488-
}
489516
}
490517
}
491518

mlir/test/Conversion/SCFToOpenMP/reductions.mlir

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ func.func @reduction1(%arg0 : index, %arg1 : index, %arg2 : index,
2828
// CHECK: omp.parallel
2929
// CHECK: omp.wsloop
3030
// CHECK-SAME: reduction(@[[$REDF]] %[[BUF]] -> %[[PVT_BUF:[a-z0-9]+]]
31+
// CHECK: omp.loop_nest
3132
// CHECK: memref.alloca_scope
3233
scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
3334
step (%arg4, %step) init (%zero) -> (f32) {
@@ -43,6 +44,7 @@ func.func @reduction1(%arg0 : index, %arg1 : index, %arg2 : index,
4344
}
4445
// CHECK: omp.yield
4546
}
47+
// CHECK: omp.terminator
4648
// CHECK: omp.terminator
4749
// CHECK: llvm.load %[[BUF]]
4850
return
@@ -107,6 +109,7 @@ func.func @reduction_muli(%arg0 : index, %arg1 : index, %arg2 : index,
107109
%one = arith.constant 1 : i32
108110
// CHECK: %[[RED_VAR:.*]] = llvm.alloca %{{.*}} x i32 : (i64) -> !llvm.ptr
109111
// CHECK: omp.wsloop reduction(@[[$REDI]] %[[RED_VAR]] -> %[[RED_PVT_VAR:.*]] : !llvm.ptr)
112+
// CHECK: omp.loop_nest
110113
scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
111114
step (%arg4, %step) init (%one) -> (i32) {
112115
// CHECK: %[[C2:.*]] = arith.constant 2 : i32
@@ -208,6 +211,7 @@ func.func @reduction4(%arg0 : index, %arg1 : index, %arg2 : index,
208211
// CHECK: omp.wsloop
209212
// CHECK-SAME: reduction(@[[$REDF1]] %[[BUF1]] -> %[[PVT_BUF1:[a-z0-9]+]]
210213
// CHECK-SAME: @[[$REDF2]] %[[BUF2]] -> %[[PVT_BUF2:[a-z0-9]+]]
214+
// CHECK: omp.loop_nest
211215
// CHECK: memref.alloca_scope
212216
%res:2 = scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
213217
step (%arg4, %step) init (%zero, %ione) -> (f32, i64) {
@@ -236,6 +240,7 @@ func.func @reduction4(%arg0 : index, %arg1 : index, %arg2 : index,
236240
}
237241
// CHECK: omp.yield
238242
}
243+
// CHECK: omp.terminator
239244
// CHECK: omp.terminator
240245
// CHECK: %[[RES1:.*]] = llvm.load %[[BUF1]] : !llvm.ptr -> f32
241246
// CHECK: %[[RES2:.*]] = llvm.load %[[BUF2]] : !llvm.ptr -> i64

mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,20 @@
22

33
// CHECK-LABEL: @parallel
44
func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
5-
%arg3: index, %arg4: index, %arg5: index) {
5+
%arg3: index, %arg4: index, %arg5: index) {
66
// CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
77
// CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
8-
// CHECK: omp.wsloop for (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
8+
// CHECK: omp.wsloop {
9+
// CHECK: omp.loop_nest (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
910
// CHECK: memref.alloca_scope
1011
scf.parallel (%i, %j) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
1112
// CHECK: "test.payload"(%[[LVAR1]], %[[LVAR2]]) : (index, index) -> ()
1213
"test.payload"(%i, %j) : (index, index) -> ()
1314
// CHECK: omp.yield
1415
// CHECK: }
1516
}
17+
// CHECK: omp.terminator
18+
// CHECK: }
1619
// CHECK: omp.terminator
1720
// CHECK: }
1821
return
@@ -23,20 +26,26 @@ func.func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
2326
%arg3: index, %arg4: index, %arg5: index) {
2427
// CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
2528
// CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
26-
// CHECK: omp.wsloop for (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
27-
// CHECK: memref.alloca_scope
29+
// CHECK: omp.wsloop {
30+
// CHECK: omp.loop_nest (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
31+
// CHECK: memref.alloca_scope
2832
scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
2933
// CHECK: omp.parallel
30-
// CHECK: omp.wsloop for (%[[LVAR_IN1:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
34+
// CHECK: omp.wsloop {
35+
// CHECK: omp.loop_nest (%[[LVAR_IN1:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
3136
// CHECK: memref.alloca_scope
3237
scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {
3338
// CHECK: "test.payload"(%[[LVAR_OUT1]], %[[LVAR_IN1]]) : (index, index) -> ()
3439
"test.payload"(%i, %j) : (index, index) -> ()
3540
// CHECK: }
3641
}
37-
// CHECK: omp.yield
42+
// CHECK: omp.yield
43+
// CHECK: }
44+
// CHECK: omp.terminator
3845
// CHECK: }
3946
}
47+
// CHECK: omp.terminator
48+
// CHECK: }
4049
// CHECK: omp.terminator
4150
// CHECK: }
4251
return
@@ -47,27 +56,33 @@ func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
4756
%arg3: index, %arg4: index, %arg5: index) {
4857
// CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
4958
// CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
50-
// CHECK: omp.wsloop for (%[[LVAR_AL1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
59+
// CHECK: omp.wsloop {
60+
// CHECK: omp.loop_nest (%[[LVAR_AL1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
5161
// CHECK: memref.alloca_scope
5262
scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
5363
// CHECK: "test.payload1"(%[[LVAR_AL1]]) : (index) -> ()
5464
"test.payload1"(%i) : (index) -> ()
5565
// CHECK: omp.yield
5666
// CHECK: }
5767
}
68+
// CHECK: omp.terminator
69+
// CHECK: }
5870
// CHECK: omp.terminator
5971
// CHECK: }
6072

6173
// CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
6274
// CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
63-
// CHECK: omp.wsloop for (%[[LVAR_AL2:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
75+
// CHECK: omp.wsloop {
76+
// CHECK: omp.loop_nest (%[[LVAR_AL2:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
6477
// CHECK: memref.alloca_scope
6578
scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {
6679
// CHECK: "test.payload2"(%[[LVAR_AL2]]) : (index) -> ()
6780
"test.payload2"(%j) : (index) -> ()
6881
// CHECK: omp.yield
6982
// CHECK: }
7083
}
84+
// CHECK: omp.terminator
85+
// CHECK: }
7186
// CHECK: omp.terminator
7287
// CHECK: }
7388
return

0 commit comments

Comments
 (0)