Skip to content

Commit c0d2ea9

Browse files
authored
[mlir][scf] Improve scf.parallel fusion pass (#75852)
Abort fusion if memref load may alias write, but not the exact alias. Add alias check hook to `naivelyFuseParallelOps`, so user can customize alias checking. Use builtin alias analysis in `ParallelLoopFusion` pass.
1 parent 9aeb333 commit c0d2ea9

File tree

3 files changed

+68
-15
lines changed

3 files changed

+68
-15
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@ class ParallelOp;
3434
/// Fuses all adjacent scf.parallel operations with identical bounds and step
3535
/// into one scf.parallel operations. Uses a naive aliasing and dependency
3636
/// analysis.
37-
void naivelyFuseParallelOps(Region &region);
37+
/// User can additionally customize alias checking with `mayAlias` hook.
38+
/// `mayAlias` must return false if 2 values are guaranteed to not alias.
39+
void naivelyFuseParallelOps(Region &region,
40+
llvm::function_ref<bool(Value, Value)> mayAlias);
3841

3942
/// Rewrite a for loop with bounds/step that potentially do not divide evenly
4043
/// into a for loop where the step divides the iteration space evenly, followed

mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "mlir/Dialect/SCF/Transforms/Passes.h"
1414

15+
#include "mlir/Analysis/AliasAnalysis.h"
1516
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1617
#include "mlir/Dialect/SCF/IR/SCF.h"
1718
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
@@ -58,19 +59,27 @@ static bool equalIterationSpaces(ParallelOp firstPloop,
5859
/// loop reads.
5960
static bool haveNoReadsAfterWriteExceptSameIndex(
6061
ParallelOp firstPloop, ParallelOp secondPloop,
61-
const IRMapping &firstToSecondPloopIndices) {
62+
const IRMapping &firstToSecondPloopIndices,
63+
llvm::function_ref<bool(Value, Value)> mayAlias) {
6264
DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
65+
SmallVector<Value> bufferStoresVec;
6366
firstPloop.getBody()->walk([&](memref::StoreOp store) {
6467
bufferStores[store.getMemRef()].push_back(store.getIndices());
68+
bufferStoresVec.emplace_back(store.getMemRef());
6569
});
6670
auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
71+
Value loadMem = load.getMemRef();
6772
// Stop if the memref is defined in secondPloop body. Careful alias analysis
6873
// is needed.
69-
auto *memrefDef = load.getMemRef().getDefiningOp();
74+
auto *memrefDef = loadMem.getDefiningOp();
7075
if (memrefDef && memrefDef->getBlock() == load->getBlock())
7176
return WalkResult::interrupt();
7277

73-
auto write = bufferStores.find(load.getMemRef());
78+
for (Value store : bufferStoresVec)
79+
if (store != loadMem && mayAlias(store, loadMem))
80+
return WalkResult::interrupt();
81+
82+
auto write = bufferStores.find(loadMem);
7483
if (write == bufferStores.end())
7584
return WalkResult::advance();
7685

@@ -98,35 +107,39 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
98107
/// write patterns.
99108
static LogicalResult
100109
verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
101-
const IRMapping &firstToSecondPloopIndices) {
102-
if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop,
103-
firstToSecondPloopIndices))
110+
const IRMapping &firstToSecondPloopIndices,
111+
llvm::function_ref<bool(Value, Value)> mayAlias) {
112+
if (!haveNoReadsAfterWriteExceptSameIndex(
113+
firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
104114
return failure();
105115

106116
IRMapping secondToFirstPloopIndices;
107117
secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
108118
firstPloop.getBody()->getArguments());
109119
return success(haveNoReadsAfterWriteExceptSameIndex(
110-
secondPloop, firstPloop, secondToFirstPloopIndices));
120+
secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
111121
}
112122

113123
static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
114-
const IRMapping &firstToSecondPloopIndices) {
124+
const IRMapping &firstToSecondPloopIndices,
125+
llvm::function_ref<bool(Value, Value)> mayAlias) {
115126
return !hasNestedParallelOp(firstPloop) &&
116127
!hasNestedParallelOp(secondPloop) &&
117128
equalIterationSpaces(firstPloop, secondPloop) &&
118129
succeeded(verifyDependencies(firstPloop, secondPloop,
119-
firstToSecondPloopIndices));
130+
firstToSecondPloopIndices, mayAlias));
120131
}
121132

122133
/// Prepends operations of firstPloop's body into secondPloop's body.
123134
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
124-
OpBuilder b) {
135+
OpBuilder b,
136+
llvm::function_ref<bool(Value, Value)> mayAlias) {
125137
IRMapping firstToSecondPloopIndices;
126138
firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
127139
secondPloop.getBody()->getArguments());
128140

129-
if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices))
141+
if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
142+
mayAlias))
130143
return;
131144

132145
b.setInsertionPointToStart(secondPloop.getBody());
@@ -135,7 +148,8 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
135148
firstPloop.erase();
136149
}
137150

138-
void mlir::scf::naivelyFuseParallelOps(Region &region) {
151+
void mlir::scf::naivelyFuseParallelOps(
152+
Region &region, llvm::function_ref<bool(Value, Value)> mayAlias) {
139153
OpBuilder b(region);
140154
// Consider every single block and attempt to fuse adjacent loops.
141155
for (auto &block : region) {
@@ -159,7 +173,7 @@ void mlir::scf::naivelyFuseParallelOps(Region &region) {
159173
}
160174
for (ArrayRef<ParallelOp> ploops : ploopChains) {
161175
for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
162-
fuseIfLegal(ploops[i], ploops[i + 1], b);
176+
fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
163177
}
164178
}
165179
}
@@ -168,9 +182,15 @@ namespace {
168182
struct ParallelLoopFusion
169183
: public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
170184
void runOnOperation() override {
185+
auto &AA = getAnalysis<AliasAnalysis>();
186+
187+
auto mayAlias = [&](Value val1, Value val2) -> bool {
188+
return !AA.alias(val1, val2).isNo();
189+
};
190+
171191
getOperation()->walk([&](Operation *child) {
172192
for (Region &region : child->getRegions())
173-
naivelyFuseParallelOps(region);
193+
naivelyFuseParallelOps(region, mayAlias);
174194
});
175195
}
176196
};

mlir/test/Dialect/SCF/parallel-loop-fusion.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,3 +357,33 @@ func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
357357
// CHECK: }
358358
// CHECK: }
359359
// CHECK: memref.dealloc [[SUM]]
360+
361+
// -----
362+
363+
func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
364+
%C: memref<2x2xf32>, %result: memref<2x2xf32>,
365+
%sum: memref<2x2xf32>) {
366+
%c2 = arith.constant 2 : index
367+
%c0 = arith.constant 0 : index
368+
%c1 = arith.constant 1 : index
369+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
370+
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
371+
%C_elem = memref.load %C[%i, %j] : memref<2x2xf32>
372+
%sum_elem = arith.addf %B_elem, %C_elem : f32
373+
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
374+
scf.yield
375+
}
376+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
377+
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
378+
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
379+
%product_elem = arith.mulf %sum_elem, %A_elem : f32
380+
memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
381+
scf.yield
382+
}
383+
return
384+
}
385+
386+
// %sum and %result may alias with other args, do not fuse loops
387+
// CHECK-LABEL: func @do_not_fuse_alias
388+
// CHECK: scf.parallel
389+
// CHECK: scf.parallel

0 commit comments

Comments
 (0)