Skip to content

Commit 54844eb

Browse files
[mlir][Interfaces] LoopLikeOpInterface: Add helper to get yielded values
Add a new interface method that returns the yielded values. Also add a verifier that checks the number of inits/iter_args/yielded values. Most of the checked invariants (but not all of them) are already covered by the `RegionBranchOpInterface`, but the `LoopLikeOpInterface` now provides (additional) error messages that are easier to read.
1 parent 5cacf4e commit 54844eb

File tree

14 files changed

+153
-28
lines changed

14 files changed

+153
-28
lines changed

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2112,6 +2112,8 @@ def fir_DoLoopOp : region_Op<"do_loop",
21122112
mlir::Operation::operand_range getIterOperands() {
21132113
return getOperands().drop_front(getNumControlOperands());
21142114
}
2115+
mlir::OperandRange getInits() { return getIterOperands(); }
2116+
mlir::ValueRange getYieldedValues();
21152117

21162118
void setLowerBound(mlir::Value bound) { (*this)->setOperand(0, bound); }
21172119
void setUpperBound(mlir::Value bound) { (*this)->setOperand(1, bound); }
@@ -2263,6 +2265,8 @@ def fir_IterWhileOp : region_Op<"iterate_while",
22632265
mlir::Operation::operand_range getIterOperands() {
22642266
return getOperands().drop_front(getNumControlOperands());
22652267
}
2268+
mlir::OperandRange getInits() { return getIterOperands(); }
2269+
mlir::ValueRange getYieldedValues();
22662270

22672271
void setLowerBound(mlir::Value bound) { (*this)->setOperand(0, bound); }
22682272
void setUpperBound(mlir::Value bound) { (*this)->setOperand(1, bound); }

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1972,6 +1972,12 @@ mlir::Value fir::IterWhileOp::blockArgToSourceOp(unsigned blockArgNum) {
19721972
return {};
19731973
}
19741974

1975+
mlir::ValueRange fir::IterWhileOp::getYieldedValues() {
1976+
auto *term = getRegion().front().getTerminator();
1977+
return getFinalValue() ? term->getOperands().drop_front()
1978+
: term->getOperands();
1979+
}
1980+
19751981
//===----------------------------------------------------------------------===//
19761982
// LenParamIndexOp
19771983
//===----------------------------------------------------------------------===//
@@ -2267,6 +2273,12 @@ mlir::Value fir::DoLoopOp::blockArgToSourceOp(unsigned blockArgNum) {
22672273
return {};
22682274
}
22692275

2276+
mlir::ValueRange fir::DoLoopOp::getYieldedValues() {
2277+
auto *term = getRegion().front().getTerminator();
2278+
return getFinalValue() ? term->getOperands().drop_front()
2279+
: term->getOperands();
2280+
}
2281+
22702282
//===----------------------------------------------------------------------===//
22712283
// DTEntryOp
22722284
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ def AffineForOp : Affine_Op<"for",
121121
ImplicitAffineTerminator, ConditionallySpeculatable,
122122
RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
123123
["getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
124-
"getSingleUpperBound", "replaceWithAdditionalYields"]>,
124+
"getSingleUpperBound", "getYieldedValues",
125+
"replaceWithAdditionalYields"]>,
125126
DeclareOpInterfaceMethods<RegionBranchOpInterface,
126127
["getEntrySuccessorOperands"]>]> {
127128
let summary = "for operation";

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
122122
def ForOp : SCF_Op<"for",
123123
[AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
124124
["getInits", "getSingleInductionVar", "getSingleLowerBound",
125-
"getSingleStep", "getSingleUpperBound", "promoteIfSingleIteration",
126-
"replaceWithAdditionalYields"]>,
125+
"getSingleStep", "getSingleUpperBound", "getYieldedValues",
126+
"promoteIfSingleIteration", "replaceWithAdditionalYields"]>,
127127
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
128128
ConditionallySpeculatable,
129129
DeclareOpInterfaceMethods<RegionBranchOpInterface,
@@ -243,9 +243,11 @@ def ForOp : SCF_Op<"for",
243243
function_ref<void(OpBuilder &, Location, Value, ValueRange)>;
244244

245245
Value getInductionVar() { return getBody()->getArgument(0); }
246+
246247
Block::BlockArgListType getRegionIterArgs() {
247248
return getBody()->getArguments().drop_front(getNumInductionVars());
248249
}
250+
249251
/// Return the `index`-th region iteration argument.
250252
BlockArgument getRegionIterArg(unsigned index) {
251253
assert(index < getNumRegionIterArgs() &&
@@ -1086,6 +1088,11 @@ def WhileOp : SCF_Op<"while",
10861088

10871089
ConditionOp getConditionOp();
10881090
YieldOp getYieldOp();
1091+
1092+
/// Return the values that are yielded from the "after" region (by the
1093+
/// scf.yield op).
1094+
ValueRange getYieldedValues();
1095+
10891096
Block::BlockArgListType getBeforeArguments();
10901097
Block::BlockArgListType getAfterArguments();
10911098
Block *getBeforeBody() { return &getBefore().front(); }

mlir/include/mlir/Interfaces/LoopLikeInterface.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ class RewriterBase;
2424
/// arguments in `newBbArgs`.
2525
using NewYieldValuesFn = std::function<SmallVector<Value>(
2626
OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs)>;
27+
28+
namespace detail {
29+
/// Verify invariants of the LoopLikeOpInterface.
30+
LogicalResult verifyLoopLikeOpInterface(Operation *op);
31+
} // namespace detail
2732
} // namespace mlir
2833

2934
/// Include the generated interface declarations.

mlir/include/mlir/Interfaces/LoopLikeInterface.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,19 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
2020
Contains helper functions to query properties and perform transformations
2121
of a loop. Operations that implement this interface will be considered by
2222
loop-invariant code motion.
23+
24+
Loop-carried variables can be exposed through this interface. There are
25+
3 components to a loop-carried variable.
26+
- The "region iter_arg" is the block argument of the entry block that
27+
represents the loop-carried variable in each iteration.
28+
- The "init value" is an operand of the loop op that serves as the initial
29+
region iter_arg value for the first iteration (if any).
30+
- The "yielded" value is the value that is forwarded from one iteration to
31+
serve as the region iter_arg of the next iteration.
32+
33+
If one of the respective interface methods is implemented, so must the other
34+
two. The interface verifier ensures that the number of types of the region
35+
iter_args, init values and yielded values match.
2336
}];
2437
let cppNamespace = "::mlir";
2538

@@ -141,6 +154,17 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
141154
return ::mlir::Block::BlockArgListType();
142155
}]
143156
>,
157+
InterfaceMethod<[{
158+
Return the values that are yielded to the next iteration.
159+
}],
160+
/*retTy=*/"::mlir::ValueRange",
161+
/*methodName=*/"getYieldedValues",
162+
/*args=*/(ins),
163+
/*methodBody=*/"",
164+
/*defaultImplementation=*/[{
165+
return ::mlir::ValueRange();
166+
}]
167+
>,
144168
InterfaceMethod<[{
145169
Append the specified additional "init" operands: replace this loop with
146170
a new loop that has the additional init operands. The loop body of
@@ -192,6 +216,12 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
192216
});
193217
}
194218
}];
219+
220+
let verifyWithRegions = 1;
221+
222+
let verify = [{
223+
return detail::verifyLoopLikeOpInterface($_op);
224+
}];
195225
}
196226

197227
#endif // MLIR_INTERFACES_LOOPLIKEINTERFACE

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2215,6 +2215,10 @@ unsigned AffineForOp::getNumIterOperands() {
22152215
return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
22162216
}
22172217

2218+
ValueRange AffineForOp::getYieldedValues() {
2219+
return cast<AffineYieldOp>(getBody()->getTerminator()).getOperands();
2220+
}
2221+
22182222
void AffineForOp::print(OpAsmPrinter &p) {
22192223
p << ' ';
22202224
p.printRegionArgument(getBody()->getArgument(0), /*argAttrs=*/{},

mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -811,8 +811,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
811811
rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp());
812812

813813
unsigned iterArgNumber = forOp.getResultForOpOperand(*pUse).getResultNumber();
814-
auto yieldOp = cast<scf::YieldOp>(forOp.getBody(0)->getTerminator());
815-
auto yieldingExtractSliceOp = yieldOp->getOperand(iterArgNumber)
814+
auto yieldingExtractSliceOp = forOp.getYieldedValues()[iterArgNumber]
816815
.getDefiningOp<tensor::ExtractSliceOp>();
817816
if (!yieldingExtractSliceOp)
818817
return tensor::ExtractSliceOp();
@@ -826,7 +825,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
826825

827826
SmallVector<Value> initArgs = forOp.getInitArgs();
828827
initArgs[iterArgNumber] = hoistedPackedTensor;
829-
SmallVector<Value> yieldOperands = yieldOp.getOperands();
828+
SmallVector<Value> yieldOperands = llvm::to_vector(forOp.getYieldedValues());
830829
yieldOperands[iterArgNumber] = yieldingExtractSliceOp.getSource();
831830

832831
int64_t numOriginalForOpResults = initArgs.size();

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
400400

401401
// Replace all results with the yielded values.
402402
auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
403-
rewriter.replaceAllUsesWith(getResults(), yieldOp.getOperands());
403+
rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
404404

405405
// Replace block arguments with lower bound (replacement for IV) and
406406
// iter_args.
@@ -772,27 +772,26 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
772772
LogicalResult matchAndRewrite(scf::ForOp forOp,
773773
PatternRewriter &rewriter) const final {
774774
bool canonicalize = false;
775-
Block &block = forOp.getRegion().front();
776-
auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
777775

778776
// An internal flat vector of block transfer
779777
// arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
780778
// transformed block argument mappings. This plays the role of a
781779
// IRMapping for the particular use case of calling into
782780
// `inlineBlockBefore`.
781+
int64_t numResults = forOp.getNumResults();
783782
SmallVector<bool, 4> keepMask;
784-
keepMask.reserve(yieldOp.getNumOperands());
783+
keepMask.reserve(numResults);
785784
SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
786785
newResultValues;
787-
newBlockTransferArgs.reserve(1 + forOp.getInitArgs().size());
786+
newBlockTransferArgs.reserve(1 + numResults);
788787
newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
789788
newIterArgs.reserve(forOp.getInitArgs().size());
790-
newYieldValues.reserve(yieldOp.getNumOperands());
791-
newResultValues.reserve(forOp.getNumResults());
789+
newYieldValues.reserve(numResults);
790+
newResultValues.reserve(numResults);
792791
for (auto it : llvm::zip(forOp.getInitArgs(), // iter from outside
793792
forOp.getRegionIterArgs(), // iter inside region
794793
forOp.getResults(), // op results
795-
yieldOp.getOperands() // iter yield
794+
forOp.getYieldedValues() // iter yield
796795
)) {
797796
// Forwarded is `true` when:
798797
// 1) The region `iter` argument is yielded.
@@ -946,12 +945,10 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
946945
return failure();
947946
// If the loop is empty, iterates at least once, and only returns values
948947
// defined outside of the loop, remove it and replace it with yield values.
949-
auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
950-
auto yieldOperands = yieldOp.getOperands();
951-
if (llvm::any_of(yieldOperands,
948+
if (llvm::any_of(op.getYieldedValues(),
952949
[&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
953950
return failure();
954-
rewriter.replaceOp(op, yieldOperands);
951+
rewriter.replaceOp(op, op.getYieldedValues());
955952
return success();
956953
}
957954
};
@@ -1224,6 +1221,10 @@ std::optional<APInt> ForOp::getConstantStep() {
12241221
return {};
12251222
}
12261223

1224+
ValueRange ForOp::getYieldedValues() {
1225+
return cast<scf::YieldOp>(getBody()->getTerminator()).getResults();
1226+
}
1227+
12271228
Speculation::Speculatability ForOp::getSpeculatability() {
12281229
// `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
12291230
// and End.
@@ -3205,6 +3206,8 @@ YieldOp WhileOp::getYieldOp() {
32053206
return cast<YieldOp>(getAfterBody()->getTerminator());
32063207
}
32073208

3209+
ValueRange WhileOp::getYieldedValues() { return getYieldOp().getResults(); }
3210+
32083211
Block::BlockArgListType WhileOp::getBeforeArguments() {
32093212
return getBeforeBody()->getArguments();
32103213
}

mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -605,9 +605,8 @@ struct ForOpInterface
605605
auto forOp = cast<scf::ForOp>(op);
606606
OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
607607
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
608-
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
609608
bool equivalentYield = state.areEquivalentBufferizedValues(
610-
bbArg, yieldOp->getOperand(opResult.getResultNumber()));
609+
bbArg, forOp.getYieldedValues()[opResult.getResultNumber()]);
611610
return equivalentYield ? BufferRelation::Equivalent
612611
: BufferRelation::Unknown;
613612
}

mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,9 @@ using namespace mlir::scf;
3636
/// type of the corresponding basic block argument of the loop.
3737
/// Note: This function handles only simple cases. Expand as needed.
3838
static bool isShapePreserving(ForOp forOp, int64_t arg) {
39-
auto yieldOp = cast<YieldOp>(forOp.getBody()->getTerminator());
40-
assert(arg < static_cast<int64_t>(yieldOp.getResults().size()) &&
39+
assert(arg < static_cast<int64_t>(forOp.getNumResults()) &&
4140
"arg is out of bounds");
42-
Value value = yieldOp.getResults()[arg];
41+
Value value = forOp.getYieldedValues()[arg];
4342
while (value) {
4443
if (value == forOp.getRegionIterArgs()[arg])
4544
return true;

mlir/lib/Interfaces/LoopLikeInterface.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,40 @@ bool LoopLikeOpInterface::blockIsInLoop(Block *block) {
5252
}
5353
return false;
5454
}
55+
56+
LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
57+
// Note: These invariants are also verified by the RegionBranchOpInterface,
58+
// but the LoopLikeOpInterface provides better error messages.
59+
auto loopLikeOp = cast<LoopLikeOpInterface>(op);
60+
61+
// Verify number of inits/iter_args/yielded values.
62+
if (loopLikeOp.getInits().size() != loopLikeOp.getRegionIterArgs().size())
63+
return op->emitOpError("different number of inits and region iter_args: ")
64+
<< loopLikeOp.getInits().size()
65+
<< " != " << loopLikeOp.getRegionIterArgs().size();
66+
if (loopLikeOp.getRegionIterArgs().size() !=
67+
loopLikeOp.getYieldedValues().size())
68+
return op->emitOpError(
69+
"different number of region iter_args and yielded values: ")
70+
<< loopLikeOp.getRegionIterArgs().size()
71+
<< " != " << loopLikeOp.getYieldedValues().size();
72+
73+
// Verify types of inits/iter_args/yielded values.
74+
int64_t i = 0;
75+
for (const auto it :
76+
llvm::zip_equal(loopLikeOp.getInits(), loopLikeOp.getRegionIterArgs(),
77+
loopLikeOp.getYieldedValues())) {
78+
if (std::get<0>(it).getType() != std::get<1>(it).getType())
79+
op->emitOpError(std::to_string(i))
80+
<< "-th init and " << i << "-th region iter_arg have different type: "
81+
<< std::get<0>(it).getType() << " != " << std::get<1>(it).getType();
82+
if (std::get<1>(it).getType() != std::get<2>(it).getType())
83+
op->emitOpError(std::to_string(i))
84+
<< "-th region iter_arg and " << i
85+
<< "-th yielded value have different type: "
86+
<< std::get<1>(it).getType() << " != " << std::get<2>(it).getType();
87+
++i;
88+
}
89+
90+
return success();
91+
}

mlir/test/Dialect/SCF/invalid.mlir

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,32 @@ func.func @not_enough_loop_results(%arg0: index, %init: f32) {
9696

9797
// -----
9898

99+
func.func @too_many_iter_args(%arg0: index, %init: f32) {
100+
// expected-error @below{{different number of inits and region iter_args: 1 != 2}}
101+
%x = "scf.for"(%arg0, %arg0, %arg0, %init) (
102+
{
103+
^bb0(%i0 : index, %iter: f32, %iter2: f32):
104+
scf.yield %iter, %iter : f32, f32
105+
}
106+
) : (index, index, index, f32) -> (f32)
107+
return
108+
}
109+
110+
// -----
111+
112+
func.func @too_few_yielded_values(%arg0: index, %init: f32) {
113+
// expected-error @below{{different number of region iter_args and yielded values: 2 != 1}}
114+
%x, %x2 = "scf.for"(%arg0, %arg0, %arg0, %init, %init) (
115+
{
116+
^bb0(%i0 : index, %iter: f32, %iter2: f32):
117+
scf.yield %iter : f32
118+
}
119+
) : (index, index, index, f32, f32) -> (f32, f32)
120+
return
121+
}
122+
123+
// -----
124+
99125
func.func @loop_if_not_i1(%arg0: index) {
100126
// expected-error@+1 {{operand #0 must be 1-bit signless integer}}
101127
"scf.if"(%arg0) ({}, {}) : (index) -> ()
@@ -422,7 +448,8 @@ func.func @std_for_operands_mismatch_3(%arg0 : index, %arg1 : index, %arg2 : ind
422448
func.func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : index) {
423449
%s0 = arith.constant 0.0 : f32
424450
%t0 = arith.constant 1.0 : f32
425-
// expected-error @+1 {{along control flow edge from Region #0 to Region #0: source type #1 'i32' should match input type #1 'f32'}}
451+
// expected-error @below {{1-th region iter_arg and 1-th yielded value have different type: 'f32' != 'i32'}}
452+
// expected-error @below {{along control flow edge from Region #0 to Region #0: source type #1 'i32' should match input type #1 'f32'}}
426453
%result1:2 = scf.for %i0 = %arg0 to %arg1 step %arg2
427454
iter_args(%si = %s0, %ti = %t0) -> (f32, f32) {
428455
%sn = arith.addf %si, %si : f32
@@ -432,7 +459,6 @@ func.func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : ind
432459
return
433460
}
434461

435-
436462
// -----
437463

438464
func.func @parallel_invalid_yield(

mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,8 @@ struct TestSCFForUtilsPass
5050
auto newInitValues = forOp.getInitArgs();
5151
if (newInitValues.empty())
5252
return;
53-
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
54-
SmallVector<Value> oldYieldValues(yieldOp.getResults().begin(),
55-
yieldOp.getResults().end());
53+
SmallVector<Value> oldYieldValues =
54+
llvm::to_vector(forOp.getYieldedValues());
5655
NewYieldValuesFn fn = [&](OpBuilder &b, Location loc,
5756
ArrayRef<BlockArgument> newBBArgs) {
5857
SmallVector<Value> newYieldValues;

0 commit comments

Comments
 (0)