Skip to content

Commit 2e7202b

Browse files
committed
[fir] Add data flow optimization pass
Add pass to perform store/load forwarding and potentially removing dead stores. This patch is part of the upstreaming effort from fir-dev branch. Reviewed By: kiranchandramohan, schweitz, mehdi_amini, awarzynski Differential Revision: https://reviews.llvm.org/D111288
1 parent af37d4b commit 2e7202b

File tree

5 files changed

+224
-0
lines changed

5 files changed

+224
-0
lines changed

flang/include/flang/Optimizer/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ std::unique_ptr<mlir::Pass> createAffineDemotionPass();
3131
std::unique_ptr<mlir::Pass> createFirToCfgPass();
3232
std::unique_ptr<mlir::Pass> createCharacterConversionPass();
3333
std::unique_ptr<mlir::Pass> createExternalNameConversionPass();
34+
std::unique_ptr<mlir::Pass> createMemDataFlowOptPass();
3435
std::unique_ptr<mlir::Pass> createPromoteToAffinePass();
3536

3637
/// Support for inlining on FIR.

flang/include/flang/Optimizer/Transforms/Passes.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,4 +120,17 @@ def ExternalNameConversion : Pass<"external-name-interop", "mlir::ModuleOp"> {
120120
let constructor = "::fir::createExternalNameConversionPass()";
121121
}
122122

123+
def MemRefDataFlowOpt : FunctionPass<"fir-memref-dataflow-opt"> {
124+
let summary =
125+
"Perform store/load forwarding and potentially removing dead stores.";
126+
let description = [{
127+
This pass performs store to load forwarding to eliminate memory accesses and
128+
potentially the entire allocation if all the accesses are forwarded.
129+
}];
130+
let constructor = "::fir::createMemDataFlowOptPass()";
131+
let dependentDialects = [
132+
"fir::FIROpsDialect", "mlir::StandardOpsDialect"
133+
];
134+
}
135+
123136
#endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES

flang/lib/Optimizer/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ add_flang_library(FIRTransforms
55
CharacterConversion.cpp
66
Inliner.cpp
77
ExternalNameConversion.cpp
8+
MemRefDataFlowOpt.cpp
89
RewriteLoop.cpp
910

1011
DEPENDS
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
//===- MemRefDataFlowOpt.cpp - Memory DataFlow Optimization pass ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "PassDetail.h"
10+
#include "flang/Optimizer/Dialect/FIRDialect.h"
11+
#include "flang/Optimizer/Dialect/FIROps.h"
12+
#include "flang/Optimizer/Dialect/FIRType.h"
13+
#include "flang/Optimizer/Transforms/Passes.h"
14+
#include "mlir/Dialect/StandardOps/IR/Ops.h"
15+
#include "mlir/IR/Dominance.h"
16+
#include "mlir/IR/Operation.h"
17+
#include "mlir/Transforms/Passes.h"
18+
#include "llvm/ADT/Optional.h"
19+
#include "llvm/ADT/STLExtras.h"
20+
#include "llvm/ADT/SmallVector.h"
21+
22+
#define DEBUG_TYPE "fir-memref-dataflow-opt"
23+
24+
namespace {
25+
26+
template <typename OpT>
27+
static std::vector<OpT> getSpecificUsers(mlir::Value v) {
28+
std::vector<OpT> ops;
29+
for (mlir::Operation *user : v.getUsers())
30+
if (auto op = dyn_cast<OpT>(user))
31+
ops.push_back(op);
32+
return ops;
33+
}
34+
35+
/// This is based on MLIR's MemRefDataFlowOpt which is specialized on AffineRead
36+
/// and AffineWrite interface
37+
template <typename ReadOp, typename WriteOp>
38+
class LoadStoreForwarding {
39+
public:
40+
LoadStoreForwarding(mlir::DominanceInfo *di) : domInfo(di) {}
41+
42+
// FIXME: This algorithm has a bug. It ignores escaping references between a
43+
// store and a load.
44+
llvm::Optional<WriteOp> findStoreToForward(ReadOp loadOp,
45+
std::vector<WriteOp> &&storeOps) {
46+
llvm::SmallVector<WriteOp> candidateSet;
47+
48+
for (auto storeOp : storeOps)
49+
if (domInfo->dominates(storeOp, loadOp))
50+
candidateSet.push_back(storeOp);
51+
52+
if (candidateSet.empty())
53+
return {};
54+
55+
llvm::Optional<WriteOp> nearestStore;
56+
for (auto candidate : candidateSet) {
57+
auto nearerThan = [&](WriteOp otherStore) {
58+
if (candidate == otherStore)
59+
return false;
60+
bool rv = domInfo->properlyDominates(candidate, otherStore);
61+
if (rv) {
62+
LLVM_DEBUG(llvm::dbgs()
63+
<< "candidate " << candidate << " is not the nearest to "
64+
<< loadOp << " because " << otherStore << " is closer\n");
65+
}
66+
return rv;
67+
};
68+
if (!llvm::any_of(candidateSet, nearerThan)) {
69+
nearestStore = mlir::cast<WriteOp>(candidate);
70+
break;
71+
}
72+
}
73+
if (!nearestStore) {
74+
LLVM_DEBUG(
75+
llvm::dbgs()
76+
<< "load " << loadOp << " has " << candidateSet.size()
77+
<< " store candidates, but this algorithm can't find a best.\n");
78+
}
79+
return nearestStore;
80+
}
81+
82+
llvm::Optional<ReadOp> findReadForWrite(WriteOp storeOp,
83+
std::vector<ReadOp> &&loadOps) {
84+
for (auto &loadOp : loadOps) {
85+
if (domInfo->dominates(storeOp, loadOp))
86+
return loadOp;
87+
}
88+
return {};
89+
}
90+
91+
private:
92+
mlir::DominanceInfo *domInfo;
93+
};
94+
95+
class MemDataFlowOpt : public fir::MemRefDataFlowOptBase<MemDataFlowOpt> {
96+
public:
97+
void runOnFunction() override {
98+
mlir::FuncOp f = getFunction();
99+
100+
auto *domInfo = &getAnalysis<mlir::DominanceInfo>();
101+
LoadStoreForwarding<fir::LoadOp, fir::StoreOp> lsf(domInfo);
102+
f.walk([&](fir::LoadOp loadOp) {
103+
auto maybeStore = lsf.findStoreToForward(
104+
loadOp, getSpecificUsers<fir::StoreOp>(loadOp.memref()));
105+
if (maybeStore) {
106+
auto storeOp = maybeStore.getValue();
107+
LLVM_DEBUG(llvm::dbgs() << "FlangMemDataFlowOpt: In " << f.getName()
108+
<< " erasing load " << loadOp
109+
<< " with value from " << storeOp << '\n');
110+
loadOp.getResult().replaceAllUsesWith(storeOp.value());
111+
loadOp.erase();
112+
}
113+
});
114+
f.walk([&](fir::AllocaOp alloca) {
115+
for (auto &storeOp : getSpecificUsers<fir::StoreOp>(alloca.getResult())) {
116+
if (!lsf.findReadForWrite(
117+
storeOp, getSpecificUsers<fir::LoadOp>(storeOp.memref()))) {
118+
LLVM_DEBUG(llvm::dbgs() << "FlangMemDataFlowOpt: In " << f.getName()
119+
<< " erasing store " << storeOp << '\n');
120+
storeOp.erase();
121+
}
122+
}
123+
});
124+
}
125+
};
126+
} // namespace
127+
128+
std::unique_ptr<mlir::Pass> fir::createMemDataFlowOptPass() {
129+
return std::make_unique<MemDataFlowOpt>();
130+
}

flang/test/Fir/memref-data-flow.fir

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// RUN: fir-opt --split-input-file --fir-memref-dataflow-opt %s | FileCheck %s
2+
3+
// Test that all load-store chains are removed
4+
5+
func @load_store_chain_removal(%arg0: !fir.ref<!fir.array<60xi32>>, %arg1: !fir.ref<!fir.array<60xi32>>, %arg2: !fir.ref<!fir.array<60xi32>>) {
6+
%c1_i64 = arith.constant 1 : i64
7+
%c60 = arith.constant 60 : index
8+
%c0 = arith.constant 0 : index
9+
%c1 = arith.constant 1 : index
10+
%0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFf1dcEi"}
11+
%1 = fir.alloca !fir.array<60xi32> {bindc_name = "t1", uniq_name = "_QFf1dcEt1"}
12+
br ^bb1(%c1, %c60 : index, index)
13+
^bb1(%2: index, %3: index): // 2 preds: ^bb0, ^bb2
14+
%4 = arith.cmpi sgt, %3, %c0 : index
15+
cond_br %4, ^bb2, ^bb3
16+
^bb2: // pred: ^bb1
17+
%5 = fir.convert %2 : (index) -> i32
18+
fir.store %5 to %0 : !fir.ref<i32>
19+
%6 = fir.load %0 : !fir.ref<i32>
20+
%7 = fir.convert %6 : (i32) -> i64
21+
%8 = arith.subi %7, %c1_i64 : i64
22+
%9 = fir.coordinate_of %arg0, %8 : (!fir.ref<!fir.array<60xi32>>, i64) -> !fir.ref<i32>
23+
%10 = fir.load %9 : !fir.ref<i32>
24+
%11 = arith.addi %10, %10 : i32
25+
%12 = fir.coordinate_of %1, %8 : (!fir.ref<!fir.array<60xi32>>, i64) -> !fir.ref<i32>
26+
fir.store %11 to %12 : !fir.ref<i32>
27+
%13 = arith.addi %2, %c1 : index
28+
%14 = arith.subi %3, %c1 : index
29+
br ^bb1(%13, %14 : index, index)
30+
^bb3: // pred: ^bb1
31+
%15 = fir.convert %2 : (index) -> i32
32+
fir.store %15 to %0 : !fir.ref<i32>
33+
br ^bb4(%c1, %c60 : index, index)
34+
^bb4(%16: index, %17: index): // 2 preds: ^bb3, ^bb5
35+
%18 = arith.cmpi sgt, %17, %c0 : index
36+
cond_br %18, ^bb5, ^bb6
37+
^bb5: // pred: ^bb4
38+
%19 = fir.convert %16 : (index) -> i32
39+
fir.store %19 to %0 : !fir.ref<i32>
40+
%20 = fir.load %0 : !fir.ref<i32>
41+
%21 = fir.convert %20 : (i32) -> i64
42+
%22 = arith.subi %21, %c1_i64 : i64
43+
%23 = fir.coordinate_of %1, %22 : (!fir.ref<!fir.array<60xi32>>, i64) -> !fir.ref<i32>
44+
%24 = fir.load %23 : !fir.ref<i32>
45+
%25 = fir.coordinate_of %arg1, %22 : (!fir.ref<!fir.array<60xi32>>, i64) -> !fir.ref<i32>
46+
%26 = fir.load %25 : !fir.ref<i32>
47+
%27 = arith.muli %24, %26 : i32
48+
%28 = fir.coordinate_of %arg2, %22 : (!fir.ref<!fir.array<60xi32>>, i64) -> !fir.ref<i32>
49+
fir.store %27 to %28 : !fir.ref<i32>
50+
%29 = arith.addi %16, %c1 : index
51+
%30 = arith.subi %17, %c1 : index
52+
br ^bb4(%29, %30 : index, index)
53+
^bb6: // pred: ^bb4
54+
%31 = fir.convert %16 : (index) -> i32
55+
fir.store %31 to %0 : !fir.ref<i32>
56+
return
57+
}
58+
59+
// CHECK-LABEL: func @load_store_chain_removal
60+
// CHECK-LABEL: ^bb1
61+
// CHECK-LABEL: ^bb2:
62+
// Make sure the previous fir.store/fir.load pair have been elimated and we
63+
// preserve the last pair of fir.load/fir.store.
64+
// CHECK-COUNT-1: %{{.*}} = fir.load %{{.*}} : !fir.ref<i32>
65+
// CHECK-COUNT-1: fir.store %{{.*}} to %{{.*}} : !fir.ref<i32>
66+
// CHECK-LABEL: ^bb3:
67+
// Make sure the fir.store has been removed.
68+
// CHECK-NOT: fir.store %{{.*}} to %{{.*}} : !fir.ref<i32>
69+
// CHECK-LABEL: ^bb5:
70+
// CHECK: %{{.*}} = fir.convert %{{.*}} : (index) -> i32
71+
// Check that the fir.store/fir.load pair has been removed between the convert.
72+
// CHECK-NOT: fir.store %{{.*}} to %{{.*}} : !fir.ref<i32>
73+
// CHECK-NOT: %{{.*}} = fir.load %{{.*}} : !fir.ref<i32>
74+
// CHECK: %{{.*}} = fir.convert %{{.*}} : (i32) -> i64
75+
// CHECK: %{{.*}} = fir.load %{{.*}} : !fir.ref<i32>
76+
// CHECK: %{{.*}} = fir.load %{{.*}} : !fir.ref<i32>
77+
// CHECK: fir.store %{{.*}} to %{{.*}} : !fir.ref<i32>
78+
// CHECK-LABEL: ^bb6:
79+
// CHECK-NOT: fir.store %{{.*}} to %{{.*}} : !fir.ref<i32>

0 commit comments

Comments
 (0)