Skip to content

Commit 469f9d5

Browse files
authored
[MLIR][Affine] Rewrite fusion helper hasNonAffineUsersOnPath for efficiency (#115588)
The hasNonAffineUsersOnPath utility used during fusion was terribly inefficient in its approach. Rewrite it efficiently to simply work based on use lists (sparse) instead of having to traverse all nodes of an MDG repeatedly and all operands of all ops of each node in the relevant range. On large models (with 10s of thousands of loop nests), this reduces fusion pass time by nearly 2x (cutting down several tens of seconds).
1 parent 8d43c88 commit 469f9d5

File tree

1 file changed

+35
-48
lines changed

1 file changed

+35
-48
lines changed

mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp

Lines changed: 35 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -343,61 +343,48 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
343343
return newMemRef;
344344
}
345345

346-
/// Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and
347-
/// 'dstId'), if there is any non-affine operation accessing 'memref', return
348-
/// true. Otherwise, return false.
349-
static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
350-
Value memref,
351-
MemRefDependenceGraph *mdg) {
352-
auto *srcNode = mdg->getNode(srcId);
353-
auto *dstNode = mdg->getNode(dstId);
354-
Value::user_range users = memref.getUsers();
355-
// For each MemRefDependenceGraph's node that is between 'srcNode' and
356-
// 'dstNode' (exclusive of 'srcNodes' and 'dstNode'), check whether any
357-
// non-affine operation in the node accesses the 'memref'.
358-
for (auto &idAndNode : mdg->nodes) {
359-
Operation *op = idAndNode.second.op;
360-
// Take care of operations between 'srcNode' and 'dstNode'.
361-
if (srcNode->op->isBeforeInBlock(op) && op->isBeforeInBlock(dstNode->op)) {
362-
// Walk inside the operation to find any use of the memref.
363-
// Interrupt the walk if found.
364-
auto walkResult = op->walk([&](Operation *user) {
365-
// Skip affine ops.
366-
if (isa<AffineMapAccessInterface>(*user))
367-
return WalkResult::advance();
368-
// Find a non-affine op that uses the memref.
369-
if (llvm::is_contained(users, user))
370-
return WalkResult::interrupt();
371-
return WalkResult::advance();
372-
});
373-
if (walkResult.wasInterrupted())
374-
return true;
375-
}
376-
}
377-
return false;
346+
/// Returns true if there are any non-affine uses of `memref` in any of
347+
/// the operations between `start` and `end` (both exclusive). Any other
348+
/// than affine read/write are treated as non-affine uses of `memref`.
349+
static bool hasNonAffineUsersOnPath(Operation *start, Operation *end,
350+
Value memref) {
351+
assert(start->getBlock() == end->getBlock());
352+
assert(start->isBeforeInBlock(end) && "start expected to be before end");
353+
Block *block = start->getBlock();
354+
// Check if there is a non-affine memref user in any op between `start` and
355+
// `end`.
356+
return llvm::any_of(memref.getUsers(), [&](Operation *user) {
357+
if (isa<AffineReadOpInterface, AffineWriteOpInterface>(user))
358+
return false;
359+
Operation *ancestor = block->findAncestorOpInBlock(*user);
360+
return ancestor && start->isBeforeInBlock(ancestor) &&
361+
ancestor->isBeforeInBlock(end);
362+
});
378363
}
379364

380-
/// Check whether a memref value in node 'srcId' has a non-affine that
381-
/// is between node 'srcId' and node 'dstId' (exclusive of 'srcNode' and
382-
/// 'dstNode').
383-
static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
384-
MemRefDependenceGraph *mdg) {
385-
// Collect memref values in node 'srcId'.
386-
auto *srcNode = mdg->getNode(srcId);
365+
/// Check whether a memref value used in any operation of 'src' has a
366+
/// non-affine operation that is between `src` and `end` (exclusive of `src`
367+
/// and `end`) where `src` and `end` are expected to be in the same Block.
368+
/// Any other than affine read/write are treated as non-affine uses of memref.
369+
static bool hasNonAffineUsersOnPath(Operation *src, Operation *end) {
370+
assert(src->getBlock() == end->getBlock() && "same block expected");
371+
372+
// Trivial case. `src` and `end` are exclusive.
373+
if (src == end || end->isBeforeInBlock(src))
374+
return false;
375+
376+
// Collect relevant memref values.
387377
llvm::SmallDenseSet<Value, 2> memRefValues;
388-
srcNode->op->walk([&](Operation *op) {
389-
// Skip affine ops.
390-
if (isa<AffineForOp>(op))
391-
return WalkResult::advance();
378+
src->walk([&](Operation *op) {
392379
for (Value v : op->getOperands())
393380
// Collect memref values only.
394381
if (isa<MemRefType>(v.getType()))
395382
memRefValues.insert(v);
396383
return WalkResult::advance();
397384
});
398-
// Looking for users between node 'srcId' and node 'dstId'.
385+
// Look for non-affine users between `src` and `end`.
399386
return llvm::any_of(memRefValues, [&](Value memref) {
400-
return hasNonAffineUsersOnThePath(srcId, dstId, memref, mdg);
387+
return hasNonAffineUsersOnPath(src, end, memref);
401388
});
402389
}
403390

@@ -884,7 +871,7 @@ struct GreedyFusion {
884871
// escaping memrefs so we can limit this check to only scenarios with
885872
// escaping memrefs.
886873
if (!srcEscapingMemRefs.empty() &&
887-
hasNonAffineUsersOnThePath(srcId, dstId, mdg)) {
874+
hasNonAffineUsersOnPath(srcNode->op, dstNode->op)) {
888875
LLVM_DEBUG(llvm::dbgs()
889876
<< "Can't fuse: non-affine users in between the loops\n");
890877
continue;
@@ -1247,8 +1234,8 @@ struct GreedyFusion {
12471234

12481235
// Skip if a memref value in one node is used by a non-affine memref
12491236
// access that lies between 'dstNode' and 'sibNode'.
1250-
if (hasNonAffineUsersOnThePath(dstNode->id, sibNode->id, mdg) ||
1251-
hasNonAffineUsersOnThePath(sibNode->id, dstNode->id, mdg))
1237+
if (hasNonAffineUsersOnPath(dstNode->op, sibNode->op) ||
1238+
hasNonAffineUsersOnPath(sibNode->op, dstNode->op))
12521239
return false;
12531240
return true;
12541241
};

0 commit comments

Comments
 (0)