Skip to content

[mlir][SCF] scf.parallel: Make reductions part of the terminator #75314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 61 additions & 50 deletions mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ def ParallelOp : SCF_Op<"parallel",
"getSingleLowerBound", "getSingleUpperBound", "getSingleStep"]>,
RecursiveMemoryEffects,
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
SingleBlockImplicitTerminator<"scf::YieldOp">]> {
SingleBlockImplicitTerminator<"scf::ReduceOp">]> {
let summary = "parallel for operation";
let description = [{
The "scf.parallel" operation represents a loop nest taking 4 groups of SSA
Expand All @@ -791,27 +791,36 @@ def ParallelOp : SCF_Op<"parallel",

The parallel loop operation supports reduction of values produced by
individual iterations into a single result. This is modeled using the
scf.reduce operation (see scf.reduce for details). Each result of a
scf.parallel operation is associated with an initial value operand and
reduce operation that is an immediate child. Reductions are matched to
result and initial values in order of their appearance in the body.
Consequently, we require that the body region has the same number of
results and initial values as it has reduce operations.

The body region must contain exactly one block that terminates with
"scf.yield" without operands. Parsing ParallelOp will create such a region
and insert the terminator when it is absent from the custom format.
"scf.reduce" terminator operation (see "scf.reduce" for details). The i-th
result of an "scf.parallel" operation is associated with the i-th initial
value operand, the i-th operand of the "scf.reduce" operation (the value to
be reduced) and the i-th region of the "scf.reduce" operation (the reduction
function). Consequently, we require that the number of results of an
"scf.parallel" op matches the number of initial values and the the number of
reductions in the "scf.reduce" terminator.

The body region must contain exactly one block that terminates with a
"scf.reduce" operation. If an "scf.parallel" op has no reductions, the
terminator has no operands and no regions. The "scf.parallel" parser will
automatically insert the terminator for ops that have no reductions if it is
absent.

Example:

```mlir
%init = arith.constant 0.0 : f32
scf.parallel (%iv) = (%lb) to (%ub) step (%step) init (%init) -> f32 {
%elem_to_reduce = load %buffer[%iv] : memref<100xf32>
scf.reduce(%elem_to_reduce) : f32 {
%r:2 = scf.parallel (%iv) = (%lb) to (%ub) step (%step) init (%init, %init)
-> f32, f32 {
%elem_to_reduce1 = load %buffer1[%iv] : memref<100xf32>
%elem_to_reduce2 = load %buffer2[%iv] : memref<100xf32>
scf.reduce(%elem_to_reduce1, %elem_to_reduce2 : f32, f32) {
^bb0(%lhs : f32, %rhs: f32):
%res = arith.addf %lhs, %rhs : f32
scf.reduce.return %res : f32
}, {
^bb0(%lhs : f32, %rhs: f32):
%res = arith.mulf %lhs, %rhs : f32
scf.reduce.return %res : f32
}
}
```
Expand Down Expand Up @@ -853,36 +862,36 @@ def ParallelOp : SCF_Op<"parallel",
// ReduceOp
//===----------------------------------------------------------------------===//

def ReduceOp : SCF_Op<"reduce", [HasParent<"ParallelOp">]> {
let summary = "reduce operation for parallel for";
def ReduceOp : SCF_Op<"reduce", [
Terminator, HasParent<"ParallelOp">, RecursiveMemoryEffects,
DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>]> {
let summary = "reduce operation for scf.parallel";
let description = [{
"scf.reduce" is an operation occurring inside "scf.parallel" operations.
It consists of one block with two arguments which have the same type as the
operand of "scf.reduce".

"scf.reduce" is used to model the value for reduction computations of a
"scf.parallel" operation. It has to appear as an immediate child of a
"scf.parallel" and is associated with a result value of its parent
operation.

Association is in the order of appearance in the body where the first
result of a parallel loop operation corresponds to the first "scf.reduce"
in the operation's body region. The reduce operation takes a single
operand, which is the value to be used in the reduction.

The reduce operation contains a region whose entry block expects two
arguments of the same type as the operand. As the iteration order of the
parallel loop and hence reduction order is unspecified, the result of
reduction may be non-deterministic unless the operation is associative and
commutative.

The result of the reduce operation's body must have the same type as the
operands and associated result value of the parallel loop operation.
"scf.reduce" is the terminator for "scf.parallel" operations. It can model
an arbitrary number of reductions. It has one region per reduction. Each
region has one block with two arguments which have the same type as the
corresponding operand of "scf.reduce". The operands of the op are the values
that should be reduce; one value per reduction.

The i-th reduction (i.e., the i-th region and the i-th operand) corresponds
the i-th initial value and the i-th result of the enclosing "scf.parallel"
op.

The "scf.reduce" operation contains regions whose entry blocks expect two
arguments of the same type as the corresponding operand. As the iteration
order of the enclosing parallel loop and hence reduction order is
unspecified, the results of the reductions may be non-deterministic unless
the reductions are associative and commutative.

The result of a reduction region ("scf.reduce.return" operand) must have the
same type as the corresponding "scf.reduce" operand and the corresponding
"scf.parallel" initial value.

Example:

```mlir
%operand = arith.constant 1.0 : f32
scf.reduce(%operand) : f32 {
scf.reduce(%operand : f32) {
^bb0(%lhs : f32, %rhs: f32):
%res = arith.addf %lhs, %rhs : f32
scf.reduce.return %res : f32
Expand All @@ -892,14 +901,15 @@ def ReduceOp : SCF_Op<"reduce", [HasParent<"ParallelOp">]> {

let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins "Value":$operand,
CArg<"function_ref<void (OpBuilder &, Location, Value, Value)>",
"nullptr">:$bodyBuilderFn)>
OpBuilder<(ins "ValueRange":$operands)>,
OpBuilder<(ins)>
];

let arguments = (ins AnyType:$operand);
let hasCustomAssemblyFormat = 1;
let regions = (region SizedRegion<1>:$reductionOperator);
let arguments = (ins Variadic<AnyType>:$operands);
let assemblyFormat = [{
(`(` $operands^ `:` type($operands) `)`)? $reductions attr-dict
}];
let regions = (region VariadicRegion<SizedRegion<1>>:$reductions);
let hasRegionVerifier = 1;
}

Expand All @@ -908,13 +918,14 @@ def ReduceOp : SCF_Op<"reduce", [HasParent<"ParallelOp">]> {
//===----------------------------------------------------------------------===//

def ReduceReturnOp :
SCF_Op<"reduce.return", [HasParent<"ReduceOp">, Pure,
Terminator]> {
SCF_Op<"reduce.return", [HasParent<"ReduceOp">, Pure, Terminator]> {
let summary = "terminator for reduce operation";
let description = [{
"scf.reduce.return" is a special terminator operation for the block inside
"scf.reduce". It terminates the region. It should have the same type as
the operand of "scf.reduce". Example for the custom format:
"scf.reduce" regions. It terminates the region. It should have the same
operand type as the corresponding operand of the enclosing "scf.reduce" op.

Example:

```mlir
scf.reduce.return %res : f32
Expand Down Expand Up @@ -1150,7 +1161,7 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,

def YieldOp : SCF_Op<"yield", [Pure, ReturnLike, Terminator,
ParentOneOf<["ExecuteRegionOp", "ForOp", "IfOp", "IndexSwitchOp",
"ParallelOp", "WhileOp"]>]> {
"WhileOp"]>]> {
let summary = "loop yield and termination operation";
let description = [{
"scf.yield" yields an SSA value from the SCF dialect op region and
Expand Down
27 changes: 16 additions & 11 deletions mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,9 @@ class AffineYieldOpLowering : public OpRewritePattern<AffineYieldOp> {
LogicalResult matchAndRewrite(AffineYieldOp op,
PatternRewriter &rewriter) const override {
if (isa<scf::ParallelOp>(op->getParentOp())) {
// scf.parallel does not yield any values via its terminator scf.yield but
// models reductions differently using additional ops in its region.
rewriter.replaceOpWithNewOp<scf::YieldOp>(op);
return success();
// Terminator is rewritten as part of the "affine.parallel" lowering
// pattern.
return failure();
}
rewriter.replaceOpWithNewOp<scf::YieldOp>(op, op.getOperands());
return success();
Expand Down Expand Up @@ -203,7 +202,8 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
steps.push_back(rewriter.create<arith::ConstantIndexOp>(loc, step));

// Get the terminator op.
Operation *affineParOpTerminator = op.getBody()->getTerminator();
auto affineParOpTerminator =
cast<AffineYieldOp>(op.getBody()->getTerminator());
scf::ParallelOp parOp;
if (op.getResults().empty()) {
// Case with no reduction operations/return values.
Expand All @@ -214,6 +214,8 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
rewriter.inlineRegionBefore(op.getRegion(), parOp.getRegion(),
parOp.getRegion().end());
rewriter.replaceOp(op, parOp.getResults());
rewriter.setInsertionPoint(affineParOpTerminator);
rewriter.replaceOpWithNewOp<scf::ReduceOp>(affineParOpTerminator);
return success();
}
// Case with affine.parallel with reduction operations/return values.
Expand Down Expand Up @@ -243,6 +245,11 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
parOp.getRegion().end());
assert(reductions.size() == affineParOpTerminator->getNumOperands() &&
"Unequal number of reductions and operands.");

// Emit new "scf.reduce" terminator.
rewriter.setInsertionPoint(affineParOpTerminator);
auto reduceOp = rewriter.replaceOpWithNewOp<scf::ReduceOp>(
affineParOpTerminator, affineParOpTerminator->getOperands());
for (unsigned i = 0, end = reductions.size(); i < end; i++) {
// For each of the reduction operations get the respective mlir::Value.
std::optional<arith::AtomicRMWKind> reductionOp =
Expand All @@ -251,13 +258,11 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
assert(reductionOp && "Reduction Operation cannot be of None Type");
arith::AtomicRMWKind reductionOpValue = *reductionOp;
rewriter.setInsertionPoint(&parOp.getBody()->back());
auto reduceOp = rewriter.create<scf::ReduceOp>(
loc, affineParOpTerminator->getOperand(i));
rewriter.setInsertionPointToEnd(&reduceOp.getReductionOperator().front());
Block &reductionBody = reduceOp.getReductions()[i].front();
rewriter.setInsertionPointToEnd(&reductionBody);
Value reductionResult = arith::getReductionOp(
reductionOpValue, rewriter, loc,
reduceOp.getReductionOperator().front().getArgument(0),
reduceOp.getReductionOperator().front().getArgument(1));
reductionOpValue, rewriter, loc, reductionBody.getArgument(0),
reductionBody.getArgument(1));
rewriter.create<scf::ReduceReturnOp>(loc, reductionResult);
}
rewriter.replaceOp(op, parOp.getResults());
Expand Down
24 changes: 11 additions & 13 deletions mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ LogicalResult
ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
PatternRewriter &rewriter) const {
Location loc = parallelOp.getLoc();
auto reductionOp = cast<ReduceOp>(parallelOp.getBody()->getTerminator());

// For a parallel loop, we essentially need to create an n-dimensional loop
// nest. We do this by translating to scf.for ops and have those lowered in
Expand Down Expand Up @@ -506,23 +507,20 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
}

// First, merge reduction blocks into the main region.
SmallVector<Value, 4> yieldOperands;
SmallVector<Value> yieldOperands;
yieldOperands.reserve(parallelOp.getNumResults());
for (auto &op : *parallelOp.getBody()) {
auto reduce = dyn_cast<ReduceOp>(op);
if (!reduce)
continue;

Block &reduceBlock = reduce.getReductionOperator().front();
for (int64_t i = 0, e = parallelOp.getNumResults(); i < e; ++i) {
Block &reductionBody = reductionOp.getReductions()[i].front();
Value arg = iterArgs[yieldOperands.size()];
yieldOperands.push_back(reduceBlock.getTerminator()->getOperand(0));
rewriter.eraseOp(reduceBlock.getTerminator());
rewriter.inlineBlockBefore(&reduceBlock, &op, {arg, reduce.getOperand()});
rewriter.eraseOp(reduce);
yieldOperands.push_back(
cast<ReduceReturnOp>(reductionBody.getTerminator()).getResult());
rewriter.eraseOp(reductionBody.getTerminator());
rewriter.inlineBlockBefore(&reductionBody, reductionOp,
{arg, reductionOp.getOperands()[i]});
}
rewriter.eraseOp(reductionOp);

// Then merge the loop body without the terminator.
rewriter.eraseOp(parallelOp.getBody()->getTerminator());
Block *newBody = rewriter.getInsertionBlock();
if (newBody->empty())
rewriter.mergeBlocks(parallelOp.getBody(), newBody, ivs);
Expand Down Expand Up @@ -711,7 +709,7 @@ LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
parallelOp.getRegion().begin());
// Replace the terminator.
rewriter.setInsertionPointToEnd(&parallelOp.getRegion().front());
rewriter.replaceOpWithNewOp<scf::YieldOp>(
rewriter.replaceOpWithNewOp<scf::ReduceOp>(
parallelOp.getRegion().front().getTerminator());

// Erase the scf.forall op.
Expand Down
Loading