Skip to content

Commit 71513a7

Browse files
committed
[MLIR][Affine] Improve load elimination
Fixes llvm#62639. Differential Revision: https://reviews.llvm.org/D154769
1 parent 758c464 commit 71513a7

File tree

2 files changed

+45
-16
lines changed

2 files changed

+45
-16
lines changed

mlir/lib/Dialect/Affine/Utils/Utils.cpp

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -862,9 +862,10 @@ bool mlir::affine::hasNoInterveningEffect(Operation *start, T memOp) {
862862
/// other operations will overwrite the memory loaded between the given load
863863
/// and store. If such a value exists, the replaced `loadOp` will be added to
864864
/// `loadOpsToErase` and its memref will be added to `memrefsToErase`.
865-
static LogicalResult forwardStoreToLoad(
866-
AffineReadOpInterface loadOp, SmallVectorImpl<Operation *> &loadOpsToErase,
867-
SmallPtrSetImpl<Value> &memrefsToErase, DominanceInfo &domInfo) {
865+
static void forwardStoreToLoad(AffineReadOpInterface loadOp,
866+
SmallVectorImpl<Operation *> &loadOpsToErase,
867+
SmallPtrSetImpl<Value> &memrefsToErase,
868+
DominanceInfo &domInfo) {
868869

869870
// The store op candidate for forwarding that satisfies all conditions
870871
// to replace the load, if any.
@@ -911,21 +912,20 @@ static LogicalResult forwardStoreToLoad(
911912
}
912913

913914
if (!lastWriteStoreOp)
914-
return failure();
915+
return;
915916

916917
// Perform the actual store to load forwarding.
917918
Value storeVal =
918919
cast<AffineWriteOpInterface>(lastWriteStoreOp).getValueToStore();
919920
// Check if 2 values have the same shape. This is needed for affine vector
920921
// loads and stores.
921922
if (storeVal.getType() != loadOp.getValue().getType())
922-
return failure();
923+
return;
923924
loadOp.getValue().replaceAllUsesWith(storeVal);
924925
// Record the memref for a later sweep to optimize away.
925926
memrefsToErase.insert(loadOp.getMemRef());
926927
// Record this to erase later.
927928
loadOpsToErase.push_back(loadOp);
928-
return success();
929929
}
930930

931931
template bool
@@ -995,16 +995,16 @@ static void loadCSE(AffineReadOpInterface loadA,
995995
MemRefAccess srcAccess(loadB);
996996
MemRefAccess destAccess(loadA);
997997

998-
// 1. The accesses have to be to the same location.
998+
// 1. The accesses should be to be to the same location.
999999
if (srcAccess != destAccess) {
10001000
continue;
10011001
}
10021002

1003-
// 2. The store has to dominate the load op to be candidate.
1003+
// 2. loadB should dominate loadA.
10041004
if (!domInfo.dominates(loadB, loadA))
10051005
continue;
10061006

1007-
// 3. There is no write between loadA and loadB.
1007+
// 3. There should not be a write between loadA and loadB.
10081008
if (!affine::hasNoInterveningEffect<MemoryEffects::Write>(
10091009
loadB.getOperation(), loadA))
10101010
continue;
@@ -1073,13 +1073,8 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
10731073

10741074
// Walk all load's and perform store to load forwarding.
10751075
f.walk([&](AffineReadOpInterface loadOp) {
1076-
if (failed(
1077-
forwardStoreToLoad(loadOp, opsToErase, memrefsToErase, domInfo))) {
1078-
loadCSE(loadOp, opsToErase, domInfo);
1079-
}
1076+
forwardStoreToLoad(loadOp, opsToErase, memrefsToErase, domInfo);
10801077
});
1081-
1082-
// Erase all load op's whose results were replaced with store fwd'ed ones.
10831078
for (auto *op : opsToErase)
10841079
op->erase();
10851080
opsToErase.clear();
@@ -1088,9 +1083,9 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
10881083
f.walk([&](AffineWriteOpInterface storeOp) {
10891084
findUnusedStore(storeOp, opsToErase, postDomInfo);
10901085
});
1091-
// Erase all store op's which don't impact the program
10921086
for (auto *op : opsToErase)
10931087
op->erase();
1088+
opsToErase.clear();
10941089

10951090
// Check if the store fwd'ed memrefs are now left with only stores and
10961091
// deallocs and can thus be completely deleted. Note: the canonicalize pass
@@ -1114,6 +1109,15 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
11141109
user->erase();
11151110
defOp->erase();
11161111
}
1112+
1113+
// To eliminate as many loads as possible, run load CSE after eliminating
1114+
// stores. Otherwise, some stores are wrongly seen as having an intervening
1115+
// effect.
1116+
f.walk([&](AffineReadOpInterface loadOp) {
1117+
loadCSE(loadOp, opsToErase, domInfo);
1118+
});
1119+
for (auto *op : opsToErase)
1120+
op->erase();
11171121
}
11181122

11191123
// Perform the replacement in `op`.

mlir/test/Dialect/Affine/scalrep.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,31 @@ func.func @refs_not_known_to_be_equal(%A : memref<100 x 100 x f32>, %M : index)
280280
return
281281
}
282282

283+
// CHECK-LABEL: func @elim_load_after_store
284+
func.func @elim_load_after_store(%arg0: memref<100xf32>, %arg1: memref<100xf32>) {
285+
%alloc = memref.alloc() : memref<1xf32>
286+
%alloc_0 = memref.alloc() : memref<1xf32>
287+
// CHECK: affine.for
288+
affine.for %arg2 = 0 to 100 {
289+
// CHECK: affine.load
290+
%0 = affine.load %arg0[%arg2] : memref<100xf32>
291+
%1 = affine.load %arg0[%arg2] : memref<100xf32>
292+
// CHECK: arith.addf
293+
%2 = arith.addf %0, %1 : f32
294+
affine.store %2, %alloc_0[0] : memref<1xf32>
295+
%3 = affine.load %arg0[%arg2] : memref<100xf32>
296+
%4 = affine.load %alloc_0[0] : memref<1xf32>
297+
// CHECK-NEXT: arith.addf
298+
%5 = arith.addf %3, %4 : f32
299+
affine.store %5, %alloc[0] : memref<1xf32>
300+
%6 = affine.load %arg0[%arg2] : memref<100xf32>
301+
%7 = affine.load %alloc[0] : memref<1xf32>
302+
%8 = arith.addf %6, %7 : f32
303+
affine.store %8, %arg1[%arg2] : memref<100xf32>
304+
}
305+
return
306+
}
307+
283308
// The test checks for value forwarding from vector stores to vector loads.
284309
// The value loaded from %in can directly be stored to %out by eliminating
285310
// store and load from %tmp.

0 commit comments

Comments
 (0)