Skip to content

Commit c79ffb0

Browse files
authored
Generalize affine fusion to work at all depths and inside other region-holding ops (#72288)
Generalize affine fusion to work at any inner depth; fusing loops inside other affine.for or even inside scf.for or scf.while nests. Apply in post order on all affine nests on the pass' top-level operation. Fix MDG init for blocks inside other affine nests. Relax unnecessary requirement for unique vars during merge and align of FlatLinearValueConstraints. There are several cases where FlatLinearValueConstraints need to have duplicate Values for the dimensions: for eg. in dependence relation systems with source and destination accesses could have common loop IVs. `mergeAndAlign` can be done even in the presence of Values reappearing by simply aligning from left to right in that order. While at this, drop outdated comments; improve some debug messages.
1 parent f5bfc83 commit c79ffb0

File tree

8 files changed

+320
-52
lines changed

8 files changed

+320
-52
lines changed

mlir/include/mlir/Analysis/FlatLinearValueConstraints.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -374,10 +374,10 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
374374
setValue(i, values[i - start]);
375375
}
376376

377-
/// Looks up the position of the variable with the specified Value. Returns
378-
/// true if found (false otherwise). `pos` is set to the (column) position of
379-
/// the variable.
380-
bool findVar(Value val, unsigned *pos) const;
377+
/// Looks up the position of the variable with the specified Value starting
378+
/// with variables at offset `offset`. Returns true if found (false
379+
/// otherwise). `pos` is set to the (column) position of the variable.
380+
bool findVar(Value val, unsigned *pos, unsigned offset = 0) const;
381381

382382
/// Returns true if a variable with the specified Value exists, false
383383
/// otherwise.

mlir/include/mlir/Dialect/Affine/Analysis/Utils.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -213,20 +213,20 @@ struct MemRefDependenceGraph {
213213
};
214214

215215
/// Populates 'loops' with IVs of the affine.for ops surrounding 'op' ordered
216-
/// from the outermost 'affine.for' operation to the innermost one.
216+
/// from the outermost 'affine.for' operation to the innermost one while not
217+
/// traversing outside of the surrounding affine scope.
217218
void getAffineForIVs(Operation &op, SmallVectorImpl<AffineForOp> *loops);
218219

219220
/// Populates 'ivs' with IVs of the surrounding affine.for and affine.parallel
220-
/// ops ordered from the outermost one to the innermost.
221+
/// ops ordered from the outermost one to the innermost while not traversing
222+
/// outside of the surrounding affine scope.
221223
void getAffineIVs(Operation &op, SmallVectorImpl<Value> &ivs);
222224

223225
/// Populates 'ops' with affine operations enclosing `op` ordered from outermost
224-
/// to innermost. affine.for, affine.if, or affine.parallel ops comprise such
225-
/// surrounding affine ops.
226-
/// TODO: Change this to return a list of enclosing ops up until the op that
227-
/// starts an `AffineScope`. In such a case, `ops` is guaranteed by design to
228-
/// have a successive chain of affine parent ops, and this is primarily what is
229-
/// needed for most analyses.
226+
/// to innermost while stopping at the boundary of the affine scope. affine.for,
227+
/// affine.if, or affine.parallel ops comprise such surrounding affine ops.
228+
/// `ops` is guaranteed by design to have a successive chain of affine parent
229+
/// ops.
230230
void getEnclosingAffineOps(Operation &op, SmallVectorImpl<Operation *> *ops);
231231

232232
/// Returns the nesting depth of this operation, i.e., the number of loops

mlir/include/mlir/Dialect/Affine/LoopFusionUtils.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ class FusionStrategy {
110110
/// returns a FusionResult explaining why fusion is not feasible.
111111
/// NOTE: This function is not feature complete and should only be used in
112112
/// testing.
113-
/// TODO: Update comments when this function is fully implemented.
114113
FusionResult
115114
canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth,
116115
ComputationSliceState *srcSlice,

mlir/lib/Analysis/FlatLinearValueConstraints.cpp

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -958,15 +958,15 @@ areVarsUnique(const FlatLinearValueConstraints &cst, VarKind kind) {
958958
/// so that they have the union of all variables, with A's original
959959
/// variables appearing first followed by any of B's variables that didn't
960960
/// appear in A. Local variables in B that have the same division
961-
/// representation as local variables in A are merged into one.
961+
/// representation as local variables in A are merged into one. We allow A
962+
/// and B to have non-unique values for their variables; in such cases, they are
963+
/// still aligned with the variables appearing first aligned with those
964+
/// appearing first in the other system from left to right.
962965
// E.g.: Input: A has ((%i, %j) [%M, %N]) and B has (%k, %j) [%P, %N, %M])
963966
// Output: both A, B have (%i, %j, %k) [%M, %N, %P]
964967
static void mergeAndAlignVars(unsigned offset, FlatLinearValueConstraints *a,
965968
FlatLinearValueConstraints *b) {
966969
assert(offset <= a->getNumDimVars() && offset <= b->getNumDimVars());
967-
// A merge/align isn't meaningful if a cst's vars aren't distinct.
968-
assert(areVarsUnique(*a) && "A's values aren't unique");
969-
assert(areVarsUnique(*b) && "B's values aren't unique");
970970

971971
assert(llvm::all_of(
972972
llvm::drop_begin(a->getMaybeValues(), offset),
@@ -982,9 +982,12 @@ static void mergeAndAlignVars(unsigned offset, FlatLinearValueConstraints *a,
982982
{
983983
// Merge dims from A into B.
984984
unsigned d = offset;
985-
for (auto aDimValue : aDimValues) {
985+
for (Value aDimValue : aDimValues) {
986986
unsigned loc;
987-
if (b->findVar(aDimValue, &loc)) {
987+
// Find from the position `d` since we'd like to also consider the
988+
// possibility of multiple variables with the same `Value`. We align with
989+
// the next appearing one.
990+
if (b->findVar(aDimValue, &loc, d)) {
988991
assert(loc >= offset && "A's dim appears in B's aligned range");
989992
assert(loc < b->getNumDimVars() &&
990993
"A's dim appears in B's non-dim position");
@@ -1017,15 +1020,12 @@ void FlatLinearValueConstraints::mergeAndAlignVarsWithOther(
10171020
}
10181021

10191022
/// Merge and align symbols of `this` and `other` such that both get union of
1020-
/// of symbols that are unique. Symbols in `this` and `other` should be
1021-
/// unique. Symbols with Value as `None` are considered to be inequal to all
1022-
/// other symbols.
1023+
/// of symbols. Existing symbols need not be unique; they will be aligned from
1024+
/// left to right with duplicates aligned in the same order. Symbols with Value
1025+
/// as `None` are considered to be inequal to all other symbols.
10231026
void FlatLinearValueConstraints::mergeSymbolVars(
10241027
FlatLinearValueConstraints &other) {
10251028

1026-
assert(areVarsUnique(*this, VarKind::Symbol) && "Symbol vars are not unique");
1027-
assert(areVarsUnique(other, VarKind::Symbol) && "Symbol vars are not unique");
1028-
10291029
SmallVector<Value, 4> aSymValues;
10301030
getValues(getNumDimVars(), getNumDimAndSymbolVars(), &aSymValues);
10311031

@@ -1034,8 +1034,9 @@ void FlatLinearValueConstraints::mergeSymbolVars(
10341034
for (Value aSymValue : aSymValues) {
10351035
unsigned loc;
10361036
// If the var is a symbol in `other`, then align it, otherwise assume that
1037-
// it is a new symbol
1038-
if (other.findVar(aSymValue, &loc) && loc >= other.getNumDimVars() &&
1037+
// it is a new symbol. Search in `other` starting at position `s` since the
1038+
// left of it is aligned.
1039+
if (other.findVar(aSymValue, &loc, s) && loc >= other.getNumDimVars() &&
10391040
loc < other.getNumDimAndSymbolVars())
10401041
other.swapVar(s, loc);
10411042
else
@@ -1051,8 +1052,6 @@ void FlatLinearValueConstraints::mergeSymbolVars(
10511052

10521053
assert(getNumSymbolVars() == other.getNumSymbolVars() &&
10531054
"expected same number of symbols");
1054-
assert(areVarsUnique(*this, VarKind::Symbol) && "Symbol vars are not unique");
1055-
assert(areVarsUnique(other, VarKind::Symbol) && "Symbol vars are not unique");
10561055
}
10571056

10581057
bool FlatLinearValueConstraints::hasConsistentState() const {
@@ -1104,9 +1103,11 @@ FlatLinearValueConstraints::computeAlignedMap(AffineMap map,
11041103
return alignedMap;
11051104
}
11061105

1107-
bool FlatLinearValueConstraints::findVar(Value val, unsigned *pos) const {
1108-
unsigned i = 0;
1109-
for (const auto &mayBeVar : values) {
1106+
bool FlatLinearValueConstraints::findVar(Value val, unsigned *pos,
1107+
unsigned offset) const {
1108+
unsigned i = offset;
1109+
for (const auto &mayBeVar :
1110+
ArrayRef<std::optional<Value>>(values).drop_front(offset)) {
11101111
if (mayBeVar && *mayBeVar == val) {
11111112
*pos = i;
11121113
return true;

mlir/lib/Dialect/Affine/Analysis/Utils.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -199,10 +199,15 @@ bool MemRefDependenceGraph::init() {
199199
continue;
200200
SmallVector<AffineForOp, 4> loops;
201201
getAffineForIVs(*user, &loops);
202-
if (loops.empty())
202+
// Find the surrounding affine.for nested immediately within the
203+
// block.
204+
auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
205+
return loop->getBlock() == &block;
206+
});
207+
if (it == loops.end())
203208
continue;
204-
assert(forToNodeMap.count(loops[0]) > 0 && "missing mapping");
205-
unsigned userLoopNestId = forToNodeMap[loops[0]];
209+
assert(forToNodeMap.count(*it) > 0 && "missing mapping");
210+
unsigned userLoopNestId = forToNodeMap[*it];
206211
addEdge(node.id, userLoopNestId, value);
207212
}
208213
}
@@ -631,8 +636,8 @@ void mlir::affine::getAffineForIVs(Operation &op,
631636
AffineForOp currAffineForOp;
632637
// Traverse up the hierarchy collecting all 'affine.for' operation while
633638
// skipping over 'affine.if' operations.
634-
while (currOp) {
635-
if (AffineForOp currAffineForOp = dyn_cast<AffineForOp>(currOp))
639+
while (currOp && !currOp->hasTrait<OpTrait::AffineScope>()) {
640+
if (auto currAffineForOp = dyn_cast<AffineForOp>(currOp))
636641
loops->push_back(currAffineForOp);
637642
currOp = currOp->getParentOp();
638643
}
@@ -646,7 +651,7 @@ void mlir::affine::getEnclosingAffineOps(Operation &op,
646651

647652
// Traverse up the hierarchy collecting all `affine.for`, `affine.if`, and
648653
// affine.parallel operations.
649-
while (currOp) {
654+
while (currOp && !currOp->hasTrait<OpTrait::AffineScope>()) {
650655
if (isa<AffineIfOp, AffineForOp, AffineParallelOp>(currOp))
651656
ops->push_back(currOp);
652657
currOp = currOp->getParentOp();

mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,16 @@ struct GreedyFusion {
896896
if (fusedLoopInsPoint == nullptr)
897897
continue;
898898

899+
// It's possible this fusion is at an inner depth (i.e., there are
900+
// common surrounding affine loops for the source and destination for
901+
// ops). We need to get this number because the call to canFuseLoops
902+
// needs to be passed the absolute depth. The max legal depth and the
903+
// depths we try below are however *relative* and as such don't include
904+
// the common depth.
905+
SmallVector<AffineForOp, 4> surroundingLoops;
906+
getAffineForIVs(*dstAffineForOp, &surroundingLoops);
907+
unsigned numSurroundingLoops = surroundingLoops.size();
908+
899909
// Compute the innermost common loop depth for dstNode
900910
// producer-consumer loads/stores.
901911
SmallVector<Operation *, 2> dstMemrefOps;
@@ -907,7 +917,8 @@ struct GreedyFusion {
907917
if (producerConsumerMemrefs.count(
908918
cast<AffineWriteOpInterface>(op).getMemRef()))
909919
dstMemrefOps.push_back(op);
910-
unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps);
920+
unsigned dstLoopDepthTest =
921+
getInnermostCommonLoopDepth(dstMemrefOps) - numSurroundingLoops;
911922

912923
// Check the feasibility of fusing src loop nest into dst loop nest
913924
// at loop depths in range [1, dstLoopDepthTest].
@@ -916,9 +927,10 @@ struct GreedyFusion {
916927
depthSliceUnions.resize(dstLoopDepthTest);
917928
FusionStrategy strategy(FusionStrategy::ProducerConsumer);
918929
for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
919-
FusionResult result = affine::canFuseLoops(
920-
srcAffineForOp, dstAffineForOp,
921-
/*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy);
930+
FusionResult result =
931+
affine::canFuseLoops(srcAffineForOp, dstAffineForOp,
932+
/*dstLoopDepth=*/i + numSurroundingLoops,
933+
&depthSliceUnions[i - 1], strategy);
922934

923935
if (result.value == FusionResult::Success)
924936
maxLegalFusionDepth = i;
@@ -1125,9 +1137,18 @@ struct GreedyFusion {
11251137
SmallVector<Operation *, 2> dstLoadOpInsts;
11261138
dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
11271139

1140+
// It's possible this fusion is at an inner depth (i.e., there are common
1141+
// surrounding affine loops for the source and destination for ops). We
1142+
// need to get this number because the call to canFuseLoops needs to be
1143+
// passed the absolute depth. The max legal depth and the depths we try
1144+
// below are however *relative* and as such don't include the common
1145+
// depth.
1146+
SmallVector<AffineForOp, 4> surroundingLoops;
1147+
getAffineForIVs(*dstAffineForOp, &surroundingLoops);
1148+
unsigned numSurroundingLoops = surroundingLoops.size();
11281149
SmallVector<AffineForOp, 4> dstLoopIVs;
11291150
getAffineForIVs(*dstLoadOpInsts[0], &dstLoopIVs);
1130-
unsigned dstLoopDepthTest = dstLoopIVs.size();
1151+
unsigned dstLoopDepthTest = dstLoopIVs.size() - numSurroundingLoops;
11311152
auto sibAffineForOp = cast<AffineForOp>(sibNode->op);
11321153

11331154
// Compute loop depth and slice union for fusion.
@@ -1136,14 +1157,18 @@ struct GreedyFusion {
11361157
unsigned maxLegalFusionDepth = 0;
11371158
FusionStrategy strategy(memref);
11381159
for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
1139-
FusionResult result = affine::canFuseLoops(
1140-
sibAffineForOp, dstAffineForOp,
1141-
/*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy);
1160+
FusionResult result =
1161+
affine::canFuseLoops(sibAffineForOp, dstAffineForOp,
1162+
/*dstLoopDepth=*/i + numSurroundingLoops,
1163+
&depthSliceUnions[i - 1], strategy);
11421164

11431165
if (result.value == FusionResult::Success)
11441166
maxLegalFusionDepth = i;
11451167
}
11461168

1169+
LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: "
1170+
<< maxLegalFusionDepth << '\n');
1171+
11471172
// Skip if fusion is not feasible at any loop depths.
11481173
if (maxLegalFusionDepth == 0)
11491174
continue;
@@ -1238,9 +1263,15 @@ struct GreedyFusion {
12381263
SmallVector<AffineForOp, 4> loops;
12391264
getAffineForIVs(*user, &loops);
12401265
// Skip 'use' if it is not within a loop nest.
1241-
if (loops.empty())
1266+
// Find the surrounding affine.for nested immediately within the
1267+
// block.
1268+
auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
1269+
return loop->getBlock() == &mdg->block;
1270+
});
1271+
// Skip 'use' if it is not within a loop nest in `block`.
1272+
if (it == loops.end())
12421273
continue;
1243-
Node *sibNode = mdg->getForOpNode(loops[0]);
1274+
Node *sibNode = mdg->getForOpNode(*it);
12441275
assert(sibNode != nullptr);
12451276
// Skip 'use' if it not a sibling to 'dstNode'.
12461277
if (sibNode->id == dstNode->id)
@@ -1373,9 +1404,17 @@ void LoopFusion::runOnBlock(Block *block) {
13731404
}
13741405

13751406
void LoopFusion::runOnOperation() {
1376-
for (Region &region : getOperation()->getRegions())
1377-
for (Block &block : region.getBlocks())
1378-
runOnBlock(&block);
1407+
// Call fusion on every op that has at least two affine.for nests (in post
1408+
// order).
1409+
getOperation()->walk([&](Operation *op) {
1410+
for (Region &region : op->getRegions()) {
1411+
for (Block &block : region.getBlocks()) {
1412+
auto affineFors = block.getOps<AffineForOp>();
1413+
if (!affineFors.empty() && !llvm::hasSingleElement(affineFors))
1414+
runOnBlock(&block);
1415+
}
1416+
}
1417+
});
13791418
}
13801419

13811420
std::unique_ptr<Pass> mlir::affine::createLoopFusionPass(

mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,6 @@ static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps,
243243
return loopDepth;
244244
}
245245

246-
// TODO: Prevent fusion of loop nests with side-effecting operations.
247246
// TODO: This pass performs some computation that is the same for all the depths
248247
// (e.g., getMaxLoopDepth). Implement a version of this utility that processes
249248
// all the depths at once or only the legal maximal depth for maximal fusion.

0 commit comments

Comments
 (0)