Skip to content

Commit 6a66673

Browse files
phisiartJeff Niu
authored andcommitted
[mlir][dataflow] Unify dependency management in AnalysisState.
In the MLIR dataflow analysis framework, when an `AnalysisState` is updated, it's dependents are enqueued to be visited. Currently, there are two ways dependents are managed: * `AnalysisState::dependents` stores a list of dependents. `DataFlowSolver::propagateIfChanged()` reads this list and enqueues them to the worklist. * `AnalysisState::onUpdate()` allows custom logic to enqueue more to the worklist. This is called by `DataFlowSolver::propagateIfChanged()`. This cleanup diff consolidates the two into `AnalysisState::onUpdate()`. This way, `DataFlowSolver` does not need to know the detail about `AnalysisState::dependents`, and the logic of dependency management is entirely handled by `AnalysisState`. Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D154170
1 parent 39dd4eb commit 6a66673

File tree

4 files changed

+41
-35
lines changed

4 files changed

+41
-35
lines changed

mlir/include/mlir/Analysis/DataFlowFramework.h

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -235,12 +235,6 @@ class DataFlowSolver {
235235
/// dependent work items to the back of the queue.
236236
void propagateIfChanged(AnalysisState *state, ChangeResult changed);
237237

238-
/// Add a dependency to an analysis state on a child analysis and program
239-
/// point. If the state is updated, the child analysis must be invoked on the
240-
/// given program point again.
241-
void addDependency(AnalysisState *state, DataFlowAnalysis *analysis,
242-
ProgramPoint point);
243-
244238
private:
245239
/// The solver's work queue. Work items can be inserted to the front of the
246240
/// queue to be processed greedily, speeding up computations that otherwise
@@ -294,13 +288,30 @@ class AnalysisState {
294288
/// Print the contents of the analysis state.
295289
virtual void print(raw_ostream &os) const = 0;
296290

291+
/// Add a dependency to this analysis state on a program point and an
292+
/// analysis. If this state is updated, the analysis will be invoked on the
293+
/// given program point again (in onUpdate()).
294+
void addDependency(ProgramPoint dependent, DataFlowAnalysis *analysis);
295+
297296
protected:
298297
/// This function is called by the solver when the analysis state is updated
299-
/// to optionally enqueue more work items. For example, if a state tracks
300-
/// dependents through the IR (e.g. use-def chains), this function can be
301-
/// implemented to push those dependents on the worklist.
302-
virtual void onUpdate(DataFlowSolver *solver) const {}
298+
/// to enqueue more work items. For example, if a state tracks dependents
299+
/// through the IR (e.g. use-def chains), this function can be implemented to
300+
/// push those dependents on the worklist.
301+
virtual void onUpdate(DataFlowSolver *solver) const {
302+
for (const DataFlowSolver::WorkItem &item : dependents)
303+
solver->enqueue(item);
304+
}
305+
306+
/// The program point to which the state belongs.
307+
ProgramPoint point;
308+
309+
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
310+
/// When compiling with debugging, keep a name for the analysis state.
311+
StringRef debugName;
312+
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
303313

314+
private:
304315
/// The dependency relations originating from this analysis state. An entry
305316
/// `state -> (analysis, point)` is created when `analysis` queries `state`
306317
/// when updating `point`.
@@ -312,14 +323,6 @@ class AnalysisState {
312323
/// Store the dependents on the analysis state for efficiency.
313324
SetVector<DataFlowSolver::WorkItem> dependents;
314325

315-
/// The program point to which the state belongs.
316-
ProgramPoint point;
317-
318-
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
319-
/// When compiling with debugging, keep a name for the analysis state.
320-
StringRef debugName;
321-
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
322-
323326
/// Allow the framework to access the dependents.
324327
friend class DataFlowSolver;
325328
};

mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
1010
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
11+
#include "mlir/Analysis/DataFlowFramework.h"
1112
#include "mlir/Interfaces/CallInterfaces.h"
1213
#include "mlir/Interfaces/ControlFlowInterfaces.h"
1314
#include <optional>
@@ -31,6 +32,8 @@ void Executable::print(raw_ostream &os) const {
3132
}
3233

3334
void Executable::onUpdate(DataFlowSolver *solver) const {
35+
AnalysisState::onUpdate(solver);
36+
3437
if (auto *block = llvm::dyn_cast_if_present<Block *>(point)) {
3538
// Re-invoke the analyses on the block itself.
3639
for (DataFlowAnalysis *analysis : subscribers)

mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
1010
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
11+
#include "mlir/Analysis/DataFlowFramework.h"
1112
#include "mlir/Interfaces/CallInterfaces.h"
1213

1314
using namespace mlir;
@@ -18,6 +19,8 @@ using namespace mlir::dataflow;
1819
//===----------------------------------------------------------------------===//
1920

2021
void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const {
22+
AnalysisState::onUpdate(solver);
23+
2124
// Push all users of the value to the queue.
2225
for (Operation *user : point.get<Value>().getUsers())
2326
for (DataFlowAnalysis *analysis : useDefSubscribers)

mlir/lib/Analysis/DataFlowFramework.cpp

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,19 @@ GenericProgramPoint::~GenericProgramPoint() = default;
3030

3131
AnalysisState::~AnalysisState() = default;
3232

33+
void AnalysisState::addDependency(ProgramPoint dependent,
34+
DataFlowAnalysis *analysis) {
35+
auto inserted = dependents.insert({dependent, analysis});
36+
(void)inserted;
37+
DATAFLOW_DEBUG({
38+
if (inserted) {
39+
llvm::dbgs() << "Creating dependency between " << debugName << " of "
40+
<< point << "\nand " << debugName << " on " << dependent
41+
<< "\n";
42+
}
43+
});
44+
}
45+
3346
//===----------------------------------------------------------------------===//
3447
// ProgramPoint
3548
//===----------------------------------------------------------------------===//
@@ -97,26 +110,10 @@ void DataFlowSolver::propagateIfChanged(AnalysisState *state,
97110
DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << state->debugName
98111
<< " of " << state->point << "\n"
99112
<< "Value: " << *state << "\n");
100-
for (const WorkItem &item : state->dependents)
101-
enqueue(item);
102113
state->onUpdate(this);
103114
}
104115
}
105116

106-
void DataFlowSolver::addDependency(AnalysisState *state,
107-
DataFlowAnalysis *analysis,
108-
ProgramPoint point) {
109-
auto inserted = state->dependents.insert({point, analysis});
110-
(void)inserted;
111-
DATAFLOW_DEBUG({
112-
if (inserted) {
113-
llvm::dbgs() << "Creating dependency between " << state->debugName
114-
<< " of " << state->point << "\nand " << analysis->debugName
115-
<< " on " << point << "\n";
116-
}
117-
});
118-
}
119-
120117
//===----------------------------------------------------------------------===//
121118
// DataFlowAnalysis
122119
//===----------------------------------------------------------------------===//
@@ -126,7 +123,7 @@ DataFlowAnalysis::~DataFlowAnalysis() = default;
126123
DataFlowAnalysis::DataFlowAnalysis(DataFlowSolver &solver) : solver(solver) {}
127124

128125
void DataFlowAnalysis::addDependency(AnalysisState *state, ProgramPoint point) {
129-
solver.addDependency(state, this, point);
126+
state->addDependency(point, this);
130127
}
131128

132129
void DataFlowAnalysis::propagateIfChanged(AnalysisState *state,

0 commit comments

Comments
 (0)