Skip to content

[MLIR][Affine] Fix fusion in the presence of cyclic deps in source nests #128397

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,14 @@ bool isOpwiseShiftValid(AffineForOp forOp, ArrayRef<uint64_t> shifts);
/// any dependence component is negative along any of `loops`.
bool isTilingValid(ArrayRef<AffineForOp> loops);

/// Returns true if the affine nest rooted at `root` has a cyclic dependence
/// among its affine memory accesses. The dependence could be through any
/// dependences carried by loops contained in `root` (inclusive of `root`) and
/// those carried by loop bodies (blocks) contained. Dependences carried by
/// loops outer to `root` aren't relevant. This method doesn't consider/account
/// for aliases.
bool hasCyclicDependence(AffineForOp root);

} // namespace affine
} // namespace mlir

Expand Down
162 changes: 160 additions & 2 deletions mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
#include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Analysis/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "llvm/Support/MathExtras.h"

Expand All @@ -28,10 +28,138 @@
#include <optional>
#include <type_traits>

#define DEBUG_TYPE "affine-loop-analysis"

using namespace mlir;
using namespace mlir::affine;

#define DEBUG_TYPE "affine-loop-analysis"
namespace {

/// A directed graph to model relationships between MLIR Operations.
class DirectedOpGraph {
public:
/// Add a node to the graph.
void addNode(Operation *op) {
assert(!hasNode(op) && "node already added");
nodes.emplace_back(op);
edges[op] = {};
}

/// Add an edge from `src` to `dest`.
void addEdge(Operation *src, Operation *dest) {
// This is a multi-graph.
assert(hasNode(src) && "src node does not exist in graph");
assert(hasNode(dest) && "dest node does not exist in graph");
edges[src].push_back(getNode(dest));
}

/// Returns true if there is a (directed) cycle in the graph.
bool hasCycle() { return dfs(/*cycleCheck=*/true); }

void printEdges() {
for (auto &en : edges) {
llvm::dbgs() << *en.first << " (" << en.first << ")"
<< " has " << en.second.size() << " edges:\n";
for (auto *node : en.second) {
llvm::dbgs() << '\t' << *node->op << '\n';
}
}
}

private:
/// A node of a directed graph between MLIR Operations to model various
/// relationships. This is meant to be used internally.
struct DGNode {
DGNode(Operation *op) : op(op) {};
Operation *op;

// Start and finish visit numbers are standard in DFS to implement things
// like finding strongly connected components. These numbers are modified
// during analyses on the graph and so seemingly const API methods will be
// non-const.

/// Start visit number.
int vn = -1;

/// Finish visit number.
int fn = -1;
};

/// Get internal node corresponding to `op`.
DGNode *getNode(Operation *op) {
auto *value =
llvm::find_if(nodes, [&](const DGNode &node) { return node.op == op; });
assert(value != nodes.end() && "node doesn't exist in graph");
return &*value;
}

/// Returns true if `key` is in the graph.
bool hasNode(Operation *key) const {
return llvm::find_if(nodes, [&](const DGNode &node) {
return node.op == key;
}) != nodes.end();
}

/// Perform a depth-first traversal of the graph setting visited and finished
/// numbers. If `cycleCheck` is set, detects cycles and returns true as soon
/// as the first cycle is detected, and false if there are no cycles. If
/// `cycleCheck` is not set, completes the DFS and the `return` value doesn't
/// have a meaning.
bool dfs(bool cycleCheck = false) {
for (DGNode &node : nodes) {
node.vn = 0;
node.fn = -1;
}

unsigned time = 0;
for (DGNode &node : nodes) {
if (node.vn == 0) {
bool ret = dfsNode(node, cycleCheck, time);
// Check if a cycle was already found.
if (cycleCheck && ret)
return true;
} else if (cycleCheck && node.fn == -1) {
// We have encountered a node whose visit has started but it's not
// finished. So we have a cycle.
return true;
}
}
return false;
}

/// Perform depth-first traversal starting at `node`. Return true
/// as soon as a cycle is found if `cycleCheck` was set. Update `time`.
bool dfsNode(DGNode &node, bool cycleCheck, unsigned &time) const {
auto nodeEdges = edges.find(node.op);
assert(nodeEdges != edges.end() && "missing node in graph");
node.vn = ++time;

for (auto &neighbour : nodeEdges->second) {
if (neighbour->vn == 0) {
bool ret = dfsNode(*neighbour, cycleCheck, time);
if (cycleCheck && ret)
return true;
} else if (cycleCheck && neighbour->fn == -1) {
// We have encountered a node whose visit has started but it's not
// finished. So we have a cycle.
return true;
}
}

// Update finish time.
node.fn = ++time;

return false;
}

// The list of nodes. The storage is owned by this class.
SmallVector<DGNode> nodes;

// Edges as an adjacency list.
DenseMap<Operation *, SmallVector<DGNode *>> edges;
};

} // namespace

/// Returns the trip count of the loop as an affine expression if the latter is
/// expressible as an affine expression, and nullptr otherwise. The trip count
Expand Down Expand Up @@ -447,3 +575,33 @@ bool mlir::affine::isTilingValid(ArrayRef<AffineForOp> loops) {

return true;
}

bool mlir::affine::hasCyclicDependence(AffineForOp root) {
// Collect all the memory accesses in the source nest grouped by their
// immediate parent block.
DirectedOpGraph graph;
SmallVector<MemRefAccess> accesses;
root->walk([&](Operation *op) {
if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
accesses.emplace_back(op);
graph.addNode(op);
}
});

// Construct the dependence graph for all the collected acccesses.
unsigned rootDepth = getNestingDepth(root);
for (const auto &accA : accesses) {
for (const auto &accB : accesses) {
if (accA.memref != accB.memref)
continue;
// Perform the dependence on all surrounding loops + the body.
unsigned numCommonLoops =
getNumCommonSurroundingLoops(*accA.opInst, *accB.opInst);
for (unsigned d = rootDepth + 1; d <= numCommonLoops + 1; ++d) {
if (!noDependence(checkMemrefAccessDependence(accA, accB, d)))
graph.addEdge(accA.opInst, accB.opInst);
}
}
}
return graph.hasCycle();
}
Loading