Skip to content

Commit eab6f2d

Browse files
authored
[MLIR][Affine] Fix fusion in the presence of cyclic deps in source nests (#128397)
Fixes: #61820 Fix affine fusion in the presence of cyclic deps in the source nest. In such cases, the nest being fused can't be executed multiple times. Add a utility to check for dependence cycles and use it in fusion. This fixes both sibling as well as producer consumer fusion where nests with cyclic dependences (typically reductions) were being in some cases incorrectly fused in. The test case also exercises/required a fix to the check for the redundant computation being within the specified threshold.
1 parent b3c51db commit eab6f2d

File tree

5 files changed

+386
-47
lines changed

5 files changed

+386
-47
lines changed

mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,14 @@ bool isOpwiseShiftValid(AffineForOp forOp, ArrayRef<uint64_t> shifts);
119119
/// any dependence component is negative along any of `loops`.
120120
bool isTilingValid(ArrayRef<AffineForOp> loops);
121121

122+
/// Returns true if the affine nest rooted at `root` has a cyclic dependence
123+
/// among its affine memory accesses. The dependence could be through any
124+
/// dependences carried by loops contained in `root` (inclusive of `root`) and
125+
/// those carried by loop bodies (blocks) contained. Dependences carried by
126+
/// loops outer to `root` aren't relevant. This method doesn't consider/account
127+
/// for aliases.
128+
bool hasCyclicDependence(AffineForOp root);
129+
122130
} // namespace affine
123131
} // namespace mlir
124132

mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp

Lines changed: 160 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
1717
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
1818
#include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
19-
#include "mlir/Dialect/Affine/IR/AffineOps.h"
19+
#include "mlir/Dialect/Affine/Analysis/Utils.h"
2020
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
2121
#include "llvm/Support/MathExtras.h"
2222

@@ -28,10 +28,138 @@
2828
#include <optional>
2929
#include <type_traits>
3030

31+
#define DEBUG_TYPE "affine-loop-analysis"
32+
3133
using namespace mlir;
3234
using namespace mlir::affine;
3335

34-
#define DEBUG_TYPE "affine-loop-analysis"
36+
namespace {
37+
38+
/// A directed graph to model relationships between MLIR Operations.
39+
class DirectedOpGraph {
40+
public:
41+
/// Add a node to the graph.
42+
void addNode(Operation *op) {
43+
assert(!hasNode(op) && "node already added");
44+
nodes.emplace_back(op);
45+
edges[op] = {};
46+
}
47+
48+
/// Add an edge from `src` to `dest`.
49+
void addEdge(Operation *src, Operation *dest) {
50+
// This is a multi-graph.
51+
assert(hasNode(src) && "src node does not exist in graph");
52+
assert(hasNode(dest) && "dest node does not exist in graph");
53+
edges[src].push_back(getNode(dest));
54+
}
55+
56+
/// Returns true if there is a (directed) cycle in the graph.
57+
bool hasCycle() { return dfs(/*cycleCheck=*/true); }
58+
59+
void printEdges() {
60+
for (auto &en : edges) {
61+
llvm::dbgs() << *en.first << " (" << en.first << ")"
62+
<< " has " << en.second.size() << " edges:\n";
63+
for (auto *node : en.second) {
64+
llvm::dbgs() << '\t' << *node->op << '\n';
65+
}
66+
}
67+
}
68+
69+
private:
70+
/// A node of a directed graph between MLIR Operations to model various
71+
/// relationships. This is meant to be used internally.
72+
struct DGNode {
73+
DGNode(Operation *op) : op(op) {};
74+
Operation *op;
75+
76+
// Start and finish visit numbers are standard in DFS to implement things
77+
// like finding strongly connected components. These numbers are modified
78+
// during analyses on the graph and so seemingly const API methods will be
79+
// non-const.
80+
81+
/// Start visit number.
82+
int vn = -1;
83+
84+
/// Finish visit number.
85+
int fn = -1;
86+
};
87+
88+
/// Get internal node corresponding to `op`.
89+
DGNode *getNode(Operation *op) {
90+
auto *value =
91+
llvm::find_if(nodes, [&](const DGNode &node) { return node.op == op; });
92+
assert(value != nodes.end() && "node doesn't exist in graph");
93+
return &*value;
94+
}
95+
96+
/// Returns true if `key` is in the graph.
97+
bool hasNode(Operation *key) const {
98+
return llvm::find_if(nodes, [&](const DGNode &node) {
99+
return node.op == key;
100+
}) != nodes.end();
101+
}
102+
103+
/// Perform a depth-first traversal of the graph setting visited and finished
104+
/// numbers. If `cycleCheck` is set, detects cycles and returns true as soon
105+
/// as the first cycle is detected, and false if there are no cycles. If
106+
/// `cycleCheck` is not set, completes the DFS and the `return` value doesn't
107+
/// have a meaning.
108+
bool dfs(bool cycleCheck = false) {
109+
for (DGNode &node : nodes) {
110+
node.vn = 0;
111+
node.fn = -1;
112+
}
113+
114+
unsigned time = 0;
115+
for (DGNode &node : nodes) {
116+
if (node.vn == 0) {
117+
bool ret = dfsNode(node, cycleCheck, time);
118+
// Check if a cycle was already found.
119+
if (cycleCheck && ret)
120+
return true;
121+
} else if (cycleCheck && node.fn == -1) {
122+
// We have encountered a node whose visit has started but it's not
123+
// finished. So we have a cycle.
124+
return true;
125+
}
126+
}
127+
return false;
128+
}
129+
130+
/// Perform depth-first traversal starting at `node`. Return true
131+
/// as soon as a cycle is found if `cycleCheck` was set. Update `time`.
132+
bool dfsNode(DGNode &node, bool cycleCheck, unsigned &time) const {
133+
auto nodeEdges = edges.find(node.op);
134+
assert(nodeEdges != edges.end() && "missing node in graph");
135+
node.vn = ++time;
136+
137+
for (auto &neighbour : nodeEdges->second) {
138+
if (neighbour->vn == 0) {
139+
bool ret = dfsNode(*neighbour, cycleCheck, time);
140+
if (cycleCheck && ret)
141+
return true;
142+
} else if (cycleCheck && neighbour->fn == -1) {
143+
// We have encountered a node whose visit has started but it's not
144+
// finished. So we have a cycle.
145+
return true;
146+
}
147+
}
148+
149+
// Update finish time.
150+
node.fn = ++time;
151+
152+
return false;
153+
}
154+
155+
// The list of nodes. The storage is owned by this class.
156+
SmallVector<DGNode> nodes;
157+
158+
// Edges as an adjacency list.
159+
DenseMap<Operation *, SmallVector<DGNode *>> edges;
160+
};
161+
162+
} // namespace
35163

36164
/// Returns the trip count of the loop as an affine expression if the latter is
37165
/// expressible as an affine expression, and nullptr otherwise. The trip count
@@ -447,3 +575,33 @@ bool mlir::affine::isTilingValid(ArrayRef<AffineForOp> loops) {
447575

448576
return true;
449577
}
578+
579+
bool mlir::affine::hasCyclicDependence(AffineForOp root) {
580+
// Collect all the memory accesses in the source nest grouped by their
581+
// immediate parent block.
582+
DirectedOpGraph graph;
583+
SmallVector<MemRefAccess> accesses;
584+
root->walk([&](Operation *op) {
585+
if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
586+
accesses.emplace_back(op);
587+
graph.addNode(op);
588+
}
589+
});
590+
591+
// Construct the dependence graph for all the collected acccesses.
592+
unsigned rootDepth = getNestingDepth(root);
593+
for (const auto &accA : accesses) {
594+
for (const auto &accB : accesses) {
595+
if (accA.memref != accB.memref)
596+
continue;
597+
// Perform the dependence on all surrounding loops + the body.
598+
unsigned numCommonLoops =
599+
getNumCommonSurroundingLoops(*accA.opInst, *accB.opInst);
600+
for (unsigned d = rootDepth + 1; d <= numCommonLoops + 1; ++d) {
601+
if (!noDependence(checkMemrefAccessDependence(accA, accB, d)))
602+
graph.addEdge(accA.opInst, accB.opInst);
603+
}
604+
}
605+
}
606+
return graph.hasCycle();
607+
}

0 commit comments

Comments
 (0)