Skip to content

[mlir][Interfaces] LoopLikeOpInterface: Expose tied loop results #70535

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 2 additions & 24 deletions mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -269,28 +269,6 @@ def ForOp : SCF_Op<"for",
/// Number of operands controlling the loop: lb, ub, step
unsigned getNumControlOperands() { return 3; }

/// Get the OpResult that corresponds to an OpOperand.
/// Assert that opOperand is an iterArg.
/// This helper prevents internal op implementation detail leakage to
/// clients by hiding the operand / block argument mapping.
OpResult getResultForOpOperand(OpOperand &opOperand) {
assert(opOperand.getOperandNumber() >= getNumControlOperands() &&
"expected an iter args operand");
assert(opOperand.getOwner() == getOperation() &&
"opOperand does not belong to this scf::ForOp operation");
return getOperation()->getResult(
opOperand.getOperandNumber() - getNumControlOperands());
}
/// Get the OpOperand& that corresponds to an OpResultOpOperand.
/// This helper prevents internal op implementation detail leakage to
/// clients by hiding the operand / block argument mapping.
OpOperand &getOpOperandForResult(OpResult opResult) {
assert(opResult.getDefiningOp() == getOperation() &&
"opResult does not belong to the scf::ForOp operation");
return getOperation()->getOpOperand(
getNumControlOperands() + opResult.getResultNumber());
}

/// Returns the step as an `APInt` if it is constant.
std::optional<APInt> getConstantStep();

Expand Down Expand Up @@ -942,7 +920,7 @@ def WhileOp : SCF_Op<"while",
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands"]>,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getLoopResults", "getRegionIterArgs", "getYieldedValuesMutable"]>,
["getRegionIterArgs", "getYieldedValuesMutable"]>,
RecursiveMemoryEffects, SingleBlock]> {
let summary = "a generic 'while' loop";
let description = [{
Expand Down Expand Up @@ -1156,7 +1134,7 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
//===----------------------------------------------------------------------===//

def YieldOp : SCF_Op<"yield", [Pure, ReturnLike, Terminator,
ParentOneOf<["ExecuteRegionOp, ForOp", "IfOp", "IndexSwitchOp",
ParentOneOf<["ExecuteRegionOp", "ForOp", "IfOp", "IndexSwitchOp",
"ParallelOp", "WhileOp"]>]> {
let summary = "loop yield and termination operation";
let description = [{
Expand Down
87 changes: 87 additions & 0 deletions mlir/include/mlir/Interfaces/LoopLikeInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
If one of the respective interface methods is implemented, so must the other
two. The interface verifier ensures that the number of types of the region
iter_args, init values and yielded values match.

Optionally, "loop results" can be exposed through this interface. These are
the values that are returned from the loop op when there are no more
iterations. The number and types of the loop results must match with the
region iter_args. Note: Loop results are optional because some loops
(e.g., `scf.while`) may produce results that do match 1-to-1 with the
region iter_args.
}];
let cppNamespace = "::mlir";

Expand Down Expand Up @@ -166,6 +173,26 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
return {};
}]
>,
InterfaceMethod<[{
Return the range of results that are return from this loop and
correspond to the "init" operands.

Note: This interface method is optional. If loop results are not
exposed via this interface, "std::nullopt" should be returned.
Otherwise, the number and types of results must match with the
region iter_args, inits and yielded values that are exposed via this
interface. If loop results are exposed but this loop op has no
loop-carried variables, an empty result range (and not "std::nullopt")
should be returned.
}],
/*retTy=*/"::std::optional<::mlir::ResultRange>",
/*methodName=*/"getLoopResults",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return ::std::nullopt;
}]
>,
InterfaceMethod<[{
Append the specified additional "init" operands: replace this loop with
a new loop that has the additional init operands. The loop body of
Expand Down Expand Up @@ -242,6 +269,8 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
}

/// Return the region iter_arg that corresponds to the given init operand.
/// Return an "empty" block argument if the given operand is not an init
/// operand of this loop op.
BlockArgument getTiedLoopRegionIterArg(OpOperand *opOperand) {
auto initsMutable = $_op.getInitsMutable();
auto it = llvm::find(initsMutable, *opOperand);
Expand All @@ -250,7 +279,22 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
return $_op.getRegionIterArgs()[std::distance(initsMutable.begin(), it)];
}

/// Return the region iter_arg that corresponds to the given loop result.
/// Return an "empty" block argument if the given OpResult is not a loop
/// result or if this op does not expose any loop results.
BlockArgument getTiedLoopRegionIterArg(OpResult opResult) {
auto loopResults = $_op.getLoopResults();
if (!loopResults)
return {};
auto it = llvm::find(*loopResults, opResult);
if (it == loopResults->end())
return {};
return $_op.getRegionIterArgs()[std::distance(loopResults->begin(), it)];
}

/// Return the init operand that corresponds to the given region iter_arg.
/// Return "nullptr" if the given block argument is not a region iter_arg
/// of this loop op.
OpOperand *getTiedLoopInit(BlockArgument bbArg) {
auto iterArgs = $_op.getRegionIterArgs();
auto it = llvm::find(iterArgs, bbArg);
Expand All @@ -259,7 +303,22 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
return &$_op.getInitsMutable()[std::distance(iterArgs.begin(), it)];
}

/// Return the init operand that corresponds to the given loop result.
/// Return "nullptr" if the given OpResult is not a loop result or if this
/// op does not expose any loop results.
OpOperand *getTiedLoopInit(OpResult opResult) {
auto loopResults = $_op.getLoopResults();
if (!loopResults)
return nullptr;
auto it = llvm::find(*loopResults, opResult);
if (it == loopResults->end())
return nullptr;
return &$_op.getInitsMutable()[std::distance(loopResults->begin(), it)];
}

/// Return the yielded value that corresponds to the given region iter_arg.
/// Return "nullptr" if the given block argument is not a region iter_arg
/// of this loop op.
OpOperand *getTiedLoopYieldedValue(BlockArgument bbArg) {
auto iterArgs = $_op.getRegionIterArgs();
auto it = llvm::find(iterArgs, bbArg);
Expand All @@ -268,6 +327,34 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
return
&$_op.getYieldedValuesMutable()[std::distance(iterArgs.begin(), it)];
}

/// Return the loop result that corresponds to the given init operand.
/// Return an "empty" OpResult if the given operand is not an init operand
/// of this loop op or if this op does not expose any loop results.
OpResult getTiedLoopResult(OpOperand *opOperand) {
auto loopResults = $_op.getLoopResults();
if (!loopResults)
return {};
auto initsMutable = $_op.getInitsMutable();
auto it = llvm::find(initsMutable, *opOperand);
if (it == initsMutable.end())
return {};
return (*loopResults)[std::distance(initsMutable.begin(), it)];
}

/// Return the loop result that corresponds to the given region iter_arg.
/// Return an "empty" OpResult if the given block argument is not a region
/// iter_arg of this loop op or if this op does not expose any loop results.
OpResult getTiedLoopResult(BlockArgument bbArg) {
auto loopResults = $_op.getLoopResults();
if (!loopResults)
return {};
auto iterArgs = $_op.getRegionIterArgs();
auto it = llvm::find(iterArgs, bbArg);
if (it == iterArgs.end())
return {};
return (*loopResults)[std::distance(iterArgs.begin(), it)];
}
}];

let verifyWithRegions = 1;
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp());

unsigned iterArgNumber = forOp.getResultForOpOperand(*pUse).getResultNumber();
unsigned iterArgNumber = forOp.getTiedLoopResult(pUse).getResultNumber();
auto yieldingExtractSliceOp = forOp.getYieldedValues()[iterArgNumber]
.getDefiningOp<tensor::ExtractSliceOp>();
if (!yieldingExtractSliceOp)
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,8 @@ std::optional<OpFoldResult> ForOp::getSingleUpperBound() {
return OpFoldResult(getUpperBound());
}

std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }

/// Promotes the loop body of a forOp to its containing block if the forOp
/// it can be determined that the loop has a single iteration.
LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ struct ForOpInterface
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
auto forOp = cast<scf::ForOp>(op);
OpResult opResult = forOp.getResultForOpOperand(opOperand);
OpResult opResult = forOp.getTiedLoopResult(&opOperand);
BufferRelation relation = bufferRelation(op, opResult, state);
return {{opResult, relation,
/*isDefinite=*/relation == BufferRelation::Equivalent}};
Expand All @@ -625,10 +625,9 @@ struct ForOpInterface
// ForOp results are equivalent to their corresponding init_args if the
// corresponding iter_args and yield values are equivalent.
auto forOp = cast<scf::ForOp>(op);
OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
auto bbArg = forOp.getTiedLoopRegionIterArg(&forOperand);
BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
bool equivalentYield = state.areEquivalentBufferizedValues(
bbArg, forOp.getYieldedValues()[opResult.getResultNumber()]);
bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get());
return equivalentYield ? BufferRelation::Equivalent
: BufferRelation::Unknown;
}
Expand Down Expand Up @@ -703,16 +702,13 @@ struct ForOpInterface

if (auto opResult = dyn_cast<OpResult>(value)) {
// The type of an OpResult must match the corresponding iter_arg type.
BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(
&forOp.getOpOperandForResult(opResult));
BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
return bufferization::getBufferType(bbArg, options, invocationStack);
}

// Compute result/argument number.
BlockArgument bbArg = cast<BlockArgument>(value);
unsigned resultNum =
forOp.getResultForOpOperand(*forOp.getTiedLoopInit(bbArg))
.getResultNumber();
unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();

// Compute the bufferized type.
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -609,8 +609,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
if (destinationInitArg &&
(*destinationInitArg)->getOwner() == outerMostLoop) {
unsigned iterArgNumber =
outerMostLoop.getResultForOpOperand(**destinationInitArg)
.getResultNumber();
outerMostLoop.getTiedLoopResult(*destinationInitArg).getResultNumber();
int64_t resultNumber = fusableProducer.getResultNumber();
if (auto dstOp =
dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
Expand Down
40 changes: 31 additions & 9 deletions mlir/lib/Interfaces/LoopLikeInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
// but the LoopLikeOpInterface provides better error messages.
auto loopLikeOp = cast<LoopLikeOpInterface>(op);

// Verify number of inits/iter_args/yielded values.
// Verify number of inits/iter_args/yielded values/loop results.
if (loopLikeOp.getInits().size() != loopLikeOp.getRegionIterArgs().size())
return op->emitOpError("different number of inits and region iter_args: ")
<< loopLikeOp.getInits().size()
Expand All @@ -69,21 +69,43 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
"different number of region iter_args and yielded values: ")
<< loopLikeOp.getRegionIterArgs().size()
<< " != " << loopLikeOp.getYieldedValues().size();
if (loopLikeOp.getLoopResults() && loopLikeOp.getLoopResults()->size() !=
loopLikeOp.getRegionIterArgs().size())
return op->emitOpError(
"different number of loop results and region iter_args: ")
<< loopLikeOp.getLoopResults()->size()
<< " != " << loopLikeOp.getRegionIterArgs().size();

// Verify types of inits/iter_args/yielded values.
// Verify types of inits/iter_args/yielded values/loop results.
int64_t i = 0;
for (const auto it :
llvm::zip_equal(loopLikeOp.getInits(), loopLikeOp.getRegionIterArgs(),
loopLikeOp.getYieldedValues())) {
if (std::get<0>(it).getType() != std::get<1>(it).getType())
op->emitOpError(std::to_string(i))
<< "-th init and " << i << "-th region iter_arg have different type: "
<< std::get<0>(it).getType() << " != " << std::get<1>(it).getType();
return op->emitOpError(std::to_string(i))
<< "-th init and " << i
<< "-th region iter_arg have different type: "
<< std::get<0>(it).getType()
<< " != " << std::get<1>(it).getType();
if (std::get<1>(it).getType() != std::get<2>(it).getType())
op->emitOpError(std::to_string(i))
<< "-th region iter_arg and " << i
<< "-th yielded value have different type: "
<< std::get<1>(it).getType() << " != " << std::get<2>(it).getType();
return op->emitOpError(std::to_string(i))
<< "-th region iter_arg and " << i
<< "-th yielded value have different type: "
<< std::get<1>(it).getType()
<< " != " << std::get<2>(it).getType();
++i;
}
i = 0;
if (loopLikeOp.getLoopResults()) {
for (const auto it : llvm::zip_equal(loopLikeOp.getRegionIterArgs(),
*loopLikeOp.getLoopResults())) {
if (std::get<0>(it).getType() != std::get<1>(it).getType())
return op->emitOpError(std::to_string(i))
<< "-th region iter_arg and " << i
<< "-th loop result have different type: "
<< std::get<0>(it).getType()
<< " != " << std::get<1>(it).getType();
}
++i;
}

Expand Down
14 changes: 13 additions & 1 deletion mlir/test/Dialect/SCF/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,19 @@ func.func @not_enough_loop_results(%arg0: index, %init: f32) {

// -----

func.func @scf_for_incorrect_result_type(%arg0: index, %init: f32) {
// expected-error @below{{0-th region iter_arg and 0-th loop result have different type: 'f32' != 'f64'}}
"scf.for"(%arg0, %arg0, %arg0, %init) (
{
^bb0(%i0 : index, %iter: f32):
scf.yield %iter : f32
}
) : (index, index, index, f32) -> (f64)
return
}

// -----

func.func @too_many_iter_args(%arg0: index, %init: f32) {
// expected-error @below{{different number of inits and region iter_args: 1 != 2}}
%x = "scf.for"(%arg0, %arg0, %arg0, %init) (
Expand Down Expand Up @@ -449,7 +462,6 @@ func.func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : ind
%s0 = arith.constant 0.0 : f32
%t0 = arith.constant 1.0 : f32
// expected-error @below {{1-th region iter_arg and 1-th yielded value have different type: 'f32' != 'i32'}}
// expected-error @below {{along control flow edge from Region #0 to Region #0: source type #1 'i32' should match input type #1 'f32'}}
%result1:2 = scf.for %i0 = %arg0 to %arg1 step %arg2
iter_args(%si = %s0, %ti = %t0) -> (f32, f32) {
%sn = arith.addf %si, %si : f32
Expand Down