Skip to content

Commit bd8af02

Browse files
committed
[mlir][scf] Add reductions support to scf.parallel fusion
1 parent a960703 commit bd8af02

File tree

2 files changed

+165
-9
lines changed

2 files changed

+165
-9
lines changed

mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -131,29 +131,63 @@ static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
131131
}
132132

133133
/// Prepends operations of firstPloop's body into secondPloop's body.
134-
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
135-
OpBuilder b,
134+
/// Updates secondPloop with new loop.
135+
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
136+
OpBuilder builder,
136137
llvm::function_ref<bool(Value, Value)> mayAlias) {
138+
Block *block1 = firstPloop.getBody();
139+
Block *block2 = secondPloop.getBody();
137140
IRMapping firstToSecondPloopIndices;
138-
firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
139-
secondPloop.getBody()->getArguments());
141+
firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
140142

141143
if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
142144
mayAlias))
143145
return;
144146

145-
b.setInsertionPointToStart(secondPloop.getBody());
146-
for (auto &op : firstPloop.getBody()->without_terminator())
147-
b.clone(op, firstToSecondPloopIndices);
147+
DominanceInfo dom;
148+
for (Operation *user : firstPloop->getUsers())
149+
if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
150+
return;
151+
152+
ValueRange inits1 = firstPloop.getInitVals();
153+
ValueRange inits2 = secondPloop.getInitVals();
154+
155+
SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
156+
newInitVars.append(inits2.begin(), inits2.end());
157+
158+
IRRewriter b(builder);
159+
b.setInsertionPoint(secondPloop);
160+
auto newSecondPloop = b.create<ParallelOp>(
161+
secondPloop.getLoc(), secondPloop.getLowerBound(),
162+
secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
163+
164+
Block *newBlock = newSecondPloop.getBody();
165+
newBlock->getTerminator()->erase();
166+
167+
block1->getTerminator()->erase();
168+
169+
b.inlineBlockBefore(block1, newBlock, newBlock->end(),
170+
newBlock->getArguments());
171+
b.inlineBlockBefore(block2, newBlock, newBlock->end(),
172+
newBlock->getArguments());
173+
174+
ValueRange results = newSecondPloop.getResults();
175+
firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
176+
secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
148177
firstPloop.erase();
178+
secondPloop.erase();
179+
secondPloop = newSecondPloop;
149180
}
150181

151182
void mlir::scf::naivelyFuseParallelOps(
152183
Region &region, llvm::function_ref<bool(Value, Value)> mayAlias) {
153184
OpBuilder b(region);
154185
// Consider every single block and attempt to fuse adjacent loops.
186+
SmallVector<SmallVector<ParallelOp>, 1> ploopChains;
155187
for (auto &block : region) {
156-
SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}};
188+
ploopChains.clear();
189+
ploopChains.push_back({});
190+
157191
// Not using `walk()` to traverse only top-level parallel loops and also
158192
// make sure that there are no side-effecting ops between the parallel
159193
// loops.
@@ -171,7 +205,7 @@ void mlir::scf::naivelyFuseParallelOps(
171205
// TODO: Handle region side effects properly.
172206
noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0;
173207
}
174-
for (ArrayRef<ParallelOp> ploops : ploopChains) {
208+
for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
175209
for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
176210
fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
177211
}

mlir/test/Dialect/SCF/parallel-loop-fusion.mlir

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,3 +387,125 @@ func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
387387
// CHECK-LABEL: func @do_not_fuse_alias
388388
// CHECK: scf.parallel
389389
// CHECK: scf.parallel
390+
391+
// -----
392+
393+
func.func @fuse_reductions(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
394+
%c2 = arith.constant 2 : index
395+
%c0 = arith.constant 0 : index
396+
%c1 = arith.constant 1 : index
397+
%init1 = arith.constant 1.0 : f32
398+
%init2 = arith.constant 2.0 : f32
399+
%res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
400+
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
401+
scf.reduce(%A_elem) : f32 {
402+
^bb0(%lhs: f32, %rhs: f32):
403+
%1 = arith.addf %lhs, %rhs : f32
404+
scf.reduce.return %1 : f32
405+
}
406+
scf.yield
407+
}
408+
%res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
409+
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
410+
scf.reduce(%B_elem) : f32 {
411+
^bb0(%lhs: f32, %rhs: f32):
412+
%1 = arith.mulf %lhs, %rhs : f32
413+
scf.reduce.return %1 : f32
414+
}
415+
scf.yield
416+
}
417+
return %res1, %res2 : f32, f32
418+
}
419+
420+
// CHECK-LABEL: func @fuse_reductions
421+
// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>)
422+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
423+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
424+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
425+
// CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32
426+
// CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32
427+
// CHECK: %[[RES:.*]]:2 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
428+
// CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
429+
// CHECK-SAME: init (%[[INIT1]], %[[INIT2]]) -> (f32, f32)
430+
// CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
431+
// CHECK: scf.reduce(%[[VAL_A]]) : f32 {
432+
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
433+
// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
434+
// CHECK: scf.reduce.return %[[R]] : f32
435+
// CHECK: }
436+
// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
437+
// CHECK: scf.reduce(%[[VAL_B]]) : f32 {
438+
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
439+
// CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
440+
// CHECK: scf.reduce.return %[[R]] : f32
441+
// CHECK: }
442+
// CHECK: scf.yield
443+
// CHECK: return %[[RES]]#0, %[[RES]]#1
444+
445+
// -----
446+
447+
func.func @reductions_use_res(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
448+
%c2 = arith.constant 2 : index
449+
%c0 = arith.constant 0 : index
450+
%c1 = arith.constant 1 : index
451+
%init1 = arith.constant 1.0 : f32
452+
%res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
453+
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
454+
scf.reduce(%A_elem) : f32 {
455+
^bb0(%lhs: f32, %rhs: f32):
456+
%1 = arith.addf %lhs, %rhs : f32
457+
scf.reduce.return %1 : f32
458+
}
459+
scf.yield
460+
}
461+
%res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%res1) -> f32 {
462+
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
463+
scf.reduce(%B_elem) : f32 {
464+
^bb0(%lhs: f32, %rhs: f32):
465+
%1 = arith.mulf %lhs, %rhs : f32
466+
scf.reduce.return %1 : f32
467+
}
468+
scf.yield
469+
}
470+
return %res1, %res2 : f32, f32
471+
}
472+
473+
// %res1 is used as second scf.parallel arg, cannot fuse
474+
// CHECK-LABEL: func @reductions_use_res
475+
// CHECK: scf.parallel
476+
// CHECK: scf.parallel
477+
478+
// -----
479+
480+
func.func @reductions_use_res_inside(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
481+
%c2 = arith.constant 2 : index
482+
%c0 = arith.constant 0 : index
483+
%c1 = arith.constant 1 : index
484+
%init1 = arith.constant 1.0 : f32
485+
%init2 = arith.constant 2.0 : f32
486+
%res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
487+
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
488+
scf.reduce(%A_elem) : f32 {
489+
^bb0(%lhs: f32, %rhs: f32):
490+
%1 = arith.addf %lhs, %rhs : f32
491+
scf.reduce.return %1 : f32
492+
}
493+
scf.yield
494+
}
495+
%res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
496+
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
497+
%sum = arith.addf %B_elem, %res1 : f32
498+
scf.reduce(%sum) : f32 {
499+
^bb0(%lhs: f32, %rhs: f32):
500+
%1 = arith.mulf %lhs, %rhs : f32
501+
scf.reduce.return %1 : f32
502+
}
503+
scf.yield
504+
}
505+
return %res1, %res2 : f32, f32
506+
}
507+
508+
// %res1 is used inside second scf.parallel arg, cannot fuse
509+
// CHECK-LABEL: func @reductions_use_res_inside
510+
// CHECK: scf.parallel
511+
// CHECK: scf.parallel

0 commit comments

Comments
 (0)