Skip to content

[mlir][Interfaces] LoopLikeOpInterface: Add helper to get yielded values #67305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions flang/include/flang/Optimizer/Dialect/FIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2112,6 +2112,8 @@ def fir_DoLoopOp : region_Op<"do_loop",
mlir::Operation::operand_range getIterOperands() {
return getOperands().drop_front(getNumControlOperands());
}
mlir::OperandRange getInits() { return getIterOperands(); }
mlir::ValueRange getYieldedValues();

void setLowerBound(mlir::Value bound) { (*this)->setOperand(0, bound); }
void setUpperBound(mlir::Value bound) { (*this)->setOperand(1, bound); }
Expand Down Expand Up @@ -2263,6 +2265,8 @@ def fir_IterWhileOp : region_Op<"iterate_while",
mlir::Operation::operand_range getIterOperands() {
return getOperands().drop_front(getNumControlOperands());
}
mlir::OperandRange getInits() { return getIterOperands(); }
mlir::ValueRange getYieldedValues();

void setLowerBound(mlir::Value bound) { (*this)->setOperand(0, bound); }
void setUpperBound(mlir::Value bound) { (*this)->setOperand(1, bound); }
Expand Down
12 changes: 12 additions & 0 deletions flang/lib/Optimizer/Dialect/FIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1972,6 +1972,12 @@ mlir::Value fir::IterWhileOp::blockArgToSourceOp(unsigned blockArgNum) {
return {};
}

mlir::ValueRange fir::IterWhileOp::getYieldedValues() {
auto *term = getRegion().front().getTerminator();
return getFinalValue() ? term->getOperands().drop_front()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just double checking, the drop_front here drops the final value, is that correct?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's correct. (There is similar code in the op verifier: auto opResults = getFinalValue() ? getResults().drop_front() : getResults();)

: term->getOperands();
}

//===----------------------------------------------------------------------===//
// LenParamIndexOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2267,6 +2273,12 @@ mlir::Value fir::DoLoopOp::blockArgToSourceOp(unsigned blockArgNum) {
return {};
}

mlir::ValueRange fir::DoLoopOp::getYieldedValues() {
auto *term = getRegion().front().getTerminator();
return getFinalValue() ? term->getOperands().drop_front()
: term->getOperands();
}

//===----------------------------------------------------------------------===//
// DTEntryOp
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def AffineForOp : Affine_Op<"for",
ImplicitAffineTerminator, ConditionallySpeculatable,
RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
"getSingleUpperBound", "replaceWithAdditionalYields"]>,
"getSingleUpperBound", "getYieldedValues",
"replaceWithAdditionalYields"]>,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands"]>]> {
let summary = "for operation";
Expand Down
11 changes: 9 additions & 2 deletions mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
def ForOp : SCF_Op<"for",
[AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getInits", "getSingleInductionVar", "getSingleLowerBound",
"getSingleStep", "getSingleUpperBound", "promoteIfSingleIteration",
"replaceWithAdditionalYields"]>,
"getSingleStep", "getSingleUpperBound", "getYieldedValues",
"promoteIfSingleIteration", "replaceWithAdditionalYields"]>,
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
ConditionallySpeculatable,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
Expand Down Expand Up @@ -243,9 +243,11 @@ def ForOp : SCF_Op<"for",
function_ref<void(OpBuilder &, Location, Value, ValueRange)>;

Value getInductionVar() { return getBody()->getArgument(0); }

Block::BlockArgListType getRegionIterArgs() {
return getBody()->getArguments().drop_front(getNumInductionVars());
}

/// Return the `index`-th region iteration argument.
BlockArgument getRegionIterArg(unsigned index) {
assert(index < getNumRegionIterArgs() &&
Expand Down Expand Up @@ -1086,6 +1088,11 @@ def WhileOp : SCF_Op<"while",

ConditionOp getConditionOp();
YieldOp getYieldOp();

/// Return the values that are yielded from the "after" region (by the
/// scf.yield op).
ValueRange getYieldedValues();

Block::BlockArgListType getBeforeArguments();
Block::BlockArgListType getAfterArguments();
Block *getBeforeBody() { return &getBefore().front(); }
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Interfaces/LoopLikeInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ class RewriterBase;
/// arguments in `newBbArgs`.
using NewYieldValuesFn = std::function<SmallVector<Value>(
OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs)>;

namespace detail {
/// Verify invariants of the LoopLikeOpInterface.
LogicalResult verifyLoopLikeOpInterface(Operation *op);
} // namespace detail
} // namespace mlir

/// Include the generated interface declarations.
Expand Down
30 changes: 30 additions & 0 deletions mlir/include/mlir/Interfaces/LoopLikeInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
Contains helper functions to query properties and perform transformations
of a loop. Operations that implement this interface will be considered by
loop-invariant code motion.

Loop-carried variables can be exposed through this interface. There are
3 components to a loop-carried variable.
- The "region iter_arg" is the block argument of the entry block that
represents the loop-carried variable in each iteration.
- The "init value" is an operand of the loop op that serves as the initial
region iter_arg value for the first iteration (if any).
- The "yielded" value is the value that is forwarded from one iteration to
serve as the region iter_arg of the next iteration.

If one of the respective interface methods is implemented, so must the other
two. The interface verifier ensures that the number of types of the region
iter_args, init values and yielded values match.
}];
let cppNamespace = "::mlir";

Expand Down Expand Up @@ -141,6 +154,17 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
return ::mlir::Block::BlockArgListType();
}]
>,
InterfaceMethod<[{
Return the values that are yielded to the next iteration.
}],
/*retTy=*/"::mlir::ValueRange",
/*methodName=*/"getYieldedValues",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return ::mlir::ValueRange();
}]
>,
InterfaceMethod<[{
Append the specified additional "init" operands: replace this loop with
a new loop that has the additional init operands. The loop body of
Expand Down Expand Up @@ -192,6 +216,12 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
});
}
}];

let verifyWithRegions = 1;

let verify = [{
return detail::verifyLoopLikeOpInterface($_op);
}];
}

#endif // MLIR_INTERFACES_LOOPLIKEINTERFACE
4 changes: 4 additions & 0 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2215,6 +2215,10 @@ unsigned AffineForOp::getNumIterOperands() {
return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
}

ValueRange AffineForOp::getYieldedValues() {
return cast<AffineYieldOp>(getBody()->getTerminator()).getOperands();
}

void AffineForOp::print(OpAsmPrinter &p) {
p << ' ';
p.printRegionArgument(getBody()->getArgument(0), /*argAttrs=*/{},
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -811,8 +811,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp());

unsigned iterArgNumber = forOp.getResultForOpOperand(*pUse).getResultNumber();
auto yieldOp = cast<scf::YieldOp>(forOp.getBody(0)->getTerminator());
auto yieldingExtractSliceOp = yieldOp->getOperand(iterArgNumber)
auto yieldingExtractSliceOp = forOp.getYieldedValues()[iterArgNumber]
.getDefiningOp<tensor::ExtractSliceOp>();
if (!yieldingExtractSliceOp)
return tensor::ExtractSliceOp();
Expand All @@ -826,7 +825,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,

SmallVector<Value> initArgs = forOp.getInitArgs();
initArgs[iterArgNumber] = hoistedPackedTensor;
SmallVector<Value> yieldOperands = yieldOp.getOperands();
SmallVector<Value> yieldOperands = llvm::to_vector(forOp.getYieldedValues());
yieldOperands[iterArgNumber] = yieldingExtractSliceOp.getSource();

int64_t numOriginalForOpResults = initArgs.size();
Expand Down
27 changes: 15 additions & 12 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {

// Replace all results with the yielded values.
auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
rewriter.replaceAllUsesWith(getResults(), yieldOp.getOperands());
rewriter.replaceAllUsesWith(getResults(), getYieldedValues());

// Replace block arguments with lower bound (replacement for IV) and
// iter_args.
Expand Down Expand Up @@ -772,27 +772,26 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
LogicalResult matchAndRewrite(scf::ForOp forOp,
PatternRewriter &rewriter) const final {
bool canonicalize = false;
Block &block = forOp.getRegion().front();
auto yieldOp = cast<scf::YieldOp>(block.getTerminator());

// An internal flat vector of block transfer
// arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
// transformed block argument mappings. This plays the role of a
// IRMapping for the particular use case of calling into
// `inlineBlockBefore`.
int64_t numResults = forOp.getNumResults();
SmallVector<bool, 4> keepMask;
keepMask.reserve(yieldOp.getNumOperands());
keepMask.reserve(numResults);
SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
newResultValues;
newBlockTransferArgs.reserve(1 + forOp.getInitArgs().size());
newBlockTransferArgs.reserve(1 + numResults);
newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
newIterArgs.reserve(forOp.getInitArgs().size());
newYieldValues.reserve(yieldOp.getNumOperands());
newResultValues.reserve(forOp.getNumResults());
newYieldValues.reserve(numResults);
newResultValues.reserve(numResults);
for (auto it : llvm::zip(forOp.getInitArgs(), // iter from outside
forOp.getRegionIterArgs(), // iter inside region
forOp.getResults(), // op results
yieldOp.getOperands() // iter yield
forOp.getYieldedValues() // iter yield
)) {
// Forwarded is `true` when:
// 1) The region `iter` argument is yielded.
Expand Down Expand Up @@ -946,12 +945,10 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
return failure();
// If the loop is empty, iterates at least once, and only returns values
// defined outside of the loop, remove it and replace it with yield values.
auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
auto yieldOperands = yieldOp.getOperands();
if (llvm::any_of(yieldOperands,
if (llvm::any_of(op.getYieldedValues(),
[&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
return failure();
rewriter.replaceOp(op, yieldOperands);
rewriter.replaceOp(op, op.getYieldedValues());
return success();
}
};
Expand Down Expand Up @@ -1224,6 +1221,10 @@ std::optional<APInt> ForOp::getConstantStep() {
return {};
}

ValueRange ForOp::getYieldedValues() {
return cast<scf::YieldOp>(getBody()->getTerminator()).getResults();
}

Speculation::Speculatability ForOp::getSpeculatability() {
// `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
// and End.
Expand Down Expand Up @@ -3205,6 +3206,8 @@ YieldOp WhileOp::getYieldOp() {
return cast<YieldOp>(getAfterBody()->getTerminator());
}

ValueRange WhileOp::getYieldedValues() { return getYieldOp().getResults(); }

Block::BlockArgListType WhileOp::getBeforeArguments() {
return getBeforeBody()->getArguments();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -605,9 +605,8 @@ struct ForOpInterface
auto forOp = cast<scf::ForOp>(op);
OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
bool equivalentYield = state.areEquivalentBufferizedValues(
bbArg, yieldOp->getOperand(opResult.getResultNumber()));
bbArg, forOp.getYieldedValues()[opResult.getResultNumber()]);
return equivalentYield ? BufferRelation::Equivalent
: BufferRelation::Unknown;
}
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,9 @@ using namespace mlir::scf;
/// type of the corresponding basic block argument of the loop.
/// Note: This function handles only simple cases. Expand as needed.
static bool isShapePreserving(ForOp forOp, int64_t arg) {
auto yieldOp = cast<YieldOp>(forOp.getBody()->getTerminator());
assert(arg < static_cast<int64_t>(yieldOp.getResults().size()) &&
assert(arg < static_cast<int64_t>(forOp.getNumResults()) &&
"arg is out of bounds");
Value value = yieldOp.getResults()[arg];
Value value = forOp.getYieldedValues()[arg];
while (value) {
if (value == forOp.getRegionIterArgs()[arg])
return true;
Expand Down
37 changes: 37 additions & 0 deletions mlir/lib/Interfaces/LoopLikeInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,40 @@ bool LoopLikeOpInterface::blockIsInLoop(Block *block) {
}
return false;
}

LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
// Note: These invariants are also verified by the RegionBranchOpInterface,
// but the LoopLikeOpInterface provides better error messages.
auto loopLikeOp = cast<LoopLikeOpInterface>(op);

// Verify number of inits/iter_args/yielded values.
if (loopLikeOp.getInits().size() != loopLikeOp.getRegionIterArgs().size())
return op->emitOpError("different number of inits and region iter_args: ")
<< loopLikeOp.getInits().size()
<< " != " << loopLikeOp.getRegionIterArgs().size();
if (loopLikeOp.getRegionIterArgs().size() !=
loopLikeOp.getYieldedValues().size())
return op->emitOpError(
"different number of region iter_args and yielded values: ")
<< loopLikeOp.getRegionIterArgs().size()
<< " != " << loopLikeOp.getYieldedValues().size();

// Verify types of inits/iter_args/yielded values.
int64_t i = 0;
for (const auto it :
llvm::zip_equal(loopLikeOp.getInits(), loopLikeOp.getRegionIterArgs(),
loopLikeOp.getYieldedValues())) {
if (std::get<0>(it).getType() != std::get<1>(it).getType())
op->emitOpError(std::to_string(i))
<< "-th init and " << i << "-th region iter_arg have different type: "
<< std::get<0>(it).getType() << " != " << std::get<1>(it).getType();
if (std::get<1>(it).getType() != std::get<2>(it).getType())
op->emitOpError(std::to_string(i))
<< "-th region iter_arg and " << i
<< "-th yielded value have different type: "
<< std::get<1>(it).getType() << " != " << std::get<2>(it).getType();
++i;
}

return success();
}
30 changes: 28 additions & 2 deletions mlir/test/Dialect/SCF/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,32 @@ func.func @not_enough_loop_results(%arg0: index, %init: f32) {

// -----

func.func @too_many_iter_args(%arg0: index, %init: f32) {
// expected-error @below{{different number of inits and region iter_args: 1 != 2}}
%x = "scf.for"(%arg0, %arg0, %arg0, %init) (
{
^bb0(%i0 : index, %iter: f32, %iter2: f32):
scf.yield %iter, %iter : f32, f32
}
) : (index, index, index, f32) -> (f32)
return
}

// -----

func.func @too_few_yielded_values(%arg0: index, %init: f32) {
// expected-error @below{{different number of region iter_args and yielded values: 2 != 1}}
%x, %x2 = "scf.for"(%arg0, %arg0, %arg0, %init, %init) (
{
^bb0(%i0 : index, %iter: f32, %iter2: f32):
scf.yield %iter : f32
}
) : (index, index, index, f32, f32) -> (f32, f32)
return
}

// -----

func.func @loop_if_not_i1(%arg0: index) {
// expected-error@+1 {{operand #0 must be 1-bit signless integer}}
"scf.if"(%arg0) ({}, {}) : (index) -> ()
Expand Down Expand Up @@ -422,7 +448,8 @@ func.func @std_for_operands_mismatch_3(%arg0 : index, %arg1 : index, %arg2 : ind
func.func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : index) {
%s0 = arith.constant 0.0 : f32
%t0 = arith.constant 1.0 : f32
// expected-error @+1 {{along control flow edge from Region #0 to Region #0: source type #1 'i32' should match input type #1 'f32'}}
// expected-error @below {{1-th region iter_arg and 1-th yielded value have different type: 'f32' != 'i32'}}
// expected-error @below {{along control flow edge from Region #0 to Region #0: source type #1 'i32' should match input type #1 'f32'}}
%result1:2 = scf.for %i0 = %arg0 to %arg1 step %arg2
iter_args(%si = %s0, %ti = %t0) -> (f32, f32) {
%sn = arith.addf %si, %si : f32
Expand All @@ -432,7 +459,6 @@ func.func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : ind
return
}


// -----

func.func @parallel_invalid_yield(
Expand Down
5 changes: 2 additions & 3 deletions mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,8 @@ struct TestSCFForUtilsPass
auto newInitValues = forOp.getInitArgs();
if (newInitValues.empty())
return;
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
SmallVector<Value> oldYieldValues(yieldOp.getResults().begin(),
yieldOp.getResults().end());
SmallVector<Value> oldYieldValues =
llvm::to_vector(forOp.getYieldedValues());
NewYieldValuesFn fn = [&](OpBuilder &b, Location loc,
ArrayRef<BlockArgument> newBBArgs) {
SmallVector<Value> newYieldValues;
Expand Down