Skip to content

Commit 66f5e03

Browse files
[mlir][dataflow] disallow outside use of propagateIfChanged for DataFlowSolver
1 parent 9423961 commit 66f5e03

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

mlir/include/mlir/Analysis/DataFlowFramework.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -394,13 +394,17 @@ class DataFlowSolver {
394394
template <typename StateT, typename AnchorT>
395395
StateT *getOrCreateState(AnchorT anchor);
396396

397+
/// Get the configuration of the solver.
398+
const DataFlowConfig &getConfig() const { return config; }
399+
400+
protected:
397401
/// Propagate an update to an analysis state if it changed by pushing
398402
/// dependent work items to the back of the queue.
403+
/// This should only be used by DataFlowAnalysis instances.
404+
/// When used outside of DataFlowAnalysis, the solver won't process
405+
/// the work items
399406
void propagateIfChanged(AnalysisState *state, ChangeResult changed);
400407

401-
/// Get the configuration of the solver.
402-
const DataFlowConfig &getConfig() const { return config; }
403-
404408
private:
405409
/// Configuration of the dataflow solver.
406410
DataFlowConfig config;

mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,13 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
4545
std::optional<APInt> constant = getValue().getValue().getConstantValue();
4646
auto value = cast<Value>(anchor);
4747
auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
48-
if (!constant)
49-
return solver->propagateIfChanged(
50-
cv, cv->join(ConstantValue::getUnknownConstant()));
48+
if (!constant) {
49+
auto changed = cv->join(ConstantValue::getUnknownConstant());
50+
if (changed == ChangeResult::Change) {
51+
cv->onUpdate(solver);
52+
}
53+
return;
54+
}
5155

5256
Dialect *dialect;
5357
if (auto *parent = value.getDefiningOp())
@@ -56,8 +60,11 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
5660
dialect = value.getParentBlock()->getParentOp()->getDialect();
5761

5862
Type type = getElementTypeOrSelf(value);
59-
solver->propagateIfChanged(
60-
cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect)));
63+
auto changed =
64+
cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect));
65+
if (changed == ChangeResult::Change) {
66+
cv->onUpdate(solver);
67+
}
6168
}
6269

6370
LogicalResult IntegerRangeAnalysis::visitOperation(

0 commit comments

Comments
 (0)