Skip to content

Generalize affine fusion to work at all depths and inside other region-holding ops #72288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions mlir/include/mlir/Analysis/FlatLinearValueConstraints.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,10 +374,10 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
setValue(i, values[i - start]);
}

/// Looks up the position of the variable with the specified Value. Returns
/// true if found (false otherwise). `pos` is set to the (column) position of
/// the variable.
bool findVar(Value val, unsigned *pos) const;
/// Looks up the position of the variable with the specified Value starting
/// with variables at offset `offset`. Returns true if found (false
/// otherwise). `pos` is set to the (column) position of the variable.
bool findVar(Value val, unsigned *pos, unsigned offset = 0) const;

/// Returns true if a variable with the specified Value exists, false
/// otherwise.
Expand Down
16 changes: 8 additions & 8 deletions mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,20 +213,20 @@ struct MemRefDependenceGraph {
};

/// Populates 'loops' with IVs of the affine.for ops surrounding 'op' ordered
/// from the outermost 'affine.for' operation to the innermost one.
/// from the outermost 'affine.for' operation to the innermost one while not
/// traversing outside of the surrounding affine scope.
void getAffineForIVs(Operation &op, SmallVectorImpl<AffineForOp> *loops);

/// Populates 'ivs' with IVs of the surrounding affine.for and affine.parallel
/// ops ordered from the outermost one to the innermost.
/// ops ordered from the outermost one to the innermost while not traversing
/// outside of the surrounding affine scope.
void getAffineIVs(Operation &op, SmallVectorImpl<Value> &ivs);

/// Populates 'ops' with affine operations enclosing `op` ordered from outermost
/// to innermost. affine.for, affine.if, or affine.parallel ops comprise such
/// surrounding affine ops.
/// TODO: Change this to return a list of enclosing ops up until the op that
/// starts an `AffineScope`. In such a case, `ops` is guaranteed by design to
/// have a successive chain of affine parent ops, and this is primarily what is
/// needed for most analyses.
/// to innermost while stopping at the boundary of the affine scope. affine.for,
/// affine.if, or affine.parallel ops comprise such surrounding affine ops.
/// `ops` is guaranteed by design to have a successive chain of affine parent
/// ops.
void getEnclosingAffineOps(Operation &op, SmallVectorImpl<Operation *> *ops);

/// Returns the nesting depth of this operation, i.e., the number of loops
Expand Down
1 change: 0 additions & 1 deletion mlir/include/mlir/Dialect/Affine/LoopFusionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ class FusionStrategy {
/// returns a FusionResult explaining why fusion is not feasible.
/// NOTE: This function is not feature complete and should only be used in
/// testing.
/// TODO: Update comments when this function is fully implemented.
FusionResult
canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth,
ComputationSliceState *srcSlice,
Expand Down
39 changes: 20 additions & 19 deletions mlir/lib/Analysis/FlatLinearValueConstraints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -958,15 +958,15 @@ areVarsUnique(const FlatLinearValueConstraints &cst, VarKind kind) {
/// so that they have the union of all variables, with A's original
/// variables appearing first followed by any of B's variables that didn't
/// appear in A. Local variables in B that have the same division
/// representation as local variables in A are merged into one.
/// representation as local variables in A are merged into one. We allow A
/// and B to have non-unique values for their variables; in such cases, they are
/// still aligned with the variables appearing first aligned with those
/// appearing first in the other system from left to right.
// E.g.: Input: A has ((%i, %j) [%M, %N]) and B has (%k, %j) [%P, %N, %M])
// Output: both A, B have (%i, %j, %k) [%M, %N, %P]
static void mergeAndAlignVars(unsigned offset, FlatLinearValueConstraints *a,
FlatLinearValueConstraints *b) {
assert(offset <= a->getNumDimVars() && offset <= b->getNumDimVars());
// A merge/align isn't meaningful if a cst's vars aren't distinct.
assert(areVarsUnique(*a) && "A's values aren't unique");
assert(areVarsUnique(*b) && "B's values aren't unique");

assert(llvm::all_of(
llvm::drop_begin(a->getMaybeValues(), offset),
Expand All @@ -982,9 +982,12 @@ static void mergeAndAlignVars(unsigned offset, FlatLinearValueConstraints *a,
{
// Merge dims from A into B.
unsigned d = offset;
for (auto aDimValue : aDimValues) {
for (Value aDimValue : aDimValues) {
unsigned loc;
if (b->findVar(aDimValue, &loc)) {
// Find from the position `d` since we'd like to also consider the
// possibility of multiple variables with the same `Value`. We align with
// the next appearing one.
if (b->findVar(aDimValue, &loc, d)) {
assert(loc >= offset && "A's dim appears in B's aligned range");
assert(loc < b->getNumDimVars() &&
"A's dim appears in B's non-dim position");
Expand Down Expand Up @@ -1017,15 +1020,12 @@ void FlatLinearValueConstraints::mergeAndAlignVarsWithOther(
}

/// Merge and align symbols of `this` and `other` such that both get union of
/// of symbols that are unique. Symbols in `this` and `other` should be
/// unique. Symbols with Value as `None` are considered to be inequal to all
/// other symbols.
/// of symbols. Existing symbols need not be unique; they will be aligned from
/// left to right with duplicates aligned in the same order. Symbols with Value
/// as `None` are considered to be inequal to all other symbols.
void FlatLinearValueConstraints::mergeSymbolVars(
FlatLinearValueConstraints &other) {

assert(areVarsUnique(*this, VarKind::Symbol) && "Symbol vars are not unique");
assert(areVarsUnique(other, VarKind::Symbol) && "Symbol vars are not unique");

SmallVector<Value, 4> aSymValues;
getValues(getNumDimVars(), getNumDimAndSymbolVars(), &aSymValues);

Expand All @@ -1034,8 +1034,9 @@ void FlatLinearValueConstraints::mergeSymbolVars(
for (Value aSymValue : aSymValues) {
unsigned loc;
// If the var is a symbol in `other`, then align it, otherwise assume that
// it is a new symbol
if (other.findVar(aSymValue, &loc) && loc >= other.getNumDimVars() &&
// it is a new symbol. Search in `other` starting at position `s` since the
// left of it is aligned.
if (other.findVar(aSymValue, &loc, s) && loc >= other.getNumDimVars() &&
loc < other.getNumDimAndSymbolVars())
other.swapVar(s, loc);
else
Expand All @@ -1051,8 +1052,6 @@ void FlatLinearValueConstraints::mergeSymbolVars(

assert(getNumSymbolVars() == other.getNumSymbolVars() &&
"expected same number of symbols");
assert(areVarsUnique(*this, VarKind::Symbol) && "Symbol vars are not unique");
assert(areVarsUnique(other, VarKind::Symbol) && "Symbol vars are not unique");
}

bool FlatLinearValueConstraints::hasConsistentState() const {
Expand Down Expand Up @@ -1104,9 +1103,11 @@ FlatLinearValueConstraints::computeAlignedMap(AffineMap map,
return alignedMap;
}

bool FlatLinearValueConstraints::findVar(Value val, unsigned *pos) const {
unsigned i = 0;
for (const auto &mayBeVar : values) {
bool FlatLinearValueConstraints::findVar(Value val, unsigned *pos,
unsigned offset) const {
unsigned i = offset;
for (const auto &mayBeVar :
ArrayRef<std::optional<Value>>(values).drop_front(offset)) {
if (mayBeVar && *mayBeVar == val) {
*pos = i;
return true;
Expand Down
17 changes: 11 additions & 6 deletions mlir/lib/Dialect/Affine/Analysis/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,15 @@ bool MemRefDependenceGraph::init() {
continue;
SmallVector<AffineForOp, 4> loops;
getAffineForIVs(*user, &loops);
if (loops.empty())
// Find the surrounding affine.for nested immediately within the
// block.
auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
return loop->getBlock() == &block;
});
if (it == loops.end())
continue;
assert(forToNodeMap.count(loops[0]) > 0 && "missing mapping");
unsigned userLoopNestId = forToNodeMap[loops[0]];
assert(forToNodeMap.count(*it) > 0 && "missing mapping");
unsigned userLoopNestId = forToNodeMap[*it];
addEdge(node.id, userLoopNestId, value);
}
}
Expand Down Expand Up @@ -631,8 +636,8 @@ void mlir::affine::getAffineForIVs(Operation &op,
AffineForOp currAffineForOp;
// Traverse up the hierarchy collecting all 'affine.for' operation while
// skipping over 'affine.if' operations.
while (currOp) {
if (AffineForOp currAffineForOp = dyn_cast<AffineForOp>(currOp))
while (currOp && !currOp->hasTrait<OpTrait::AffineScope>()) {
if (auto currAffineForOp = dyn_cast<AffineForOp>(currOp))
loops->push_back(currAffineForOp);
currOp = currOp->getParentOp();
}
Expand All @@ -646,7 +651,7 @@ void mlir::affine::getEnclosingAffineOps(Operation &op,

// Traverse up the hierarchy collecting all `affine.for`, `affine.if`, and
// affine.parallel operations.
while (currOp) {
while (currOp && !currOp->hasTrait<OpTrait::AffineScope>()) {
if (isa<AffineIfOp, AffineForOp, AffineParallelOp>(currOp))
ops->push_back(currOp);
currOp = currOp->getParentOp();
Expand Down
65 changes: 52 additions & 13 deletions mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,16 @@ struct GreedyFusion {
if (fusedLoopInsPoint == nullptr)
continue;

// It's possible this fusion is at an inner depth (i.e., there are
// common surrounding affine loops for the source and destination for
// ops). We need to get this number because the call to canFuseLoops
// needs to be passed the absolute depth. The max legal depth and the
// depths we try below are however *relative* and as such don't include
// the common depth.
SmallVector<AffineForOp, 4> surroundingLoops;
getAffineForIVs(*dstAffineForOp, &surroundingLoops);
unsigned numSurroundingLoops = surroundingLoops.size();

// Compute the innermost common loop depth for dstNode
// producer-consumer loads/stores.
SmallVector<Operation *, 2> dstMemrefOps;
Expand All @@ -907,7 +917,8 @@ struct GreedyFusion {
if (producerConsumerMemrefs.count(
cast<AffineWriteOpInterface>(op).getMemRef()))
dstMemrefOps.push_back(op);
unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps);
unsigned dstLoopDepthTest =
getInnermostCommonLoopDepth(dstMemrefOps) - numSurroundingLoops;

// Check the feasibility of fusing src loop nest into dst loop nest
// at loop depths in range [1, dstLoopDepthTest].
Expand All @@ -916,9 +927,10 @@ struct GreedyFusion {
depthSliceUnions.resize(dstLoopDepthTest);
FusionStrategy strategy(FusionStrategy::ProducerConsumer);
for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
FusionResult result = affine::canFuseLoops(
srcAffineForOp, dstAffineForOp,
/*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy);
FusionResult result =
affine::canFuseLoops(srcAffineForOp, dstAffineForOp,
/*dstLoopDepth=*/i + numSurroundingLoops,
&depthSliceUnions[i - 1], strategy);

if (result.value == FusionResult::Success)
maxLegalFusionDepth = i;
Expand Down Expand Up @@ -1125,9 +1137,18 @@ struct GreedyFusion {
SmallVector<Operation *, 2> dstLoadOpInsts;
dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);

// It's possible this fusion is at an inner depth (i.e., there are common
// surrounding affine loops for the source and destination for ops). We
// need to get this number because the call to canFuseLoops needs to be
// passed the absolute depth. The max legal depth and the depths we try
// below are however *relative* and as such don't include the common
// depth.
SmallVector<AffineForOp, 4> surroundingLoops;
getAffineForIVs(*dstAffineForOp, &surroundingLoops);
unsigned numSurroundingLoops = surroundingLoops.size();
SmallVector<AffineForOp, 4> dstLoopIVs;
getAffineForIVs(*dstLoadOpInsts[0], &dstLoopIVs);
unsigned dstLoopDepthTest = dstLoopIVs.size();
unsigned dstLoopDepthTest = dstLoopIVs.size() - numSurroundingLoops;
auto sibAffineForOp = cast<AffineForOp>(sibNode->op);

// Compute loop depth and slice union for fusion.
Expand All @@ -1136,14 +1157,18 @@ struct GreedyFusion {
unsigned maxLegalFusionDepth = 0;
FusionStrategy strategy(memref);
for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
FusionResult result = affine::canFuseLoops(
sibAffineForOp, dstAffineForOp,
/*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy);
FusionResult result =
affine::canFuseLoops(sibAffineForOp, dstAffineForOp,
/*dstLoopDepth=*/i + numSurroundingLoops,
&depthSliceUnions[i - 1], strategy);

if (result.value == FusionResult::Success)
maxLegalFusionDepth = i;
}

LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: "
<< maxLegalFusionDepth << '\n');

// Skip if fusion is not feasible at any loop depths.
if (maxLegalFusionDepth == 0)
continue;
Expand Down Expand Up @@ -1238,9 +1263,15 @@ struct GreedyFusion {
SmallVector<AffineForOp, 4> loops;
getAffineForIVs(*user, &loops);
// Skip 'use' if it is not within a loop nest.
if (loops.empty())
// Find the surrounding affine.for nested immediately within the
// block.
auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
return loop->getBlock() == &mdg->block;
});
// Skip 'use' if it is not within a loop nest in `block`.
if (it == loops.end())
continue;
Node *sibNode = mdg->getForOpNode(loops[0]);
Node *sibNode = mdg->getForOpNode(*it);
assert(sibNode != nullptr);
// Skip 'use' if it not a sibling to 'dstNode'.
if (sibNode->id == dstNode->id)
Expand Down Expand Up @@ -1373,9 +1404,17 @@ void LoopFusion::runOnBlock(Block *block) {
}

void LoopFusion::runOnOperation() {
for (Region &region : getOperation()->getRegions())
for (Block &block : region.getBlocks())
runOnBlock(&block);
// Call fusion on every op that has at least two affine.for nests (in post
// order).
getOperation()->walk([&](Operation *op) {
for (Region &region : op->getRegions()) {
for (Block &block : region.getBlocks()) {
auto affineFors = block.getOps<AffineForOp>();
if (!affineFors.empty() && !llvm::hasSingleElement(affineFors))
runOnBlock(&block);
}
}
});
}

std::unique_ptr<Pass> mlir::affine::createLoopFusionPass(
Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,6 @@ static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps,
return loopDepth;
}

// TODO: Prevent fusion of loop nests with side-effecting operations.
// TODO: This pass performs some computation that is the same for all the depths
// (e.g., getMaxLoopDepth). Implement a version of this utility that processes
// all the depths at once or only the legal maximal depth for maximal fusion.
Expand Down
Loading