Skip to content

Commit aff6bf4

Browse files
committed
[mlir] support conversion of parallel reduction loops to std
Recently introduced support for converting sequential reduction loops to CFG of basic blocks in the Standard dialect makes it possible to perform a staged conversion of parallel reduction loops into a similar CFG by using sequential loops as an intermediate step. This is already the case for parallel loops without reduction, so extend the pattern to support an additional use case. Differential Revision: https://reviews.llvm.org/D75599
1 parent 16c6e0f commit aff6bf4

File tree

4 files changed

+149
-12
lines changed

4 files changed

+149
-12
lines changed

mlir/include/mlir/Dialect/LoopOps/LoopOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,8 @@ def ForOp : Loop_Op<"for",
131131
let skipDefaultBuilders = 1;
132132
let builders = [
133133
OpBuilder<"Builder *builder, OperationState &result, "
134-
"Value lowerBound, Value upperBound, Value step">
134+
"Value lowerBound, Value upperBound, Value step, "
135+
"ValueRange iterArgs = llvm::None">
135136
];
136137

137138
let extraClassDeclaration = [{

mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -274,29 +274,75 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
274274
Location loc = parallelOp.getLoc();
275275
BlockAndValueMapping mapping;
276276

277-
if (parallelOp.getNumResults() != 0) {
278-
// TODO: Implement lowering of parallelOp with reductions.
279-
return matchFailure();
280-
}
281-
282277
// For a parallel loop, we essentially need to create an n-dimensional loop
283278
// nest. We do this by translating to loop.for ops and have those lowered in
284-
// a further rewrite.
279+
// a further rewrite. If a parallel loop contains reductions (and thus returns
280+
// values), forward the initial values for the reductions down the loop
281+
// hierarchy and bubble up the results by modifying the "yield" terminator.
282+
SmallVector<Value, 4> iterArgs = llvm::to_vector<4>(parallelOp.initVals());
283+
bool first = true;
284+
SmallVector<Value, 4> loopResults(iterArgs);
285285
for (auto loop_operands :
286286
llvm::zip(parallelOp.getInductionVars(), parallelOp.lowerBound(),
287287
parallelOp.upperBound(), parallelOp.step())) {
288288
Value iv, lower, upper, step;
289289
std::tie(iv, lower, upper, step) = loop_operands;
290-
ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step);
290+
ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step, iterArgs);
291291
mapping.map(iv, forOp.getInductionVar());
292+
auto iterRange = forOp.getRegionIterArgs();
293+
iterArgs.assign(iterRange.begin(), iterRange.end());
294+
295+
if (first) {
296+
// Store the results of the outermost loop that will be used to replace
297+
// the results of the parallel loop when it is fully rewritten.
298+
loopResults.assign(forOp.result_begin(), forOp.result_end());
299+
first = false;
300+
} else {
301+
// A loop is constructed with an empty "yield" terminator by default.
302+
// Replace it with another "yield" that forwards the results of the nested
303+
// loop to the parent loop. We need to explicitly make sure the new
304+
// terminator is the last operation in the block because further transfoms
305+
// rely on this.
306+
rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
307+
rewriter.replaceOpWithNewOp<YieldOp>(
308+
rewriter.getInsertionBlock()->getTerminator(), forOp.getResults());
309+
}
310+
292311
rewriter.setInsertionPointToStart(forOp.getBody());
293312
}
294313

295314
// Now copy over the contents of the body.
296-
for (auto &op : parallelOp.getBody()->without_terminator())
297-
rewriter.clone(op, mapping);
315+
SmallVector<Value, 4> yieldOperands;
316+
yieldOperands.reserve(parallelOp.getNumResults());
317+
for (auto &op : parallelOp.getBody()->without_terminator()) {
318+
// Reduction blocks are handled differently.
319+
auto reduce = dyn_cast<ReduceOp>(op);
320+
if (!reduce) {
321+
rewriter.clone(op, mapping);
322+
continue;
323+
}
324+
325+
// Clone the body of the reduction operation into the body of the loop,
326+
// using operands of "loop.reduce" and iteration arguments corresponding
327+
// to the reduction value to replace arguments of the reduction block.
328+
// Collect operands of "loop.reduce.return" to be returned by a final
329+
// "loop.yield" instead.
330+
Value arg = iterArgs[yieldOperands.size()];
331+
Block &reduceBlock = reduce.reductionOperator().front();
332+
mapping.map(reduceBlock.getArgument(0), mapping.lookupOrDefault(arg));
333+
mapping.map(reduceBlock.getArgument(1),
334+
mapping.lookupOrDefault(reduce.operand()));
335+
for (auto &nested : reduceBlock.without_terminator())
336+
rewriter.clone(nested, mapping);
337+
yieldOperands.push_back(
338+
mapping.lookup(reduceBlock.getTerminator()->getOperand(0)));
339+
}
340+
341+
rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
342+
rewriter.replaceOpWithNewOp<YieldOp>(
343+
rewriter.getInsertionBlock()->getTerminator(), yieldOperands);
298344

299-
rewriter.eraseOp(parallelOp);
345+
rewriter.replaceOp(parallelOp, loopResults);
300346

301347
return matchSuccess();
302348
}

mlir/lib/Dialect/LoopOps/LoopOps.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,16 @@ LoopOpsDialect::LoopOpsDialect(MLIRContext *context)
6161
//===----------------------------------------------------------------------===//
6262

6363
void ForOp::build(Builder *builder, OperationState &result, Value lb, Value ub,
64-
Value step) {
64+
Value step, ValueRange iterArgs) {
6565
result.addOperands({lb, ub, step});
66+
result.addOperands(iterArgs);
67+
for (Value v : iterArgs)
68+
result.addTypes(v.getType());
6669
Region *bodyRegion = result.addRegion();
6770
ForOp::ensureTerminator(*bodyRegion, *builder, result.location);
6871
bodyRegion->front().addArgument(builder->getIndexType());
72+
for (Value v : iterArgs)
73+
bodyRegion->front().addArgument(v.getType());
6974
}
7075

7176
static LogicalResult verify(ForOp op) {

mlir/test/Conversion/convert-to-cfg.mlir

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,3 +236,88 @@ func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32 {
236236
}
237237
return %r : f32
238238
}
239+
240+
func @generate() -> i64
241+
242+
// CHECK-LABEL: @simple_parallel_reduce_loop
243+
// CHECK-SAME: %[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[INIT:.*]]: f32
244+
func @simple_parallel_reduce_loop(%arg0: index, %arg1: index,
245+
%arg2: index, %arg3: f32) -> f32 {
246+
// A parallel loop with reduction is converted through sequential loops with
247+
// reductions into a CFG of blocks where the partially reduced value is
248+
// passed across as a block argument.
249+
250+
// Branch to the condition block passing in the initial reduction value.
251+
// CHECK: br ^[[COND:.*]](%[[LB]], %[[INIT]]
252+
253+
// Condition branch takes as arguments the current value of the iteration
254+
// variable and the current partially reduced value.
255+
// CHECK: ^[[COND]](%[[ITER:.*]]: index, %[[ITER_ARG:.*]]: f32
256+
// CHECK: %[[COMP:.*]] = cmpi "slt", %[[ITER]], %[[UB]]
257+
// CHECK: cond_br %[[COMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]]
258+
259+
// Bodies of loop.reduce operations are folded into the main loop body. The
260+
// result of this partial reduction is passed as argument to the condition
261+
// block.
262+
// CHECK: ^[[BODY]]:
263+
// CHECK: %[[CST:.*]] = constant 4.2
264+
// CHECK: %[[PROD:.*]] = mulf %[[ITER_ARG]], %[[CST]]
265+
// CHECK: %[[INCR:.*]] = addi %[[ITER]], %[[STEP]]
266+
// CHECK: br ^[[COND]](%[[INCR]], %[[PROD]]
267+
268+
// The continuation block has access to the (last value of) reduction.
269+
// CHECK: ^[[CONTINUE]]:
270+
// CHECK: return %[[ITER_ARG]]
271+
%0 = loop.parallel (%i) = (%arg0) to (%arg1) step (%arg2) init(%arg3) {
272+
%cst = constant 42.0 : f32
273+
loop.reduce(%cst) {
274+
^bb0(%lhs: f32, %rhs: f32):
275+
%1 = mulf %lhs, %rhs : f32
276+
loop.reduce.return %1 : f32
277+
} : f32
278+
} : f32
279+
return %0 : f32
280+
}
281+
282+
// CHECK-LABEL: parallel_reduce_loop
283+
// CHECK-SAME: %[[INIT1:[0-9A-Za-z_]*]]: f32)
284+
func @parallel_reduce_loop(%arg0 : index, %arg1 : index, %arg2 : index,
285+
%arg3 : index, %arg4 : index, %arg5 : f32) -> (f32, i64) {
286+
// Multiple reduction blocks should be folded in the same body, and the
287+
// reduction value must be forwarded through block structures.
288+
// CHECK: %[[INIT2:.*]] = constant 42
289+
// CHECK: br ^[[COND_OUT:.*]](%{{.*}}, %[[INIT1]], %[[INIT2]]
290+
// CHECK: ^[[COND_OUT]](%{{.*}}: index, %[[ITER_ARG1_OUT:.*]]: f32, %[[ITER_ARG2_OUT:.*]]: i64
291+
// CHECK: cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]]
292+
// CHECK: ^[[BODY_OUT]]:
293+
// CHECK: br ^[[COND_IN:.*]](%{{.*}}, %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]]
294+
// CHECK: ^[[COND_IN]](%{{.*}}: index, %[[ITER_ARG1_IN:.*]]: f32, %[[ITER_ARG2_IN:.*]]: i64
295+
// CHECK: cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]]
296+
// CHECK: ^[[BODY_IN]]:
297+
// CHECK: %[[REDUCE1:.*]] = addf %[[ITER_ARG1_IN]], %{{.*}}
298+
// CHECK: %[[REDUCE2:.*]] = or %[[ITER_ARG2_IN]], %{{.*}}
299+
// CHECK: br ^[[COND_IN]](%{{.*}}, %[[REDUCE1]], %[[REDUCE2]]
300+
// CHECK: ^[[CONT_IN]]:
301+
// CHECK: br ^[[COND_OUT]](%{{.*}}, %[[ITER_ARG1_IN]], %[[ITER_ARG2_IN]]
302+
// CHECK: ^[[CONT_OUT]]:
303+
// CHECK: return %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]]
304+
%step = constant 1 : index
305+
%init = constant 42 : i64
306+
%0:2 = loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
307+
step (%arg4, %step) init(%arg5, %init) {
308+
%cf = constant 42.0 : f32
309+
loop.reduce(%cf) {
310+
^bb0(%lhs: f32, %rhs: f32):
311+
%1 = addf %lhs, %rhs : f32
312+
loop.reduce.return %1 : f32
313+
} : f32
314+
315+
%2 = call @generate() : () -> i64
316+
loop.reduce(%2) {
317+
^bb0(%lhs: i64, %rhs: i64):
318+
%3 = or %lhs, %rhs : i64
319+
loop.reduce.return %3 : i64
320+
} : i64
321+
} : f32, i64
322+
return %0#0, %0#1 : f32, i64
323+
}

0 commit comments

Comments
 (0)