Skip to content

[mlir][affine] implement promoteIfSingleIteration for AffineForOp #72547

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 1 addition & 14 deletions mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"

#include <optional>

namespace mlir {
Expand All @@ -29,20 +30,6 @@ namespace affine {
class AffineForOp;
class NestedPattern;

/// Returns the trip count of the loop as an affine map with its corresponding
/// operands if the latter is expressible as an affine expression, and nullptr
/// otherwise. This method always succeeds as long as the lower bound is not a
/// multi-result map. The trip count expression is simplified before returning.
/// This method only utilizes map composition to construct lower and upper
/// bounds before computing the trip count expressions
void getTripCountMapAndOperands(AffineForOp forOp, AffineMap *map,
SmallVectorImpl<Value> *operands);

/// Returns the trip count of the loop if it's a constant, std::nullopt
/// otherwise. This uses affine expression analysis and is able to determine
/// constant trip count in non-trivial cases.
std::optional<uint64_t> getConstantTripCount(AffineForOp forOp);

/// Returns the greatest known integral divisor of the trip count. Affine
/// expression analysis is used (indirectly through getTripCount), and
/// this method is thus able to determine non-trivial divisors.
Expand Down
29 changes: 25 additions & 4 deletions mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ class AffineDmaStartOp
/// Returns the affine map used to access the source memref.
AffineMap getSrcMap() { return getSrcMapAttr().getValue(); }
AffineMapAttr getSrcMapAttr() {
return cast<AffineMapAttr>(*(*this)->getInherentAttr(getSrcMapAttrStrName()));
return cast<AffineMapAttr>(
*(*this)->getInherentAttr(getSrcMapAttrStrName()));
}

/// Returns the source memref affine map indices for this DMA operation.
Expand Down Expand Up @@ -156,7 +157,8 @@ class AffineDmaStartOp
/// Returns the affine map used to access the destination memref.
AffineMap getDstMap() { return getDstMapAttr().getValue(); }
AffineMapAttr getDstMapAttr() {
return cast<AffineMapAttr>(*(*this)->getInherentAttr(getDstMapAttrStrName()));
return cast<AffineMapAttr>(
*(*this)->getInherentAttr(getDstMapAttrStrName()));
}

/// Returns the destination memref indices for this DMA operation.
Expand Down Expand Up @@ -185,7 +187,8 @@ class AffineDmaStartOp
/// Returns the affine map used to access the tag memref.
AffineMap getTagMap() { return getTagMapAttr().getValue(); }
AffineMapAttr getTagMapAttr() {
return cast<AffineMapAttr>(*(*this)->getInherentAttr(getTagMapAttrStrName()));
return cast<AffineMapAttr>(
*(*this)->getInherentAttr(getTagMapAttrStrName()));
}

/// Returns the tag memref indices for this DMA operation.
Expand Down Expand Up @@ -307,7 +310,8 @@ class AffineDmaWaitOp
/// Returns the affine map used to access the tag memref.
AffineMap getTagMap() { return getTagMapAttr().getValue(); }
AffineMapAttr getTagMapAttr() {
return cast<AffineMapAttr>(*(*this)->getInherentAttr(getTagMapAttrStrName()));
return cast<AffineMapAttr>(
*(*this)->getInherentAttr(getTagMapAttrStrName()));
}

/// Returns the tag memref index for this DMA operation.
Expand Down Expand Up @@ -465,6 +469,23 @@ AffineForOp getForInductionVarOwner(Value val);
/// AffineParallelOp.
AffineParallelOp getAffineParallelInductionVarOwner(Value val);

/// Helper to replace uses of loop carried values (iter_args) and loop
/// yield values while promoting single iteration affine.for ops.
void replaceIterArgsAndYieldResults(AffineForOp forOp);

/// Returns the trip count of the loop as an affine expression if the latter is
/// expressible as an affine expression, and nullptr otherwise. The trip count
/// expression is simplified before returning. This method only utilizes map
/// composition to construct lower and upper bounds before computing the trip
/// count expressions.
void getTripCountMapAndOperands(AffineForOp forOp, AffineMap *tripCountMap,
SmallVectorImpl<Value> *tripCountOperands);

/// Returns the trip count of the loop if it's a constant, std::nullopt
/// otherwise. This uses affine expression analysis and is able to determine
/// constant trip count in non-trivial cases.
std::optional<uint64_t> getConstantTripCount(AffineForOp forOp);

/// Extracts the induction variables from a list of AffineForOps and places them
/// in the output argument `ivs`.
void extractForInductionVars(ArrayRef<AffineForOp> forInsts,
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def AffineForOp : Affine_Op<"for",
ImplicitAffineTerminator, ConditionallySpeculatable,
RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
"getSingleUpperBound", "getYieldedValuesMutable",
"getSingleUpperBound", "getYieldedValuesMutable", "promoteIfSingleIteration",
"replaceWithAdditionalYields"]>,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands"]>]> {
Expand Down
4 changes: 0 additions & 4 deletions mlir/include/mlir/Dialect/Affine/LoopUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,6 @@ LogicalResult loopUnrollJamByFactor(AffineForOp forOp,
LogicalResult loopUnrollJamUpToFactor(AffineForOp forOp,
uint64_t unrollJamFactor);

/// Promotes the loop body of a AffineForOp to its containing block if the loop
/// was known to have a single iteration.
LogicalResult promoteIfSingleIteration(AffineForOp forOp);

/// Promotes all single iteration AffineForOp's in the Function, i.e., moves
/// their body into the containing Block.
void promoteSingleIterationLoops(func::FuncOp f);
Expand Down
80 changes: 1 addition & 79 deletions mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,101 +12,23 @@

#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"

#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
#include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/Support/MathExtras.h"

#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallString.h"

#include <numeric>
#include <optional>
#include <type_traits>

using namespace mlir;
using namespace mlir::affine;

/// Returns the trip count of the loop as an affine expression if the latter is
/// expressible as an affine expression, and nullptr otherwise. The trip count
/// expression is simplified before returning. This method only utilizes map
/// composition to construct lower and upper bounds before computing the trip
/// count expressions.
void mlir::affine::getTripCountMapAndOperands(
AffineForOp forOp, AffineMap *tripCountMap,
SmallVectorImpl<Value> *tripCountOperands) {
MLIRContext *context = forOp.getContext();
int64_t step = forOp.getStepAsInt();
int64_t loopSpan;
if (forOp.hasConstantBounds()) {
int64_t lb = forOp.getConstantLowerBound();
int64_t ub = forOp.getConstantUpperBound();
loopSpan = ub - lb;
if (loopSpan < 0)
loopSpan = 0;
*tripCountMap = AffineMap::getConstantMap(ceilDiv(loopSpan, step), context);
tripCountOperands->clear();
return;
}
auto lbMap = forOp.getLowerBoundMap();
auto ubMap = forOp.getUpperBoundMap();
if (lbMap.getNumResults() != 1) {
*tripCountMap = AffineMap();
return;
}

// Difference of each upper bound expression from the single lower bound
// expression (divided by the step) provides the expressions for the trip
// count map.
AffineValueMap ubValueMap(ubMap, forOp.getUpperBoundOperands());

SmallVector<AffineExpr, 4> lbSplatExpr(ubValueMap.getNumResults(),
lbMap.getResult(0));
auto lbMapSplat = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(),
lbSplatExpr, context);
AffineValueMap lbSplatValueMap(lbMapSplat, forOp.getLowerBoundOperands());

AffineValueMap tripCountValueMap;
AffineValueMap::difference(ubValueMap, lbSplatValueMap, &tripCountValueMap);
for (unsigned i = 0, e = tripCountValueMap.getNumResults(); i < e; ++i)
tripCountValueMap.setResult(i,
tripCountValueMap.getResult(i).ceilDiv(step));

*tripCountMap = tripCountValueMap.getAffineMap();
tripCountOperands->assign(tripCountValueMap.getOperands().begin(),
tripCountValueMap.getOperands().end());
}

/// Returns the trip count of the loop if it's a constant, std::nullopt
/// otherwise. This method uses affine expression analysis (in turn using
/// getTripCount) and is able to determine constant trip count in non-trivial
/// cases.
std::optional<uint64_t> mlir::affine::getConstantTripCount(AffineForOp forOp) {
SmallVector<Value, 4> operands;
AffineMap map;
getTripCountMapAndOperands(forOp, &map, &operands);

if (!map)
return std::nullopt;

// Take the min if all trip counts are constant.
std::optional<uint64_t> tripCount;
for (auto resultExpr : map.getResults()) {
if (auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr)) {
if (tripCount.has_value())
tripCount =
std::min(*tripCount, static_cast<uint64_t>(constExpr.getValue()));
else
tripCount = constExpr.getValue();
} else
return std::nullopt;
}
return tripCount;
}

/// Returns the greatest known integral divisor of the trip count. Affine
/// expression analysis is used (indirectly through getTripCount), and
/// this method is thus able to determine non-trivial divisors.
Expand Down
147 changes: 140 additions & 7 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/IRMapping.h"
Expand Down Expand Up @@ -2440,6 +2442,65 @@ std::optional<OpFoldResult> AffineForOp::getSingleUpperBound() {
return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()));
}

void mlir::affine::replaceIterArgsAndYieldResults(AffineForOp forOp) {
// Replace uses of iter arguments with iter operands (initial values).
OperandRange iterOperands = forOp.getInits();
MutableArrayRef<BlockArgument> iterArgs = forOp.getRegionIterArgs();
for (auto [operand, arg] : llvm::zip(iterOperands, iterArgs))
arg.replaceAllUsesWith(operand);

// Replace uses of loop results with the values yielded by the loop.
ResultRange outerResults = forOp.getResults();
OperandRange innerResults = forOp.getBody()->getTerminator()->getOperands();
for (auto [outer, inner] : llvm::zip(outerResults, innerResults))
outer.replaceAllUsesWith(inner);
}

LogicalResult AffineForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
auto forOp = cast<AffineForOp>(getOperation());
std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
if (!tripCount || *tripCount != 1)
return failure();

// TODO: extend this for arbitrary affine bounds.
if (forOp.getLowerBoundMap().getNumResults() != 1)
return failure();

// Replaces all IV uses to its single iteration value.
BlockArgument iv = forOp.getInductionVar();
Block *parentBlock = forOp->getBlock();
if (!iv.use_empty()) {
if (forOp.hasConstantLowerBound()) {
OpBuilder topBuilder(forOp->getParentOfType<func::FuncOp>().getBody());
auto constOp = topBuilder.create<arith::ConstantIndexOp>(
forOp.getLoc(), forOp.getConstantLowerBound());
iv.replaceAllUsesWith(constOp);
} else {
OperandRange lbOperands = forOp.getLowerBoundOperands();
AffineMap lbMap = forOp.getLowerBoundMap();
OpBuilder builder(forOp);
if (lbMap == builder.getDimIdentityMap()) {
// No need of generating an affine.apply.
iv.replaceAllUsesWith(lbOperands[0]);
} else {
auto affineApplyOp =
builder.create<AffineApplyOp>(forOp.getLoc(), lbMap, lbOperands);
iv.replaceAllUsesWith(affineApplyOp);
}
}
}

replaceIterArgsAndYieldResults(forOp);

// Move the loop body operations, except for its terminator, to the loop's
// containing block.
forOp.getBody()->back().erase();
parentBlock->getOperations().splice(Block::iterator(forOp),
forOp.getBody()->getOperations());
forOp.erase();
return success();
}

FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
RewriterBase &rewriter, ValueRange newInitOperands,
bool replaceInitOperandUsesInLoop,
Expand Down Expand Up @@ -2538,6 +2599,79 @@ AffineParallelOp mlir::affine::getAffineParallelInductionVarOwner(Value val) {
return nullptr;
}

/// Returns the trip count of the loop as an affine expression if the latter is
/// expressible as an affine expression, and nullptr otherwise. The trip count
/// expression is simplified before returning. This method only utilizes map
/// composition to construct lower and upper bounds before computing the trip
/// count expressions.
void mlir::affine::getTripCountMapAndOperands(
AffineForOp forOp, AffineMap *tripCountMap,
SmallVectorImpl<Value> *tripCountOperands) {
MLIRContext *context = forOp.getContext();
int64_t step = forOp.getStepAsInt();
int64_t loopSpan;
if (forOp.hasConstantBounds()) {
int64_t lb = forOp.getConstantLowerBound();
int64_t ub = forOp.getConstantUpperBound();
loopSpan = ub - lb;
if (loopSpan < 0)
loopSpan = 0;
*tripCountMap = AffineMap::getConstantMap(ceilDiv(loopSpan, step), context);
tripCountOperands->clear();
return;
}
auto lbMap = forOp.getLowerBoundMap();
auto ubMap = forOp.getUpperBoundMap();
if (lbMap.getNumResults() != 1) {
*tripCountMap = AffineMap();
return;
}

// Difference of each upper bound expression from the single lower bound
// expression (divided by the step) provides the expressions for the trip
// count map.
AffineValueMap ubValueMap(ubMap, forOp.getUpperBoundOperands());

SmallVector<AffineExpr, 4> lbSplatExpr(ubValueMap.getNumResults(),
lbMap.getResult(0));
auto lbMapSplat = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(),
lbSplatExpr, context);
AffineValueMap lbSplatValueMap(lbMapSplat, forOp.getLowerBoundOperands());

AffineValueMap tripCountValueMap;
AffineValueMap::difference(ubValueMap, lbSplatValueMap, &tripCountValueMap);
for (unsigned i = 0, e = tripCountValueMap.getNumResults(); i < e; ++i)
tripCountValueMap.setResult(i,
tripCountValueMap.getResult(i).ceilDiv(step));

*tripCountMap = tripCountValueMap.getAffineMap();
tripCountOperands->assign(tripCountValueMap.getOperands().begin(),
tripCountValueMap.getOperands().end());
}

std::optional<uint64_t> mlir::affine::getConstantTripCount(AffineForOp forOp) {
SmallVector<Value, 4> operands;
AffineMap map;
getTripCountMapAndOperands(forOp, &map, &operands);

if (!map)
return std::nullopt;

// Take the min if all trip counts are constant.
std::optional<uint64_t> tripCount;
for (auto resultExpr : map.getResults()) {
if (auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr)) {
if (tripCount.has_value())
tripCount =
std::min(*tripCount, static_cast<uint64_t>(constExpr.getValue()));
else
tripCount = constExpr.getValue();
} else
return std::nullopt;
}
return tripCount;
}

/// Extracts the induction variables from a list of AffineForOps and returns
/// them.
void mlir::affine::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
Expand Down Expand Up @@ -2905,8 +3039,7 @@ static void composeSetAndOperands(IntegerSet &set,
}

/// Canonicalize an affine if op's conditional (integer set + operands).
LogicalResult AffineIfOp::fold(FoldAdaptor,
SmallVectorImpl<OpFoldResult> &) {
LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
auto set = getIntegerSet();
SmallVector<Value, 4> operands(getOperands());
composeSetAndOperands(set, operands);
Expand Down Expand Up @@ -2997,11 +3130,11 @@ static LogicalResult
verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
Operation::operand_range mapOperands,
MemRefType memrefType, unsigned numIndexOperands) {
AffineMap map = mapAttr.getValue();
if (map.getNumResults() != memrefType.getRank())
return op->emitOpError("affine map num results must equal memref rank");
if (map.getNumInputs() != numIndexOperands)
return op->emitOpError("expects as many subscripts as affine map inputs");
AffineMap map = mapAttr.getValue();
if (map.getNumResults() != memrefType.getRank())
return op->emitOpError("affine map num results must equal memref rank");
if (map.getNumInputs() != numIndexOperands)
return op->emitOpError("expects as many subscripts as affine map inputs");

Region *scope = getAffineScope(op);
for (auto idx : mapOperands) {
Expand Down
Loading