12
12
13
13
#include " mlir/Dialect/SCF/Transforms/Passes.h"
14
14
15
+ #include " mlir/Analysis/AliasAnalysis.h"
15
16
#include " mlir/Dialect/MemRef/IR/MemRef.h"
16
17
#include " mlir/Dialect/SCF/IR/SCF.h"
17
18
#include " mlir/Dialect/SCF/Transforms/Transforms.h"
@@ -58,19 +59,27 @@ static bool equalIterationSpaces(ParallelOp firstPloop,
58
59
// / loop reads.
59
60
static bool haveNoReadsAfterWriteExceptSameIndex (
60
61
ParallelOp firstPloop, ParallelOp secondPloop,
61
- const IRMapping &firstToSecondPloopIndices) {
62
+ const IRMapping &firstToSecondPloopIndices,
63
+ llvm::function_ref<bool (Value, Value)> mayAlias) {
62
64
DenseMap<Value, SmallVector<ValueRange, 1 >> bufferStores;
65
+ SmallVector<Value> bufferStoresVec;
63
66
firstPloop.getBody ()->walk ([&](memref::StoreOp store) {
64
67
bufferStores[store.getMemRef ()].push_back (store.getIndices ());
68
+ bufferStoresVec.emplace_back (store.getMemRef ());
65
69
});
66
70
auto walkResult = secondPloop.getBody ()->walk ([&](memref::LoadOp load) {
71
+ Value loadMem = load.getMemRef ();
67
72
// Stop if the memref is defined in secondPloop body. Careful alias analysis
68
73
// is needed.
69
- auto *memrefDef = load. getMemRef () .getDefiningOp ();
74
+ auto *memrefDef = loadMem .getDefiningOp ();
70
75
if (memrefDef && memrefDef->getBlock () == load->getBlock ())
71
76
return WalkResult::interrupt ();
72
77
73
- auto write = bufferStores.find (load.getMemRef ());
78
+ for (Value store : bufferStoresVec)
79
+ if (store != loadMem && mayAlias (store, loadMem))
80
+ return WalkResult::interrupt ();
81
+
82
+ auto write = bufferStores.find (loadMem);
74
83
if (write == bufferStores.end ())
75
84
return WalkResult::advance ();
76
85
@@ -98,35 +107,39 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
98
107
// / write patterns.
99
108
static LogicalResult
100
109
verifyDependencies (ParallelOp firstPloop, ParallelOp secondPloop,
101
- const IRMapping &firstToSecondPloopIndices) {
102
- if (!haveNoReadsAfterWriteExceptSameIndex (firstPloop, secondPloop,
103
- firstToSecondPloopIndices))
110
+ const IRMapping &firstToSecondPloopIndices,
111
+ llvm::function_ref<bool (Value, Value)> mayAlias) {
112
+ if (!haveNoReadsAfterWriteExceptSameIndex (
113
+ firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
104
114
return failure ();
105
115
106
116
IRMapping secondToFirstPloopIndices;
107
117
secondToFirstPloopIndices.map (secondPloop.getBody ()->getArguments (),
108
118
firstPloop.getBody ()->getArguments ());
109
119
return success (haveNoReadsAfterWriteExceptSameIndex (
110
- secondPloop, firstPloop, secondToFirstPloopIndices));
120
+ secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias ));
111
121
}
112
122
113
123
static bool isFusionLegal (ParallelOp firstPloop, ParallelOp secondPloop,
114
- const IRMapping &firstToSecondPloopIndices) {
124
+ const IRMapping &firstToSecondPloopIndices,
125
+ llvm::function_ref<bool (Value, Value)> mayAlias) {
115
126
return !hasNestedParallelOp (firstPloop) &&
116
127
!hasNestedParallelOp (secondPloop) &&
117
128
equalIterationSpaces (firstPloop, secondPloop) &&
118
129
succeeded (verifyDependencies (firstPloop, secondPloop,
119
- firstToSecondPloopIndices));
130
+ firstToSecondPloopIndices, mayAlias ));
120
131
}
121
132
122
133
// / Prepends operations of firstPloop's body into secondPloop's body.
123
134
static void fuseIfLegal (ParallelOp firstPloop, ParallelOp secondPloop,
124
- OpBuilder b) {
135
+ OpBuilder b,
136
+ llvm::function_ref<bool (Value, Value)> mayAlias) {
125
137
IRMapping firstToSecondPloopIndices;
126
138
firstToSecondPloopIndices.map (firstPloop.getBody ()->getArguments (),
127
139
secondPloop.getBody ()->getArguments ());
128
140
129
- if (!isFusionLegal (firstPloop, secondPloop, firstToSecondPloopIndices))
141
+ if (!isFusionLegal (firstPloop, secondPloop, firstToSecondPloopIndices,
142
+ mayAlias))
130
143
return ;
131
144
132
145
b.setInsertionPointToStart (secondPloop.getBody ());
@@ -135,7 +148,8 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
135
148
firstPloop.erase ();
136
149
}
137
150
138
- void mlir::scf::naivelyFuseParallelOps (Region ®ion) {
151
+ void mlir::scf::naivelyFuseParallelOps (
152
+ Region ®ion, llvm::function_ref<bool (Value, Value)> mayAlias) {
139
153
OpBuilder b (region);
140
154
// Consider every single block and attempt to fuse adjacent loops.
141
155
for (auto &block : region) {
@@ -159,7 +173,7 @@ void mlir::scf::naivelyFuseParallelOps(Region ®ion) {
159
173
}
160
174
for (ArrayRef<ParallelOp> ploops : ploopChains) {
161
175
for (int i = 0 , e = ploops.size (); i + 1 < e; ++i)
162
- fuseIfLegal (ploops[i], ploops[i + 1 ], b);
176
+ fuseIfLegal (ploops[i], ploops[i + 1 ], b, mayAlias );
163
177
}
164
178
}
165
179
}
@@ -168,9 +182,15 @@ namespace {
168
182
struct ParallelLoopFusion
169
183
: public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
170
184
void runOnOperation () override {
185
+ auto &AA = getAnalysis<AliasAnalysis>();
186
+
187
+ auto mayAlias = [&](Value val1, Value val2) -> bool {
188
+ return !AA.alias (val1, val2).isNo ();
189
+ };
190
+
171
191
getOperation ()->walk ([&](Operation *child) {
172
192
for (Region ®ion : child->getRegions ())
173
- naivelyFuseParallelOps (region);
193
+ naivelyFuseParallelOps (region, mayAlias );
174
194
});
175
195
}
176
196
};
0 commit comments