Skip to content

Commit 98a6edd

Browse files
[mlir][Interfaces] LoopLikeOpInterface: Expose tied loop results (#70535)
Expose loop results, which correspond to the region iter_arg values that are returned from the loop when there are no more iterations. Exposing loop results is optional because some loops (e.g., `scf.while`) do not have a 1-to-1 mapping between region iter_args and op results. Also add additional helper functions to query tied results/iter_args/inits.
1 parent e599978 commit 98a6edd

File tree

8 files changed

+142
-46
lines changed

8 files changed

+142
-46
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -269,28 +269,6 @@ def ForOp : SCF_Op<"for",
269269
/// Number of operands controlling the loop: lb, ub, step
270270
unsigned getNumControlOperands() { return 3; }
271271

272-
/// Get the OpResult that corresponds to an OpOperand.
273-
/// Assert that opOperand is an iterArg.
274-
/// This helper prevents internal op implementation detail leakage to
275-
/// clients by hiding the operand / block argument mapping.
276-
OpResult getResultForOpOperand(OpOperand &opOperand) {
277-
assert(opOperand.getOperandNumber() >= getNumControlOperands() &&
278-
"expected an iter args operand");
279-
assert(opOperand.getOwner() == getOperation() &&
280-
"opOperand does not belong to this scf::ForOp operation");
281-
return getOperation()->getResult(
282-
opOperand.getOperandNumber() - getNumControlOperands());
283-
}
284-
/// Get the OpOperand& that corresponds to an OpResultOpOperand.
285-
/// This helper prevents internal op implementation detail leakage to
286-
/// clients by hiding the operand / block argument mapping.
287-
OpOperand &getOpOperandForResult(OpResult opResult) {
288-
assert(opResult.getDefiningOp() == getOperation() &&
289-
"opResult does not belong to the scf::ForOp operation");
290-
return getOperation()->getOpOperand(
291-
getNumControlOperands() + opResult.getResultNumber());
292-
}
293-
294272
/// Returns the step as an `APInt` if it is constant.
295273
std::optional<APInt> getConstantStep();
296274

@@ -942,7 +920,7 @@ def WhileOp : SCF_Op<"while",
942920
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
943921
["getEntrySuccessorOperands"]>,
944922
DeclareOpInterfaceMethods<LoopLikeOpInterface,
945-
["getLoopResults", "getRegionIterArgs", "getYieldedValuesMutable"]>,
923+
["getRegionIterArgs", "getYieldedValuesMutable"]>,
946924
RecursiveMemoryEffects, SingleBlock]> {
947925
let summary = "a generic 'while' loop";
948926
let description = [{
@@ -1156,7 +1134,7 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
11561134
//===----------------------------------------------------------------------===//
11571135

11581136
def YieldOp : SCF_Op<"yield", [Pure, ReturnLike, Terminator,
1159-
ParentOneOf<["ExecuteRegionOp, ForOp", "IfOp", "IndexSwitchOp",
1137+
ParentOneOf<["ExecuteRegionOp", "ForOp", "IfOp", "IndexSwitchOp",
11601138
"ParallelOp", "WhileOp"]>]> {
11611139
let summary = "loop yield and termination operation";
11621140
let description = [{

mlir/include/mlir/Interfaces/LoopLikeInterface.td

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
3333
If one of the respective interface methods is implemented, so must the other
3434
two. The interface verifier ensures that the number of types of the region
3535
iter_args, init values and yielded values match.
36+
37+
Optionally, "loop results" can be exposed through this interface. These are
38+
the values that are returned from the loop op when there are no more
39+
iterations. The number and types of the loop results must match with the
40+
region iter_args. Note: Loop results are optional because some loops
41+
(e.g., `scf.while`) may produce results that do match 1-to-1 with the
42+
region iter_args.
3643
}];
3744
let cppNamespace = "::mlir";
3845

@@ -166,6 +173,26 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
166173
return {};
167174
}]
168175
>,
176+
InterfaceMethod<[{
177+
Return the range of results that are return from this loop and
178+
correspond to the "init" operands.
179+
180+
Note: This interface method is optional. If loop results are not
181+
exposed via this interface, "std::nullopt" should be returned.
182+
Otherwise, the number and types of results must match with the
183+
region iter_args, inits and yielded values that are exposed via this
184+
interface. If loop results are exposed but this loop op has no
185+
loop-carried variables, an empty result range (and not "std::nullopt")
186+
should be returned.
187+
}],
188+
/*retTy=*/"::std::optional<::mlir::ResultRange>",
189+
/*methodName=*/"getLoopResults",
190+
/*args=*/(ins),
191+
/*methodBody=*/"",
192+
/*defaultImplementation=*/[{
193+
return ::std::nullopt;
194+
}]
195+
>,
169196
InterfaceMethod<[{
170197
Append the specified additional "init" operands: replace this loop with
171198
a new loop that has the additional init operands. The loop body of
@@ -242,6 +269,8 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
242269
}
243270

244271
/// Return the region iter_arg that corresponds to the given init operand.
272+
/// Return an "empty" block argument if the given operand is not an init
273+
/// operand of this loop op.
245274
BlockArgument getTiedLoopRegionIterArg(OpOperand *opOperand) {
246275
auto initsMutable = $_op.getInitsMutable();
247276
auto it = llvm::find(initsMutable, *opOperand);
@@ -250,7 +279,22 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
250279
return $_op.getRegionIterArgs()[std::distance(initsMutable.begin(), it)];
251280
}
252281

282+
/// Return the region iter_arg that corresponds to the given loop result.
283+
/// Return an "empty" block argument if the given OpResult is not a loop
284+
/// result or if this op does not expose any loop results.
285+
BlockArgument getTiedLoopRegionIterArg(OpResult opResult) {
286+
auto loopResults = $_op.getLoopResults();
287+
if (!loopResults)
288+
return {};
289+
auto it = llvm::find(*loopResults, opResult);
290+
if (it == loopResults->end())
291+
return {};
292+
return $_op.getRegionIterArgs()[std::distance(loopResults->begin(), it)];
293+
}
294+
253295
/// Return the init operand that corresponds to the given region iter_arg.
296+
/// Return "nullptr" if the given block argument is not a region iter_arg
297+
/// of this loop op.
254298
OpOperand *getTiedLoopInit(BlockArgument bbArg) {
255299
auto iterArgs = $_op.getRegionIterArgs();
256300
auto it = llvm::find(iterArgs, bbArg);
@@ -259,7 +303,22 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
259303
return &$_op.getInitsMutable()[std::distance(iterArgs.begin(), it)];
260304
}
261305

306+
/// Return the init operand that corresponds to the given loop result.
307+
/// Return "nullptr" if the given OpResult is not a loop result or if this
308+
/// op does not expose any loop results.
309+
OpOperand *getTiedLoopInit(OpResult opResult) {
310+
auto loopResults = $_op.getLoopResults();
311+
if (!loopResults)
312+
return nullptr;
313+
auto it = llvm::find(*loopResults, opResult);
314+
if (it == loopResults->end())
315+
return nullptr;
316+
return &$_op.getInitsMutable()[std::distance(loopResults->begin(), it)];
317+
}
318+
262319
/// Return the yielded value that corresponds to the given region iter_arg.
320+
/// Return "nullptr" if the given block argument is not a region iter_arg
321+
/// of this loop op.
263322
OpOperand *getTiedLoopYieldedValue(BlockArgument bbArg) {
264323
auto iterArgs = $_op.getRegionIterArgs();
265324
auto it = llvm::find(iterArgs, bbArg);
@@ -268,6 +327,34 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
268327
return
269328
&$_op.getYieldedValuesMutable()[std::distance(iterArgs.begin(), it)];
270329
}
330+
331+
/// Return the loop result that corresponds to the given init operand.
332+
/// Return an "empty" OpResult if the given operand is not an init operand
333+
/// of this loop op or if this op does not expose any loop results.
334+
OpResult getTiedLoopResult(OpOperand *opOperand) {
335+
auto loopResults = $_op.getLoopResults();
336+
if (!loopResults)
337+
return {};
338+
auto initsMutable = $_op.getInitsMutable();
339+
auto it = llvm::find(initsMutable, *opOperand);
340+
if (it == initsMutable.end())
341+
return {};
342+
return (*loopResults)[std::distance(initsMutable.begin(), it)];
343+
}
344+
345+
/// Return the loop result that corresponds to the given region iter_arg.
346+
/// Return an "empty" OpResult if the given block argument is not a region
347+
/// iter_arg of this loop op or if this op does not expose any loop results.
348+
OpResult getTiedLoopResult(BlockArgument bbArg) {
349+
auto loopResults = $_op.getLoopResults();
350+
if (!loopResults)
351+
return {};
352+
auto iterArgs = $_op.getRegionIterArgs();
353+
auto it = llvm::find(iterArgs, bbArg);
354+
if (it == iterArgs.end())
355+
return {};
356+
return (*loopResults)[std::distance(iterArgs.begin(), it)];
357+
}
271358
}];
272359

273360
let verifyWithRegions = 1;

mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
810810
OpBuilder::InsertionGuard g(rewriter);
811811
rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp());
812812

813-
unsigned iterArgNumber = forOp.getResultForOpOperand(*pUse).getResultNumber();
813+
unsigned iterArgNumber = forOp.getTiedLoopResult(pUse).getResultNumber();
814814
auto yieldingExtractSliceOp = forOp.getYieldedValues()[iterArgNumber]
815815
.getDefiningOp<tensor::ExtractSliceOp>();
816816
if (!yieldingExtractSliceOp)

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,8 @@ std::optional<OpFoldResult> ForOp::getSingleUpperBound() {
390390
return OpFoldResult(getUpperBound());
391391
}
392392

393+
std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
394+
393395
/// Promotes the loop body of a forOp to its containing block if the forOp
394396
/// it can be determined that the loop has a single iteration.
395397
LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {

mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ struct ForOpInterface
614614
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
615615
const AnalysisState &state) const {
616616
auto forOp = cast<scf::ForOp>(op);
617-
OpResult opResult = forOp.getResultForOpOperand(opOperand);
617+
OpResult opResult = forOp.getTiedLoopResult(&opOperand);
618618
BufferRelation relation = bufferRelation(op, opResult, state);
619619
return {{opResult, relation,
620620
/*isDefinite=*/relation == BufferRelation::Equivalent}};
@@ -625,10 +625,9 @@ struct ForOpInterface
625625
// ForOp results are equivalent to their corresponding init_args if the
626626
// corresponding iter_args and yield values are equivalent.
627627
auto forOp = cast<scf::ForOp>(op);
628-
OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
629-
auto bbArg = forOp.getTiedLoopRegionIterArg(&forOperand);
628+
BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
630629
bool equivalentYield = state.areEquivalentBufferizedValues(
631-
bbArg, forOp.getYieldedValues()[opResult.getResultNumber()]);
630+
bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get());
632631
return equivalentYield ? BufferRelation::Equivalent
633632
: BufferRelation::Unknown;
634633
}
@@ -703,16 +702,13 @@ struct ForOpInterface
703702

704703
if (auto opResult = dyn_cast<OpResult>(value)) {
705704
// The type of an OpResult must match the corresponding iter_arg type.
706-
BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(
707-
&forOp.getOpOperandForResult(opResult));
705+
BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
708706
return bufferization::getBufferType(bbArg, options, invocationStack);
709707
}
710708

711709
// Compute result/argument number.
712710
BlockArgument bbArg = cast<BlockArgument>(value);
713-
unsigned resultNum =
714-
forOp.getResultForOpOperand(*forOp.getTiedLoopInit(bbArg))
715-
.getResultNumber();
711+
unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();
716712

717713
// Compute the bufferized type.
718714
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -609,8 +609,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
609609
if (destinationInitArg &&
610610
(*destinationInitArg)->getOwner() == outerMostLoop) {
611611
unsigned iterArgNumber =
612-
outerMostLoop.getResultForOpOperand(**destinationInitArg)
613-
.getResultNumber();
612+
outerMostLoop.getTiedLoopResult(*destinationInitArg).getResultNumber();
614613
int64_t resultNumber = fusableProducer.getResultNumber();
615614
if (auto dstOp =
616615
dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {

mlir/lib/Interfaces/LoopLikeInterface.cpp

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
5858
// but the LoopLikeOpInterface provides better error messages.
5959
auto loopLikeOp = cast<LoopLikeOpInterface>(op);
6060

61-
// Verify number of inits/iter_args/yielded values.
61+
// Verify number of inits/iter_args/yielded values/loop results.
6262
if (loopLikeOp.getInits().size() != loopLikeOp.getRegionIterArgs().size())
6363
return op->emitOpError("different number of inits and region iter_args: ")
6464
<< loopLikeOp.getInits().size()
@@ -69,21 +69,43 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
6969
"different number of region iter_args and yielded values: ")
7070
<< loopLikeOp.getRegionIterArgs().size()
7171
<< " != " << loopLikeOp.getYieldedValues().size();
72+
if (loopLikeOp.getLoopResults() && loopLikeOp.getLoopResults()->size() !=
73+
loopLikeOp.getRegionIterArgs().size())
74+
return op->emitOpError(
75+
"different number of loop results and region iter_args: ")
76+
<< loopLikeOp.getLoopResults()->size()
77+
<< " != " << loopLikeOp.getRegionIterArgs().size();
7278

73-
// Verify types of inits/iter_args/yielded values.
79+
// Verify types of inits/iter_args/yielded values/loop results.
7480
int64_t i = 0;
7581
for (const auto it :
7682
llvm::zip_equal(loopLikeOp.getInits(), loopLikeOp.getRegionIterArgs(),
7783
loopLikeOp.getYieldedValues())) {
7884
if (std::get<0>(it).getType() != std::get<1>(it).getType())
79-
op->emitOpError(std::to_string(i))
80-
<< "-th init and " << i << "-th region iter_arg have different type: "
81-
<< std::get<0>(it).getType() << " != " << std::get<1>(it).getType();
85+
return op->emitOpError(std::to_string(i))
86+
<< "-th init and " << i
87+
<< "-th region iter_arg have different type: "
88+
<< std::get<0>(it).getType()
89+
<< " != " << std::get<1>(it).getType();
8290
if (std::get<1>(it).getType() != std::get<2>(it).getType())
83-
op->emitOpError(std::to_string(i))
84-
<< "-th region iter_arg and " << i
85-
<< "-th yielded value have different type: "
86-
<< std::get<1>(it).getType() << " != " << std::get<2>(it).getType();
91+
return op->emitOpError(std::to_string(i))
92+
<< "-th region iter_arg and " << i
93+
<< "-th yielded value have different type: "
94+
<< std::get<1>(it).getType()
95+
<< " != " << std::get<2>(it).getType();
96+
++i;
97+
}
98+
i = 0;
99+
if (loopLikeOp.getLoopResults()) {
100+
for (const auto it : llvm::zip_equal(loopLikeOp.getRegionIterArgs(),
101+
*loopLikeOp.getLoopResults())) {
102+
if (std::get<0>(it).getType() != std::get<1>(it).getType())
103+
return op->emitOpError(std::to_string(i))
104+
<< "-th region iter_arg and " << i
105+
<< "-th loop result have different type: "
106+
<< std::get<0>(it).getType()
107+
<< " != " << std::get<1>(it).getType();
108+
}
87109
++i;
88110
}
89111

mlir/test/Dialect/SCF/invalid.mlir

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,19 @@ func.func @not_enough_loop_results(%arg0: index, %init: f32) {
9696

9797
// -----
9898

99+
func.func @scf_for_incorrect_result_type(%arg0: index, %init: f32) {
100+
// expected-error @below{{0-th region iter_arg and 0-th loop result have different type: 'f32' != 'f64'}}
101+
"scf.for"(%arg0, %arg0, %arg0, %init) (
102+
{
103+
^bb0(%i0 : index, %iter: f32):
104+
scf.yield %iter : f32
105+
}
106+
) : (index, index, index, f32) -> (f64)
107+
return
108+
}
109+
110+
// -----
111+
99112
func.func @too_many_iter_args(%arg0: index, %init: f32) {
100113
// expected-error @below{{different number of inits and region iter_args: 1 != 2}}
101114
%x = "scf.for"(%arg0, %arg0, %arg0, %init) (
@@ -449,7 +462,6 @@ func.func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : ind
449462
%s0 = arith.constant 0.0 : f32
450463
%t0 = arith.constant 1.0 : f32
451464
// expected-error @below {{1-th region iter_arg and 1-th yielded value have different type: 'f32' != 'i32'}}
452-
// expected-error @below {{along control flow edge from Region #0 to Region #0: source type #1 'i32' should match input type #1 'f32'}}
453465
%result1:2 = scf.for %i0 = %arg0 to %arg1 step %arg2
454466
iter_args(%si = %s0, %ti = %t0) -> (f32, f32) {
455467
%sn = arith.addf %si, %si : f32

0 commit comments

Comments
 (0)