Skip to content

Commit fb99c74

Browse files
committed
[MLIR][Affine] Fix fusion in the presence of cyclic deps in source nests
Fixes: #61820 Fix affine fusion in the presence of cyclic deps in the source nest. In such cases, the nest being fused can't be executed multiple times. Add a utility to check for dependence cycles and use it in fusion. This fixes both sibling as well as producer consumer fusion where nests with cyclic dependences (typically reductions) were being in some cases incorrectly fused in. The test case also exercises/required a fix to the check for the redundant computation being within the specified threshold.
1 parent 86ef031 commit fb99c74

File tree

5 files changed

+363
-47
lines changed

5 files changed

+363
-47
lines changed

mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,13 @@ bool isOpwiseShiftValid(AffineForOp forOp, ArrayRef<uint64_t> shifts);
119119
/// any dependence component is negative along any of `loops`.
120120
bool isTilingValid(ArrayRef<AffineForOp> loops);
121121

122+
/// Returns true if the affine nest rooted at `root` has a cyclic dependence
123+
/// among its affine memory accesses. The dependence could be through any
124+
/// dependences carried by loops contained in `root` (inclusive of `root`) and
125+
/// those carried by loop bodies (blocks) contained. Dependences carried by
126+
/// loops outer to `root` aren't relevant.
127+
bool hasCyclicDependence(AffineForOp root);
128+
122129
} // namespace affine
123130
} // namespace mlir
124131

mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp

Lines changed: 160 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
1717
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
1818
#include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
19-
#include "mlir/Dialect/Affine/IR/AffineOps.h"
19+
#include "mlir/Dialect/Affine/Analysis/Utils.h"
2020
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
2121
#include "llvm/Support/MathExtras.h"
2222

@@ -28,10 +28,138 @@
2828
#include <optional>
2929
#include <type_traits>
3030

31+
#define DEBUG_TYPE "affine-loop-analysis"
32+
3133
using namespace mlir;
3234
using namespace mlir::affine;
3335

34-
#define DEBUG_TYPE "affine-loop-analysis"
36+
namespace {
37+
38+
/// A directed graph to model relationships between MLIR Operations.
39+
class DirectedOpGraph {
40+
public:
41+
/// Add a node to
42+
void addNode(Operation *op) {
43+
assert(!hasNode(op) && "node already added");
44+
nodes.emplace_back(op);
45+
edges[op] = {};
46+
}
47+
48+
/// Add an edge between `src` and `dest`.
49+
void addEdge(Operation *src, Operation *dest) {
50+
// This is a multi-graph.
51+
assert(hasNode(src) && "src node does not exist in graph");
52+
assert(hasNode(dest) && "dest node does not exist in graph");
53+
edges[src].push_back(getNode(dest));
54+
}
55+
56+
/// Returns true if there is a (directed) cycle in the graph.
57+
bool hasCycle() { return dfsImpl(/*cycleCheck=*/true); }
58+
59+
void printEdges() {
60+
for (auto &en : edges) {
61+
llvm::dbgs() << *en.first << " (" << en.first << ")"
62+
<< " has " << en.second.size() << " edges:\n";
63+
for (auto *node : en.second) {
64+
llvm::dbgs() << '\t' << *node->op << '\n';
65+
}
66+
}
67+
}
68+
69+
private:
70+
/// A node of a directed graph between MLIR Operations to model various
71+
/// relationships. This is meant to be used internally.
72+
struct DGNode {
73+
DGNode(Operation *op) : op(op) {};
74+
Operation *op;
75+
76+
// Start and finish visit numbers are standard in DFS to implement things
77+
// strongly connected components. These numbers are modified during analyses
78+
// on the graph and so seemingly const API methods will be non-const.
79+
80+
/// Start visit number.
81+
int vn = -1;
82+
83+
/// Finish visit number.
84+
int fn = -1;
85+
};
86+
87+
/// Get internal node corresponding to `op`.
88+
DGNode *getNode(Operation *op) {
89+
auto *value =
90+
llvm::find_if(nodes, [&](const DGNode &node) { return node.op == op; });
91+
assert(value != nodes.end() && "node doesn't exist in graph");
92+
return &*value;
93+
}
94+
95+
/// Returns true if `key` is in the graph.
96+
bool hasNode(Operation *key) const {
97+
return llvm::find_if(nodes, [&](const DGNode &node) {
98+
return node.op == key;
99+
}) != nodes.end();
100+
}
101+
102+
/// Perform a depth-first traversal of the graph setting visited and finished
103+
/// numbers. If `cycleCheck` is set, detects cycles and returns true as soon
104+
/// as the first cycle is detected, and false if there are no cycles. If
105+
/// `cycleCheck` is not set, completes the DFS and the `return` value doesn't
106+
/// have a meaning.
107+
bool dfsImpl(bool cycleCheck = false) {
108+
for (DGNode &node : nodes)
109+
node.vn = 0;
110+
111+
unsigned time = 0;
112+
for (DGNode &node : nodes) {
113+
if (node.vn == 0) {
114+
bool ret = dfsNode(node, cycleCheck, time);
115+
// Check if a cycle was already found.
116+
if (cycleCheck && ret)
117+
return true;
118+
} else if (cycleCheck && node.fn == -1) {
119+
// We have encountered a node whose visit has started but it's not
120+
// finished. So we have a cycle.
121+
return true;
122+
}
123+
}
124+
return false;
125+
}
126+
127+
/// Perform depth-first traversal starting at `node`. Return true
128+
/// as soon as a cycle is found if `cycleCheck` was set. Update `time`.
129+
bool dfsNode(DGNode &node, bool cycleCheck, unsigned &time) const {
130+
auto nodeEdges = edges.find(node.op);
131+
assert(nodeEdges != edges.end() && "missing node in graph");
132+
// Depth first search from a given vertex.
133+
++time;
134+
node.vn = time;
135+
136+
for (auto &neighbour : nodeEdges->second) {
137+
if (neighbour->vn == 0) {
138+
bool ret = dfsNode(*neighbour, cycleCheck, time);
139+
if (cycleCheck && ret)
140+
return true;
141+
} else if (cycleCheck && neighbour->fn == -1) {
142+
// We have encountered a node whose visit has started but it's not
143+
// finished. So we have a cycle.
144+
return true;
145+
}
146+
}
147+
148+
++time;
149+
// Update finish time.
150+
node.fn = time;
151+
152+
return false;
153+
}
154+
155+
// The list of nodes. The storage is owned by this class.
156+
SmallVector<DGNode> nodes;
157+
158+
// Edges as an adjacency list.
159+
DenseMap<Operation *, SmallVector<DGNode *>> edges;
160+
};
161+
162+
} // namespace
35163

36164
/// Returns the trip count of the loop as an affine expression if the latter is
37165
/// expressible as an affine expression, and nullptr otherwise. The trip count
@@ -447,3 +575,33 @@ bool mlir::affine::isTilingValid(ArrayRef<AffineForOp> loops) {
447575

448576
return true;
449577
}
578+
579+
bool mlir::affine::hasCyclicDependence(AffineForOp root) {
580+
// Collect all the memory accesses in the source nest grouped by their
581+
// immediate parent block.
582+
DirectedOpGraph graph;
583+
SmallVector<MemRefAccess> accesses;
584+
root->walk([&](Operation *op) {
585+
if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
586+
accesses.emplace_back(op);
587+
graph.addNode(op);
588+
}
589+
});
590+
591+
// Construct the dependence graph for all the collected acccesses.
592+
unsigned rootDepth = getNestingDepth(root);
593+
for (const auto &accA : accesses) {
594+
for (const auto &accB : accesses) {
595+
if (accA.memref != accB.memref)
596+
continue;
597+
// Perform the dependence on all surrounding loops + the body.
598+
unsigned numCommonLoops =
599+
getNumCommonSurroundingLoops(*accA.opInst, *accB.opInst);
600+
for (unsigned d = rootDepth + 1; d <= numCommonLoops + 1; ++d) {
601+
if (!noDependence(checkMemrefAccessDependence(accA, accB, d)))
602+
graph.addEdge(accA.opInst, accB.opInst);
603+
}
604+
}
605+
}
606+
return graph.hasCycle();
607+
}

0 commit comments

Comments
 (0)