Skip to content

Commit 77533d7

Browse files
Moerafaatd0k
authored andcommitted
[mlir][SCF] Adding custom builder to SCF::WhileOp.
This is a similar builder to the one for SCF::IfOp which allows users to pass region builders to it. Refer to the builders for IfOp. Reviewed By: tpopp Differential Revision: https://reviews.llvm.org/D137709
1 parent beaffb0 commit 77533d7

File tree

3 files changed

+64
-35
lines changed

3 files changed

+64
-35
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -935,7 +935,7 @@ def WhileOp : SCF_Op<"while",
935935

936936
Note that the types of region arguments need not to match with each other.
937937
The op expects the operand types to match with argument types of the
938-
"before" region"; the result types to match with the trailing operand types
938+
"before" region; the result types to match with the trailing operand types
939939
of the terminator of the "before" region, and with the argument types of the
940940
"after" region. The following scheme can be used to share the results of
941941
some operations executed in the "before" region with the "after" region,
@@ -983,7 +983,16 @@ def WhileOp : SCF_Op<"while",
983983
let results = (outs Variadic<AnyType>:$results);
984984
let regions = (region SizedRegion<1>:$before, SizedRegion<1>:$after);
985985

986+
let builders = [
987+
OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$operands,
988+
"function_ref<void(OpBuilder &, Location, ValueRange)>":$beforeBuilder,
989+
"function_ref<void(OpBuilder &, Location, ValueRange)>":$afterBuilder)>
990+
];
991+
986992
let extraClassDeclaration = [{
993+
using BodyBuilderFn =
994+
function_ref<void(OpBuilder &, Location, ValueRange)>;
995+
987996
OperandRange getSuccessorEntryOperands(Optional<unsigned> index);
988997
ConditionOp getConditionOp();
989998
YieldOp getYieldOp();

mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp

Lines changed: 26 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -71,40 +71,32 @@ static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
7171
SmallVector<Type> types = {elementTy, elementTy, elementTy};
7272
SmallVector<Location> locations = {loc, loc, loc};
7373

74-
auto whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
75-
Block *before =
76-
rewriter.createBlock(&whileOp.getBefore(), {}, types, locations);
77-
Block *after =
78-
rewriter.createBlock(&whileOp.getAfter(), {}, types, locations);
79-
80-
// The conditional block of the while loop.
81-
{
82-
rewriter.setInsertionPointToStart(&whileOp.getBefore().front());
83-
Value input = before->getArgument(0);
84-
Value zero = before->getArgument(2);
85-
86-
Value inputNotZero = rewriter.create<arith::CmpIOp>(
87-
loc, arith::CmpIPredicate::ne, input, zero);
88-
rewriter.create<scf::ConditionOp>(loc, inputNotZero,
89-
before->getArguments());
90-
}
91-
92-
// The body of the while loop: shift right until reaching a value of 0.
93-
{
94-
rewriter.setInsertionPointToStart(&whileOp.getAfter().front());
95-
Value input = after->getArgument(0);
96-
Value leadingZeros = after->getArgument(1);
97-
98-
auto one =
99-
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
100-
auto shifted = rewriter.create<arith::ShRUIOp>(loc, resultTy, input, one);
101-
auto leadingZerosMinusOne =
102-
rewriter.create<arith::SubIOp>(loc, resultTy, leadingZeros, one);
103-
104-
rewriter.create<scf::YieldOp>(
105-
loc,
106-
ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)}));
107-
}
74+
auto whileOp = rewriter.create<scf::WhileOp>(
75+
loc, types, operands,
76+
[&](OpBuilder &beforeBuilder, Location beforeLoc, ValueRange args) {
77+
// The conditional block of the while loop.
78+
Value input = args[0];
79+
Value zero = args[2];
80+
81+
Value inputNotZero = beforeBuilder.create<arith::CmpIOp>(
82+
loc, arith::CmpIPredicate::ne, input, zero);
83+
beforeBuilder.create<scf::ConditionOp>(loc, inputNotZero, args);
84+
},
85+
[&](OpBuilder &afterBuilder, Location afterLoc, ValueRange args) {
86+
// The body of the while loop: shift right until reaching a value of 0.
87+
Value input = args[0];
88+
Value leadingZeros = args[1];
89+
90+
auto one = afterBuilder.create<arith::ConstantOp>(
91+
loc, IntegerAttr::get(elementTy, 1));
92+
auto shifted =
93+
afterBuilder.create<arith::ShRUIOp>(loc, resultTy, input, one);
94+
auto leadingZerosMinusOne = afterBuilder.create<arith::SubIOp>(
95+
loc, resultTy, leadingZeros, one);
96+
97+
afterBuilder.create<scf::YieldOp>(
98+
loc, ValueRange({shifted, leadingZerosMinusOne, args[2]}));
99+
});
108100

109101
rewriter.setInsertionPointAfter(whileOp);
110102
rewriter.replaceOp(op, whileOp->getResult(1));

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2669,6 +2669,34 @@ LogicalResult ReduceReturnOp::verify() {
26692669
// WhileOp
26702670
//===----------------------------------------------------------------------===//
26712671

2672+
void WhileOp::build(::mlir::OpBuilder &odsBuilder,
2673+
::mlir::OperationState &odsState, TypeRange resultTypes,
2674+
ValueRange operands, BodyBuilderFn beforeBuilder,
2675+
BodyBuilderFn afterBuilder) {
2676+
assert(beforeBuilder && "the builder callback for 'before' must be present");
2677+
assert(afterBuilder && "the builder callback for 'after' must be present");
2678+
2679+
odsState.addOperands(operands);
2680+
odsState.addTypes(resultTypes);
2681+
2682+
OpBuilder::InsertionGuard guard(odsBuilder);
2683+
2684+
SmallVector<Location, 4> blockArgLocs;
2685+
for (Value operand : operands) {
2686+
blockArgLocs.push_back(operand.getLoc());
2687+
}
2688+
2689+
Region *beforeRegion = odsState.addRegion();
2690+
Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
2691+
resultTypes, blockArgLocs);
2692+
beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
2693+
2694+
Region *afterRegion = odsState.addRegion();
2695+
Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
2696+
resultTypes, blockArgLocs);
2697+
afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
2698+
}
2699+
26722700
OperandRange WhileOp::getSuccessorEntryOperands(Optional<unsigned> index) {
26732701
assert(index && *index == 0 &&
26742702
"WhileOp is expected to branch only to the first region");

0 commit comments

Comments
 (0)