Skip to content

Commit 10056c8

Browse files
[mlir][SCF] scf.parallel: Make reductions part of the terminator (#75314)
This commit makes reductions part of the terminator. Instead of `scf.yield`, `scf.reduce` now terminates the body of `scf.parallel` ops. `scf.reduce` may contain an arbitrary number of reductions, with one region per reduction. Example: ```mlir %init = arith.constant 0.0 : f32 %r:2 = scf.parallel (%iv) = (%lb) to (%ub) step (%step) init (%init, %init) -> f32, f32 { %elem_to_reduce1 = load %buffer1[%iv] : memref<100xf32> %elem_to_reduce2 = load %buffer2[%iv] : memref<100xf32> scf.reduce(%elem_to_reduce1, %elem_to_reduce2 : f32, f32) { ^bb0(%lhs : f32, %rhs: f32): %res = arith.addf %lhs, %rhs : f32 scf.reduce.return %res : f32 }, { ^bb0(%lhs : f32, %rhs: f32): %res = arith.mulf %lhs, %rhs : f32 scf.reduce.return %res : f32 } } ``` `scf.reduce` operations can no longer be interleaved with other ops in the body of `scf.parallel`. This simplifies the op and makes it possible to assign the `RecursiveMemoryEffects` trait to `scf.reduce`. (This was not possible before because the op was not a terminator, causing the op to be DCE'd.)
1 parent ac029e0 commit 10056c8

File tree

26 files changed

+344
-340
lines changed

26 files changed

+344
-340
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 61 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -770,7 +770,7 @@ def ParallelOp : SCF_Op<"parallel",
770770
"getSingleLowerBound", "getSingleUpperBound", "getSingleStep"]>,
771771
RecursiveMemoryEffects,
772772
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
773-
SingleBlockImplicitTerminator<"scf::YieldOp">]> {
773+
SingleBlockImplicitTerminator<"scf::ReduceOp">]> {
774774
let summary = "parallel for operation";
775775
let description = [{
776776
The "scf.parallel" operation represents a loop nest taking 4 groups of SSA
@@ -791,27 +791,36 @@ def ParallelOp : SCF_Op<"parallel",
791791

792792
The parallel loop operation supports reduction of values produced by
793793
individual iterations into a single result. This is modeled using the
794-
scf.reduce operation (see scf.reduce for details). Each result of a
795-
scf.parallel operation is associated with an initial value operand and
796-
reduce operation that is an immediate child. Reductions are matched to
797-
result and initial values in order of their appearance in the body.
798-
Consequently, we require that the body region has the same number of
799-
results and initial values as it has reduce operations.
800-
801-
The body region must contain exactly one block that terminates with
802-
"scf.yield" without operands. Parsing ParallelOp will create such a region
803-
and insert the terminator when it is absent from the custom format.
794+
"scf.reduce" terminator operation (see "scf.reduce" for details). The i-th
795+
result of an "scf.parallel" operation is associated with the i-th initial
796+
value operand, the i-th operand of the "scf.reduce" operation (the value to
797+
be reduced) and the i-th region of the "scf.reduce" operation (the reduction
798+
function). Consequently, we require that the number of results of an
799+
"scf.parallel" op matches the number of initial values and the the number of
800+
reductions in the "scf.reduce" terminator.
801+
802+
The body region must contain exactly one block that terminates with a
803+
"scf.reduce" operation. If an "scf.parallel" op has no reductions, the
804+
terminator has no operands and no regions. The "scf.parallel" parser will
805+
automatically insert the terminator for ops that have no reductions if it is
806+
absent.
804807

805808
Example:
806809

807810
```mlir
808811
%init = arith.constant 0.0 : f32
809-
scf.parallel (%iv) = (%lb) to (%ub) step (%step) init (%init) -> f32 {
810-
%elem_to_reduce = load %buffer[%iv] : memref<100xf32>
811-
scf.reduce(%elem_to_reduce) : f32 {
812+
%r:2 = scf.parallel (%iv) = (%lb) to (%ub) step (%step) init (%init, %init)
813+
-> f32, f32 {
814+
%elem_to_reduce1 = load %buffer1[%iv] : memref<100xf32>
815+
%elem_to_reduce2 = load %buffer2[%iv] : memref<100xf32>
816+
scf.reduce(%elem_to_reduce1, %elem_to_reduce2 : f32, f32) {
812817
^bb0(%lhs : f32, %rhs: f32):
813818
%res = arith.addf %lhs, %rhs : f32
814819
scf.reduce.return %res : f32
820+
}, {
821+
^bb0(%lhs : f32, %rhs: f32):
822+
%res = arith.mulf %lhs, %rhs : f32
823+
scf.reduce.return %res : f32
815824
}
816825
}
817826
```
@@ -853,36 +862,36 @@ def ParallelOp : SCF_Op<"parallel",
853862
// ReduceOp
854863
//===----------------------------------------------------------------------===//
855864

856-
def ReduceOp : SCF_Op<"reduce", [HasParent<"ParallelOp">]> {
857-
let summary = "reduce operation for parallel for";
865+
def ReduceOp : SCF_Op<"reduce", [
866+
Terminator, HasParent<"ParallelOp">, RecursiveMemoryEffects,
867+
DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>]> {
868+
let summary = "reduce operation for scf.parallel";
858869
let description = [{
859-
"scf.reduce" is an operation occurring inside "scf.parallel" operations.
860-
It consists of one block with two arguments which have the same type as the
861-
operand of "scf.reduce".
862-
863-
"scf.reduce" is used to model the value for reduction computations of a
864-
"scf.parallel" operation. It has to appear as an immediate child of a
865-
"scf.parallel" and is associated with a result value of its parent
866-
operation.
867-
868-
Association is in the order of appearance in the body where the first
869-
result of a parallel loop operation corresponds to the first "scf.reduce"
870-
in the operation's body region. The reduce operation takes a single
871-
operand, which is the value to be used in the reduction.
872-
873-
The reduce operation contains a region whose entry block expects two
874-
arguments of the same type as the operand. As the iteration order of the
875-
parallel loop and hence reduction order is unspecified, the result of
876-
reduction may be non-deterministic unless the operation is associative and
877-
commutative.
878-
879-
The result of the reduce operation's body must have the same type as the
880-
operands and associated result value of the parallel loop operation.
870+
"scf.reduce" is the terminator for "scf.parallel" operations. It can model
871+
an arbitrary number of reductions. It has one region per reduction. Each
872+
region has one block with two arguments which have the same type as the
873+
corresponding operand of "scf.reduce". The operands of the op are the values
874+
that should be reduce; one value per reduction.
875+
876+
The i-th reduction (i.e., the i-th region and the i-th operand) corresponds
877+
the i-th initial value and the i-th result of the enclosing "scf.parallel"
878+
op.
879+
880+
The "scf.reduce" operation contains regions whose entry blocks expect two
881+
arguments of the same type as the corresponding operand. As the iteration
882+
order of the enclosing parallel loop and hence reduction order is
883+
unspecified, the results of the reductions may be non-deterministic unless
884+
the reductions are associative and commutative.
885+
886+
The result of a reduction region ("scf.reduce.return" operand) must have the
887+
same type as the corresponding "scf.reduce" operand and the corresponding
888+
"scf.parallel" initial value.
889+
881890
Example:
882891

883892
```mlir
884893
%operand = arith.constant 1.0 : f32
885-
scf.reduce(%operand) : f32 {
894+
scf.reduce(%operand : f32) {
886895
^bb0(%lhs : f32, %rhs: f32):
887896
%res = arith.addf %lhs, %rhs : f32
888897
scf.reduce.return %res : f32
@@ -892,14 +901,15 @@ def ReduceOp : SCF_Op<"reduce", [HasParent<"ParallelOp">]> {
892901

893902
let skipDefaultBuilders = 1;
894903
let builders = [
895-
OpBuilder<(ins "Value":$operand,
896-
CArg<"function_ref<void (OpBuilder &, Location, Value, Value)>",
897-
"nullptr">:$bodyBuilderFn)>
904+
OpBuilder<(ins "ValueRange":$operands)>,
905+
OpBuilder<(ins)>
898906
];
899907

900-
let arguments = (ins AnyType:$operand);
901-
let hasCustomAssemblyFormat = 1;
902-
let regions = (region SizedRegion<1>:$reductionOperator);
908+
let arguments = (ins Variadic<AnyType>:$operands);
909+
let assemblyFormat = [{
910+
(`(` $operands^ `:` type($operands) `)`)? $reductions attr-dict
911+
}];
912+
let regions = (region VariadicRegion<SizedRegion<1>>:$reductions);
903913
let hasRegionVerifier = 1;
904914
}
905915

@@ -908,13 +918,14 @@ def ReduceOp : SCF_Op<"reduce", [HasParent<"ParallelOp">]> {
908918
//===----------------------------------------------------------------------===//
909919

910920
def ReduceReturnOp :
911-
SCF_Op<"reduce.return", [HasParent<"ReduceOp">, Pure,
912-
Terminator]> {
921+
SCF_Op<"reduce.return", [HasParent<"ReduceOp">, Pure, Terminator]> {
913922
let summary = "terminator for reduce operation";
914923
let description = [{
915924
"scf.reduce.return" is a special terminator operation for the block inside
916-
"scf.reduce". It terminates the region. It should have the same type as
917-
the operand of "scf.reduce". Example for the custom format:
925+
"scf.reduce" regions. It terminates the region. It should have the same
926+
operand type as the corresponding operand of the enclosing "scf.reduce" op.
927+
928+
Example:
918929

919930
```mlir
920931
scf.reduce.return %res : f32
@@ -1150,7 +1161,7 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
11501161

11511162
def YieldOp : SCF_Op<"yield", [Pure, ReturnLike, Terminator,
11521163
ParentOneOf<["ExecuteRegionOp", "ForOp", "IfOp", "IndexSwitchOp",
1153-
"ParallelOp", "WhileOp"]>]> {
1164+
"WhileOp"]>]> {
11541165
let summary = "loop yield and termination operation";
11551166
let description = [{
11561167
"scf.yield" yields an SSA value from the SCF dialect op region and

mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,9 @@ class AffineYieldOpLowering : public OpRewritePattern<AffineYieldOp> {
137137
LogicalResult matchAndRewrite(AffineYieldOp op,
138138
PatternRewriter &rewriter) const override {
139139
if (isa<scf::ParallelOp>(op->getParentOp())) {
140-
// scf.parallel does not yield any values via its terminator scf.yield but
141-
// models reductions differently using additional ops in its region.
142-
rewriter.replaceOpWithNewOp<scf::YieldOp>(op);
143-
return success();
140+
// Terminator is rewritten as part of the "affine.parallel" lowering
141+
// pattern.
142+
return failure();
144143
}
145144
rewriter.replaceOpWithNewOp<scf::YieldOp>(op, op.getOperands());
146145
return success();
@@ -203,7 +202,8 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
203202
steps.push_back(rewriter.create<arith::ConstantIndexOp>(loc, step));
204203

205204
// Get the terminator op.
206-
Operation *affineParOpTerminator = op.getBody()->getTerminator();
205+
auto affineParOpTerminator =
206+
cast<AffineYieldOp>(op.getBody()->getTerminator());
207207
scf::ParallelOp parOp;
208208
if (op.getResults().empty()) {
209209
// Case with no reduction operations/return values.
@@ -214,6 +214,8 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
214214
rewriter.inlineRegionBefore(op.getRegion(), parOp.getRegion(),
215215
parOp.getRegion().end());
216216
rewriter.replaceOp(op, parOp.getResults());
217+
rewriter.setInsertionPoint(affineParOpTerminator);
218+
rewriter.replaceOpWithNewOp<scf::ReduceOp>(affineParOpTerminator);
217219
return success();
218220
}
219221
// Case with affine.parallel with reduction operations/return values.
@@ -243,6 +245,11 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
243245
parOp.getRegion().end());
244246
assert(reductions.size() == affineParOpTerminator->getNumOperands() &&
245247
"Unequal number of reductions and operands.");
248+
249+
// Emit new "scf.reduce" terminator.
250+
rewriter.setInsertionPoint(affineParOpTerminator);
251+
auto reduceOp = rewriter.replaceOpWithNewOp<scf::ReduceOp>(
252+
affineParOpTerminator, affineParOpTerminator->getOperands());
246253
for (unsigned i = 0, end = reductions.size(); i < end; i++) {
247254
// For each of the reduction operations get the respective mlir::Value.
248255
std::optional<arith::AtomicRMWKind> reductionOp =
@@ -251,13 +258,11 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
251258
assert(reductionOp && "Reduction Operation cannot be of None Type");
252259
arith::AtomicRMWKind reductionOpValue = *reductionOp;
253260
rewriter.setInsertionPoint(&parOp.getBody()->back());
254-
auto reduceOp = rewriter.create<scf::ReduceOp>(
255-
loc, affineParOpTerminator->getOperand(i));
256-
rewriter.setInsertionPointToEnd(&reduceOp.getReductionOperator().front());
261+
Block &reductionBody = reduceOp.getReductions()[i].front();
262+
rewriter.setInsertionPointToEnd(&reductionBody);
257263
Value reductionResult = arith::getReductionOp(
258-
reductionOpValue, rewriter, loc,
259-
reduceOp.getReductionOperator().front().getArgument(0),
260-
reduceOp.getReductionOperator().front().getArgument(1));
264+
reductionOpValue, rewriter, loc, reductionBody.getArgument(0),
265+
reductionBody.getArgument(1));
261266
rewriter.create<scf::ReduceReturnOp>(loc, reductionResult);
262267
}
263268
rewriter.replaceOp(op, parOp.getResults());

mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,7 @@ LogicalResult
471471
ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
472472
PatternRewriter &rewriter) const {
473473
Location loc = parallelOp.getLoc();
474+
auto reductionOp = cast<ReduceOp>(parallelOp.getBody()->getTerminator());
474475

475476
// For a parallel loop, we essentially need to create an n-dimensional loop
476477
// nest. We do this by translating to scf.for ops and have those lowered in
@@ -506,23 +507,20 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
506507
}
507508

508509
// First, merge reduction blocks into the main region.
509-
SmallVector<Value, 4> yieldOperands;
510+
SmallVector<Value> yieldOperands;
510511
yieldOperands.reserve(parallelOp.getNumResults());
511-
for (auto &op : *parallelOp.getBody()) {
512-
auto reduce = dyn_cast<ReduceOp>(op);
513-
if (!reduce)
514-
continue;
515-
516-
Block &reduceBlock = reduce.getReductionOperator().front();
512+
for (int64_t i = 0, e = parallelOp.getNumResults(); i < e; ++i) {
513+
Block &reductionBody = reductionOp.getReductions()[i].front();
517514
Value arg = iterArgs[yieldOperands.size()];
518-
yieldOperands.push_back(reduceBlock.getTerminator()->getOperand(0));
519-
rewriter.eraseOp(reduceBlock.getTerminator());
520-
rewriter.inlineBlockBefore(&reduceBlock, &op, {arg, reduce.getOperand()});
521-
rewriter.eraseOp(reduce);
515+
yieldOperands.push_back(
516+
cast<ReduceReturnOp>(reductionBody.getTerminator()).getResult());
517+
rewriter.eraseOp(reductionBody.getTerminator());
518+
rewriter.inlineBlockBefore(&reductionBody, reductionOp,
519+
{arg, reductionOp.getOperands()[i]});
522520
}
521+
rewriter.eraseOp(reductionOp);
523522

524523
// Then merge the loop body without the terminator.
525-
rewriter.eraseOp(parallelOp.getBody()->getTerminator());
526524
Block *newBody = rewriter.getInsertionBlock();
527525
if (newBody->empty())
528526
rewriter.mergeBlocks(parallelOp.getBody(), newBody, ivs);
@@ -711,7 +709,7 @@ LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
711709
parallelOp.getRegion().begin());
712710
// Replace the terminator.
713711
rewriter.setInsertionPointToEnd(&parallelOp.getRegion().front());
714-
rewriter.replaceOpWithNewOp<scf::YieldOp>(
712+
rewriter.replaceOpWithNewOp<scf::ReduceOp>(
715713
parallelOp.getRegion().front().getTerminator());
716714

717715
// Erase the scf.forall op.

0 commit comments

Comments
 (0)