Skip to content

Commit 833a8db

Browse files
authored
[mlir][scf] Implement getSingle... of LoopLikeOpInterface for scf::ForallOp (#67883)
The `getSingle(IterationVar|UpperBound|LowerBound|Step)` methods of `LoopLikeOpInterface` are useful to quickly query the iteration space of unidimensional loops. Until now, `scf::ForallOp` always fell back to the default implementation of these methods, returning `std::nullopt`. This patch implements those methods, returning the respective bounds or steps in the special case of `rank == 1`.
1 parent 7f7a15c commit 833a8db

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,8 @@ def ForallOp : SCF_Op<"forall", [
333333
AttrSizedOperandSegments,
334334
AutomaticAllocationScope,
335335
DeclareOpInterfaceMethods<LoopLikeOpInterface,
336-
["promoteIfSingleIteration"]>,
336+
["promoteIfSingleIteration", "getSingleInductionVar",
337+
"getSingleLowerBound", "getSingleUpperBound", "getSingleStep"]>,
337338
RecursiveMemoryEffects,
338339
SingleBlockImplicitTerminator<"scf::InParallelOp">,
339340
DeclareOpInterfaceMethods<RegionBranchOpInterface>,

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1527,6 +1527,30 @@ InParallelOp ForallOp::getTerminator() {
15271527
return cast<InParallelOp>(getBody()->getTerminator());
15281528
}
15291529

1530+
std::optional<Value> ForallOp::getSingleInductionVar() {
1531+
if (getRank() != 1)
1532+
return std::nullopt;
1533+
return getInductionVar(0);
1534+
}
1535+
1536+
std::optional<OpFoldResult> ForallOp::getSingleLowerBound() {
1537+
if (getRank() != 1)
1538+
return std::nullopt;
1539+
return getMixedLowerBound()[0];
1540+
}
1541+
1542+
std::optional<OpFoldResult> ForallOp::getSingleUpperBound() {
1543+
if (getRank() != 1)
1544+
return std::nullopt;
1545+
return getMixedUpperBound()[0];
1546+
}
1547+
1548+
std::optional<OpFoldResult> ForallOp::getSingleStep() {
1549+
if (getRank() != 1)
1550+
return std::nullopt;
1551+
return getMixedStep()[0];
1552+
}
1553+
15301554
ForallOp mlir::scf::getForallOpThreadIndexOwner(Value val) {
15311555
auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
15321556
if (!tidxArg)

0 commit comments

Comments
 (0)