Skip to content

Commit 87568ff

Browse files
[mlir][SCF] convert-scf-to-cf: Lower scf.forall to scf.parallel (#65449)
scf.forall ops without shared outputs (i.e., fully bufferized ops) are lowered to scf.parallel. scf.forall ops are typically lowered by an earlier pass depending on the execution target. E.g., there are optimized lowerings for GPU execution. This new lowering is for completeness (convert-scf-to-cf can now lower all SCF loop constructs) and provides a simple CPU lowering strategy for testing purposes. scf.parallel is currently lowered to scf.for, which executes sequentially. The scf.parallel lowering could be improved in the future to run on multiple threads.
1 parent 3e7cd5e commit 87568ff

File tree

2 files changed

+74
-3
lines changed

2 files changed

+74
-3
lines changed

mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,18 @@ struct IndexSwitchLowering : public OpRewritePattern<IndexSwitchOp> {
298298
LogicalResult matchAndRewrite(IndexSwitchOp op,
299299
PatternRewriter &rewriter) const override;
300300
};
301+
302+
/// Lower an `scf.forall` operation to an `scf.parallel` op, assuming that it
303+
/// has no shared outputs. Ops with shared outputs should be bufferized first.
304+
/// Specialized lowerings for `scf.forall` (e.g., for GPUs) exist in other
305+
/// dialects/passes.
306+
struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> {
307+
using OpRewritePattern<mlir::scf::ForallOp>::OpRewritePattern;
308+
309+
LogicalResult matchAndRewrite(mlir::scf::ForallOp forallOp,
310+
PatternRewriter &rewriter) const override;
311+
};
312+
301313
} // namespace
302314

303315
LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
@@ -677,10 +689,41 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
677689
return success();
678690
}
679691

692+
LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
693+
PatternRewriter &rewriter) const {
694+
Location loc = forallOp.getLoc();
695+
if (!forallOp.getOutputs().empty())
696+
return rewriter.notifyMatchFailure(
697+
forallOp,
698+
"only fully bufferized scf.forall ops can be lowered to scf.parallel");
699+
700+
// Convert mixed bounds and steps to SSA values.
701+
SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
702+
rewriter, loc, forallOp.getMixedLowerBound());
703+
SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
704+
rewriter, loc, forallOp.getMixedUpperBound());
705+
SmallVector<Value> steps =
706+
getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
707+
708+
// Create empty scf.parallel op.
709+
auto parallelOp = rewriter.create<ParallelOp>(loc, lbs, ubs, steps);
710+
rewriter.eraseBlock(&parallelOp.getRegion().front());
711+
rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
712+
parallelOp.getRegion().begin());
713+
// Replace the terminator.
714+
rewriter.setInsertionPointToEnd(&parallelOp.getRegion().front());
715+
rewriter.replaceOpWithNewOp<scf::YieldOp>(
716+
parallelOp.getRegion().front().getTerminator());
717+
718+
// Erase the scf.forall op.
719+
rewriter.replaceOp(forallOp, parallelOp);
720+
return success();
721+
}
722+
680723
void mlir::populateSCFToControlFlowConversionPatterns(
681724
RewritePatternSet &patterns) {
682-
patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering,
683-
ExecuteRegionLowering, IndexSwitchLowering>(
725+
patterns.add<ForallLowering, ForLowering, IfLowering, ParallelLowering,
726+
WhileLowering, ExecuteRegionLowering, IndexSwitchLowering>(
684727
patterns.getContext());
685728
patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
686729
}
@@ -691,7 +734,7 @@ void SCFToControlFlowPass::runOnOperation() {
691734

692735
// Configure conversion to lower out SCF operations.
693736
ConversionTarget target(getContext());
694-
target.addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp,
737+
target.addIllegalOp<scf::ForallOp, scf::ForOp, scf::IfOp, scf::IndexSwitchOp,
695738
scf::ParallelOp, scf::WhileOp, scf::ExecuteRegionOp>();
696739
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
697740
if (failed(

mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,3 +648,31 @@ func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 {
648648
// CHECK-NEXT: return %[[V]]
649649
return %0 : i32
650650
}
651+
652+
// Note: scf.forall is lowered to scf.parallel, which is currently lowered to
653+
// scf.for and then to unstructured control flow. scf.parallel could lower more
654+
// efficiently to multi-threaded IR, at which point scf.forall would
655+
// automatically lower to multi-threaded IR.
656+
657+
// CHECK-LABEL: func @forall(
658+
// CHECK-SAME: %[[num_threads:.*]]: index)
659+
// CHECK: %[[c0:.*]] = arith.constant 0 : index
660+
// CHECK: %[[c1:.*]] = arith.constant 1 : index
661+
// CHECK: cf.br ^[[bb1:.*]](%[[c0]] : index)
662+
// CHECK: ^[[bb1]](%[[arg0:.*]]: index):
663+
// CHECK: %[[cmpi:.*]] = arith.cmpi slt, %[[arg0]], %[[num_threads]]
664+
// CHECK: cf.cond_br %[[cmpi]], ^[[bb2:.*]], ^[[bb3:.*]]
665+
// CHECK: ^[[bb2]]:
666+
// CHECK: "test.foo"(%[[arg0]])
667+
// CHECK: %[[addi:.*]] = arith.addi %[[arg0]], %[[c1]]
668+
// CHECK: cf.br ^[[bb1]](%[[addi]] : index)
669+
// CHECK: ^[[bb3]]:
670+
// CHECK: return
671+
func.func @forall(%num_threads: index) {
672+
scf.forall (%thread_idx) in (%num_threads) {
673+
"test.foo"(%thread_idx) : (index) -> ()
674+
scf.forall.in_parallel {
675+
}
676+
}
677+
return
678+
}

0 commit comments

Comments
 (0)