Skip to content

[mlir] [dataflow] : Improve the time and space footprint of data flow. #135325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 48 additions & 4 deletions mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
/// may modify the program state; that is, every operation and block.
LogicalResult initialize(Operation *top) override;

/// Initialize lattice anchor equivalence class from the provided top-level
/// operation.
///
/// This function will union lattice anchor to same equivalent class if the
/// analysis can determine the lattice content of lattice anchor is
/// necessarily identical under the corrensponding lattice type.
virtual void initializeEquivalentLatticeAnchor(Operation *top) override;

/// Visit a program point that modifies the state of the program. If the
/// program point is at the beginning of a block, then the state is propagated
/// from control-flow predecessors or callsites. If the operation before
Expand All @@ -96,8 +104,8 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
/// dependency. That is, every time the lattice after anchor is updated, the
/// dependent program point must be visited, and the newly triggered visit
/// might update the lattice on dependent.
const AbstractDenseLattice *getLatticeFor(ProgramPoint *dependent,
LatticeAnchor anchor);
virtual const AbstractDenseLattice *getLatticeFor(ProgramPoint *dependent,
LatticeAnchor anchor) = 0;

/// Set the dense lattice at control flow entry point and propagate an update
/// if it changed.
Expand All @@ -114,6 +122,11 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
/// operation transfer function.
virtual LogicalResult processOperation(Operation *op);

/// Visit an operation. If this analysis can confirm that lattice content
/// of lattice anchors around operation are necessarily identical, join
/// them into the same equivalent class.
virtual void buildOperationEquivalentLatticeAnchor(Operation *op) { return; }

/// Propagate the dense lattice forward along the control flow edge from
/// `regionFrom` to `regionTo` regions of the `branch` operation. `nullopt`
/// values correspond to control flow branches originating at or targeting the
Expand Down Expand Up @@ -252,6 +265,15 @@ class DenseForwardDataFlowAnalysis
return getOrCreate<LatticeT>(anchor);
}

/// Get the dense lattice on the given lattice anchor and add dependent as its
/// dependency. That is, every time the lattice after anchor is updated, the
/// dependent program point must be visited, and the newly triggered visit
/// might update the lattice on dependent.
const AbstractDenseLattice *getLatticeFor(ProgramPoint *dependent,
LatticeAnchor anchor) override {
return getOrCreateFor<LatticeT>(dependent, anchor);
}

/// Set the dense lattice at control flow entry point and propagate an update
/// if it changed.
virtual void setToEntryState(LatticeT *lattice) = 0;
Expand Down Expand Up @@ -310,6 +332,14 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// may modify the program state; that is, every operation and block.
LogicalResult initialize(Operation *top) override;

/// Initialize lattice anchor equivalence class from the provided top-level
/// operation.
///
/// This function will union lattice anchor to same equivalent class if the
/// analysis can determine the lattice content of lattice anchor is
/// necessarily identical under the corrensponding lattice type.
virtual void initializeEquivalentLatticeAnchor(Operation *top) override;

/// Visit a program point that modifies the state of the program. The state is
/// propagated along control flow directions for branch-, region- and
/// call-based control flow using the respective interfaces. For other
Expand All @@ -336,8 +366,8 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// dependency. That is, every time the lattice after anchor is updated, the
/// dependent program point must be visited, and the newly triggered visit
/// might update the lattice before dependent.
const AbstractDenseLattice *getLatticeFor(ProgramPoint *dependent,
LatticeAnchor anchor);
virtual const AbstractDenseLattice *getLatticeFor(ProgramPoint *dependent,
LatticeAnchor anchor) = 0;

/// Set the dense lattice before at the control flow exit point and propagate
/// the update if it changed.
Expand All @@ -353,6 +383,11 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// transfer function.
virtual LogicalResult processOperation(Operation *op);

/// Visit an operation. If this analysis can confirm that lattice content
/// of lattice anchors around operation are necessarily identical, join
/// them into the same equivalent class.
virtual void buildOperationEquivalentLatticeAnchor(Operation *op) { return; }

/// Propagate the dense lattice backwards along the control flow edge from
/// `regionFrom` to `regionTo` regions of the `branch` operation. `nullopt`
/// values correspond to control flow branches originating at or targeting the
Expand Down Expand Up @@ -502,6 +537,15 @@ class DenseBackwardDataFlowAnalysis
return getOrCreate<LatticeT>(anchor);
}

/// Get the dense lattice on the given lattice anchor and add dependent as its
/// dependency. That is, every time the lattice after anchor is updated, the
/// dependent program point must be visited, and the newly triggered visit
/// might update the lattice before dependent.
virtual const AbstractDenseLattice *
getLatticeFor(ProgramPoint *dependent, LatticeAnchor anchor) override {
return getOrCreateFor<LatticeT>(dependent, anchor);
}

/// Set the dense lattice at control flow exit point (after the terminator)
/// and propagate an update if it changed.
virtual void setToExitState(LatticeT *lattice) = 0;
Expand Down
124 changes: 115 additions & 9 deletions mlir/include/mlir/Analysis/DataFlowFramework.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "mlir/IR/Operation.h"
#include "mlir/Support/StorageUniquer.h"
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Compiler.h"
Expand Down Expand Up @@ -265,6 +266,14 @@ struct LatticeAnchor
/// Forward declaration of the data-flow analysis class.
class DataFlowAnalysis;

} // namespace mlir

template <>
struct llvm::DenseMapInfo<mlir::LatticeAnchor>
: public llvm::DenseMapInfo<mlir::LatticeAnchor::ParentTy> {};

namespace mlir {

//===----------------------------------------------------------------------===//
// DataFlowConfig
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -332,7 +341,9 @@ class DataFlowSolver {
/// does not exist.
template <typename StateT, typename AnchorT>
const StateT *lookupState(AnchorT anchor) const {
const auto &mapIt = analysisStates.find(LatticeAnchor(anchor));
LatticeAnchor latticeAnchor =
getLeaderAnchorOrSelf<StateT>(LatticeAnchor(anchor));
const auto &mapIt = analysisStates.find(latticeAnchor);
if (mapIt == analysisStates.end())
return nullptr;
auto it = mapIt->second.find(TypeID::get<StateT>());
Expand All @@ -344,12 +355,34 @@ class DataFlowSolver {
/// Erase any analysis state associated with the given lattice anchor.
template <typename AnchorT>
void eraseState(AnchorT anchor) {
LatticeAnchor la(anchor);
analysisStates.erase(LatticeAnchor(anchor));
LatticeAnchor latticeAnchor(anchor);

// Update equivalentAnchorMap.
for (auto &&[TypeId, eqClass] : equivalentAnchorMap) {
if (!eqClass.contains(latticeAnchor)) {
continue;
}
llvm::EquivalenceClasses<LatticeAnchor>::member_iterator leaderIt =
eqClass.findLeader(latticeAnchor);

// Update analysis states with new leader if needed.
if (*leaderIt == latticeAnchor && ++leaderIt != eqClass.member_end()) {
analysisStates[*leaderIt][TypeId] =
std::move(analysisStates[latticeAnchor][TypeId]);
}

eqClass.erase(latticeAnchor);
}

// Update analysis states.
analysisStates.erase(latticeAnchor);
}

// Erase all analysis states
void eraseAllStates() { analysisStates.clear(); }
/// Erase all analysis states.
void eraseAllStates() {
analysisStates.clear();
equivalentAnchorMap.clear();
}

/// Get a uniqued lattice anchor instance. If one is not present, it is
/// created with the provided arguments.
Expand Down Expand Up @@ -399,6 +432,20 @@ class DataFlowSolver {
template <typename StateT, typename AnchorT>
StateT *getOrCreateState(AnchorT anchor);

/// Get leader lattice anchor in equivalence lattice anchor group, return
/// input lattice anchor if input not found in equivalece lattice anchor
/// group.
template <typename StateT>
LatticeAnchor getLeaderAnchorOrSelf(LatticeAnchor latticeAnchor) const;

/// Union input anchors under the given state.
template <typename StateT, typename AnchorT>
void unionLatticeAnchors(AnchorT anchor, AnchorT other);

/// Return given lattice is equivalent on given state.
template <typename StateT>
bool isEquivalent(LatticeAnchor lhs, LatticeAnchor rhs) const;

/// Propagate an update to an analysis state if it changed by pushing
/// dependent work items to the back of the queue.
/// This should only be used when DataFlowSolver is running.
Expand Down Expand Up @@ -429,10 +476,15 @@ class DataFlowSolver {

/// A type-erased map of lattice anchors to associated analysis states for
/// first-class lattice anchors.
DenseMap<LatticeAnchor, DenseMap<TypeID, std::unique_ptr<AnalysisState>>,
DenseMapInfo<LatticeAnchor::ParentTy>>
DenseMap<LatticeAnchor, DenseMap<TypeID, std::unique_ptr<AnalysisState>>>
analysisStates;

/// A map of Ananlysis state type to the equivalent lattice anchors.
/// Lattice anchors are considered equivalent under a certain analysis state
/// type if and only if, the analysis states pointed to by these lattice
/// anchors necessarily contain identical value.
DenseMap<TypeID, llvm::EquivalenceClasses<LatticeAnchor>> equivalentAnchorMap;

/// Allow the base child analysis class to access the internals of the solver.
friend class DataFlowAnalysis;
};
Expand Down Expand Up @@ -564,6 +616,14 @@ class DataFlowAnalysis {
/// will provide a value for then.
virtual LogicalResult visit(ProgramPoint *point) = 0;

/// Initialize lattice anchor equivalence class from the provided top-level
/// operation.
///
/// This function will union lattice anchor to same equivalent class if the
/// analysis can determine the lattice content of lattice anchor is
/// necessarily identical under the corrensponding lattice type.
virtual void initializeEquivalentLatticeAnchor(Operation *top) { return; }

protected:
/// Create a dependency between the given analysis state and lattice anchor
/// on this analysis.
Expand All @@ -584,6 +644,12 @@ class DataFlowAnalysis {
return solver.getLatticeAnchor<AnchorT>(std::forward<Args>(args)...);
}

/// Union input anchors under the given state.
template <typename StateT, typename AnchorT>
void unionLatticeAnchors(AnchorT anchor, AnchorT other) {
return solver.unionLatticeAnchors<StateT>(anchor, other);
}

/// Get the analysis state associated with the lattice anchor. The returned
/// state is expected to be "write-only", and any updates need to be
/// propagated by `propagateIfChanged`.
Expand All @@ -598,7 +664,9 @@ class DataFlowAnalysis {
template <typename StateT, typename AnchorT>
const StateT *getOrCreateFor(ProgramPoint *dependent, AnchorT anchor) {
StateT *state = getOrCreate<StateT>(anchor);
addDependency(state, dependent);
if (!solver.isEquivalent<StateT>(LatticeAnchor(anchor),
LatticeAnchor(dependent)))
addDependency(state, dependent);
return state;
}

Expand Down Expand Up @@ -644,10 +712,29 @@ AnalysisT *DataFlowSolver::load(Args &&...args) {
return static_cast<AnalysisT *>(childAnalyses.back().get());
}

template <typename StateT>
LatticeAnchor
DataFlowSolver::getLeaderAnchorOrSelf(LatticeAnchor latticeAnchor) const {
if (!equivalentAnchorMap.contains(TypeID::get<StateT>())) {
return latticeAnchor;
}
const llvm::EquivalenceClasses<LatticeAnchor> &eqClass =
equivalentAnchorMap.at(TypeID::get<StateT>());
llvm::EquivalenceClasses<LatticeAnchor>::member_iterator leaderIt =
eqClass.findLeader(latticeAnchor);
if (leaderIt != eqClass.member_end()) {
return *leaderIt;
}
return latticeAnchor;
}

template <typename StateT, typename AnchorT>
StateT *DataFlowSolver::getOrCreateState(AnchorT anchor) {
// Replace to leader anchor if found.
LatticeAnchor latticeAnchor(anchor);
latticeAnchor = getLeaderAnchorOrSelf<StateT>(latticeAnchor);
std::unique_ptr<AnalysisState> &state =
analysisStates[LatticeAnchor(anchor)][TypeID::get<StateT>()];
analysisStates[latticeAnchor][TypeID::get<StateT>()];
if (!state) {
state = std::unique_ptr<StateT>(new StateT(anchor));
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
Expand All @@ -657,6 +744,25 @@ StateT *DataFlowSolver::getOrCreateState(AnchorT anchor) {
return static_cast<StateT *>(state.get());
}

template <typename StateT>
bool DataFlowSolver::isEquivalent(LatticeAnchor lhs, LatticeAnchor rhs) const {
if (!equivalentAnchorMap.contains(TypeID::get<StateT>())) {
return false;
}
const llvm::EquivalenceClasses<LatticeAnchor> &eqClass =
equivalentAnchorMap.at(TypeID::get<StateT>());
if (!eqClass.contains(lhs) || !eqClass.contains(rhs))
return false;
return eqClass.isEquivalent(lhs, rhs);
}

template <typename StateT, typename AnchorT>
void DataFlowSolver::unionLatticeAnchors(AnchorT anchor, AnchorT other) {
llvm::EquivalenceClasses<LatticeAnchor> &eqClass =
equivalentAnchorMap[TypeID::get<StateT>()];
eqClass.unionSets(LatticeAnchor(anchor), LatticeAnchor(other));
}

inline raw_ostream &operator<<(raw_ostream &os, const AnalysisState &state) {
state.print(os);
return os;
Expand Down
34 changes: 18 additions & 16 deletions mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ using namespace mlir::dataflow;
// AbstractDenseForwardDataFlowAnalysis
//===----------------------------------------------------------------------===//

void AbstractDenseForwardDataFlowAnalysis::initializeEquivalentLatticeAnchor(
Operation *top) {
top->walk([&](Operation *op) {
if (isa<RegionBranchOpInterface, CallOpInterface>(op))
return;
buildOperationEquivalentLatticeAnchor(op);
});
}

LogicalResult AbstractDenseForwardDataFlowAnalysis::initialize(Operation *top) {
// Visit every operation and block.
if (failed(processOperation(top)))
Expand Down Expand Up @@ -240,18 +249,19 @@ void AbstractDenseForwardDataFlowAnalysis::visitRegionBranchOperation(
}
}

const AbstractDenseLattice *
AbstractDenseForwardDataFlowAnalysis::getLatticeFor(ProgramPoint *dependent,
LatticeAnchor anchor) {
AbstractDenseLattice *state = getLattice(anchor);
addDependency(state, dependent);
return state;
}

//===----------------------------------------------------------------------===//
// AbstractDenseBackwardDataFlowAnalysis
//===----------------------------------------------------------------------===//

void AbstractDenseBackwardDataFlowAnalysis::initializeEquivalentLatticeAnchor(
Operation *top) {
top->walk([&](Operation *op) {
if (isa<RegionBranchOpInterface, CallOpInterface>(op))
return;
buildOperationEquivalentLatticeAnchor(op);
});
}

LogicalResult
AbstractDenseBackwardDataFlowAnalysis::initialize(Operation *top) {
// Visit every operation and block.
Expand Down Expand Up @@ -455,11 +465,3 @@ void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation(
before);
}
}

const AbstractDenseLattice *
AbstractDenseBackwardDataFlowAnalysis::getLatticeFor(ProgramPoint *dependent,
LatticeAnchor anchor) {
AbstractDenseLattice *state = getLattice(anchor);
addDependency(state, dependent);
return state;
}
5 changes: 5 additions & 0 deletions mlir/lib/Analysis/DataFlowFramework.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
isRunning = true;
auto guard = llvm::make_scope_exit([&]() { isRunning = false; });

// Initialize equivalent lattice anchors.
for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
analysis.initializeEquivalentLatticeAnchor(top);
}

// Initialize the analyses.
for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
DATAFLOW_DEBUG(llvm::dbgs()
Expand Down
Loading
Loading