Skip to content

[mlir] support non-interprocedural dataflow analyses #75583

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ namespace dataflow {
// CallControlFlowAction
//===----------------------------------------------------------------------===//

/// Indicates whether the control enters or exits the callee.
enum class CallControlFlowAction { EnterCallee, ExitCallee };
/// Indicates whether the control enters, exits, or skips over the callee (in
/// the case of external functions).
enum class CallControlFlowAction { EnterCallee, ExitCallee, ExternalCallee };

//===----------------------------------------------------------------------===//
// AbstractDenseLattice
Expand Down Expand Up @@ -131,14 +132,21 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {

/// Propagate the dense lattice forward along the call control flow edge,
/// which can be either entering or exiting the callee. Default implementation
/// just meets the states, meaning that operations implementing
/// `CallOpInterface` don't have any effect on the lattice that isn't already
/// expressed by the interface itself.
/// for enter and exit callee actions just meets the states, meaning that
/// operations implementing `CallOpInterface` don't have any effect on the
/// lattice that isn't already expressed by the interface itself. Default
/// implementation for the external callee action additionally sets the
/// "after" lattice to the entry state.
virtual void visitCallControlFlowTransfer(CallOpInterface call,
CallControlFlowAction action,
const AbstractDenseLattice &before,
AbstractDenseLattice *after) {
join(after, before);
// Note that `setToEntryState` may be a "partial fixpoint" for some
// lattices, e.g., lattices that are lists of maps of other lattices will
// only set fixpoint for "known" lattices.
if (action == CallControlFlowAction::ExternalCallee)
setToEntryState(after);
}

/// Visit a program point within a region branch operation with predecessors
Expand All @@ -155,7 +163,9 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {

/// Visit an operation for which the data flow is described by the
/// `CallOpInterface`.
void visitCallOperation(CallOpInterface call, AbstractDenseLattice *after);
void visitCallOperation(CallOpInterface call,
const AbstractDenseLattice &before,
AbstractDenseLattice *after);
};

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -361,14 +371,22 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {

/// Propagate the dense lattice backwards along the call control flow edge,
/// which can be either entering or exiting the callee. Default implementation
/// just meets the states, meaning that operations implementing
/// `CallOpInterface` don't have any effect on hte lattice that isn't already
/// expressed by the interface itself.
/// for enter and exit callee action just meets the states, meaning that
/// operations implementing `CallOpInterface` don't have any effect on the
/// lattice that isn't already expressed by the interface itself. Default
/// implementation for external callee action additional sets the result to
/// the exit (fixpoint) state.
virtual void visitCallControlFlowTransfer(CallOpInterface call,
CallControlFlowAction action,
const AbstractDenseLattice &after,
AbstractDenseLattice *before) {
meet(before, after);

// Note that `setToExitState` may be a "partial fixpoint" for some lattices,
// e.g., lattices that are lists of maps of other lattices will only
// set fixpoint for "known" lattices.
if (action == CallControlFlowAction::ExternalCallee)
setToExitState(before);
}

private:
Expand All @@ -394,7 +412,9 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// otherwise,
/// - meet that state with the state before the call-like op, or use the
/// custom logic if overridden by concrete analyses.
void visitCallOperation(CallOpInterface call, AbstractDenseLattice *before);
void visitCallOperation(CallOpInterface call,
const AbstractDenseLattice &after,
AbstractDenseLattice *before);

/// Symbol table for call-level control flow.
SymbolTableCollection &symbolTable;
Expand Down
55 changes: 55 additions & 0 deletions mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "llvm/ADT/SmallPtrSet.h"

Expand Down Expand Up @@ -199,6 +200,12 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
ArrayRef<const AbstractSparseLattice *> operandLattices,
ArrayRef<AbstractSparseLattice *> resultLattices) = 0;

/// The transfer function for calls to external functions.
virtual void visitExternalCallImpl(
CallOpInterface call,
ArrayRef<const AbstractSparseLattice *> argumentLattices,
ArrayRef<AbstractSparseLattice *> resultLattices) = 0;

/// Given an operation with region control-flow, the lattices of the operands,
/// and a region successor, compute the lattice values for block arguments
/// that are not accounted for by the branching control flow (ex. the bounds
Expand Down Expand Up @@ -271,6 +278,14 @@ class SparseForwardDataFlowAnalysis
virtual void visitOperation(Operation *op, ArrayRef<const StateT *> operands,
ArrayRef<StateT *> results) = 0;

/// Visit a call operation to an externally defined function given the
/// lattices of its arguments.
virtual void visitExternalCall(CallOpInterface call,
ArrayRef<const StateT *> argumentLattices,
ArrayRef<StateT *> resultLattices) {
setAllToEntryStates(resultLattices);
}

/// Given an operation with possible region control-flow, the lattices of the
/// operands, and a region successor, compute the lattice values for block
/// arguments that are not accounted for by the branching control flow (ex.
Expand Down Expand Up @@ -321,6 +336,17 @@ class SparseForwardDataFlowAnalysis
{reinterpret_cast<StateT *const *>(resultLattices.begin()),
resultLattices.size()});
}
void visitExternalCallImpl(
CallOpInterface call,
ArrayRef<const AbstractSparseLattice *> argumentLattices,
ArrayRef<AbstractSparseLattice *> resultLattices) override {
visitExternalCall(
call,
{reinterpret_cast<const StateT *const *>(argumentLattices.begin()),
argumentLattices.size()},
{reinterpret_cast<StateT *const *>(resultLattices.begin()),
resultLattices.size()});
}
void visitNonControlFlowArgumentsImpl(
Operation *op, const RegionSuccessor &successor,
ArrayRef<AbstractSparseLattice *> argLattices,
Expand Down Expand Up @@ -363,6 +389,11 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
Operation *op, ArrayRef<AbstractSparseLattice *> operandLattices,
ArrayRef<const AbstractSparseLattice *> resultLattices) = 0;

/// The transfer function for calls to external functions.
virtual void visitExternalCallImpl(
CallOpInterface call, ArrayRef<AbstractSparseLattice *> operandLattices,
ArrayRef<const AbstractSparseLattice *> resultLattices) = 0;

// Visit operands on branch instructions that are not forwarded.
virtual void visitBranchOperand(OpOperand &operand) = 0;

Expand Down Expand Up @@ -444,6 +475,19 @@ class SparseBackwardDataFlowAnalysis
virtual void visitOperation(Operation *op, ArrayRef<StateT *> operands,
ArrayRef<const StateT *> results) = 0;

/// Visit a call to an external function. This function is expected to set
/// lattice values of the call operands. By default, calls `visitCallOperand`
/// for all operands.
virtual void visitExternalCall(CallOpInterface call,
ArrayRef<StateT *> argumentLattices,
ArrayRef<const StateT *> resultLattices) {
(void)argumentLattices;
(void)resultLattices;
for (OpOperand &operand : call->getOpOperands()) {
visitCallOperand(operand);
}
};

protected:
/// Get the lattice element for a value.
StateT *getLatticeElement(Value value) override {
Expand Down Expand Up @@ -474,6 +518,17 @@ class SparseBackwardDataFlowAnalysis
{reinterpret_cast<const StateT *const *>(resultLattices.begin()),
resultLattices.size()});
}

void visitExternalCallImpl(
CallOpInterface call, ArrayRef<AbstractSparseLattice *> operandLattices,
ArrayRef<const AbstractSparseLattice *> resultLattices) override {
visitExternalCall(
call,
{reinterpret_cast<StateT *const *>(operandLattices.begin()),
operandLattices.size()},
{reinterpret_cast<const StateT *const *>(resultLattices.begin()),
resultLattices.size()});
}
};

} // end namespace dataflow
Expand Down
38 changes: 38 additions & 0 deletions mlir/include/mlir/Analysis/DataFlowFramework.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,32 @@ struct ProgramPoint
/// Forward declaration of the data-flow analysis class.
class DataFlowAnalysis;

//===----------------------------------------------------------------------===//
// DataFlowConfig
//===----------------------------------------------------------------------===//

/// Configuration class for data flow solver and child analyses. Follows the
/// fluent API pattern.
class DataFlowConfig {
public:
DataFlowConfig() = default;

/// Set whether the solver should operate interpocedurally, i.e. enter the
/// callee body when available. Interprocedural analyses may be more precise,
/// but also more expensive as more states need to be computed and the
/// fixpoint convergence takes longer.
DataFlowConfig &setInterprocedural(bool enable) {
interprocedural = enable;
return *this;
}

/// Return `true` if the solver operates interprocedurally, `false` otherwise.
bool isInterprocedural() const { return interprocedural; }

private:
bool interprocedural = true;
};

//===----------------------------------------------------------------------===//
// DataFlowSolver
//===----------------------------------------------------------------------===//
Expand All @@ -195,6 +221,9 @@ class DataFlowAnalysis;
/// TODO: Optimize the internal implementation of the solver.
class DataFlowSolver {
public:
explicit DataFlowSolver(const DataFlowConfig &config = DataFlowConfig())
: config(config) {}

/// Load an analysis into the solver. Return the analysis instance.
template <typename AnalysisT, typename... Args>
AnalysisT *load(Args &&...args);
Expand Down Expand Up @@ -236,7 +265,13 @@ class DataFlowSolver {
/// dependent work items to the back of the queue.
void propagateIfChanged(AnalysisState *state, ChangeResult changed);

/// Get the configuration of the solver.
const DataFlowConfig &getConfig() const { return config; }

private:
/// Configuration of the dataflow solver.
DataFlowConfig config;

/// The solver's work queue. Work items can be inserted to the front of the
/// queue to be processed greedily, speeding up computations that otherwise
/// quickly degenerate to quadratic due to propagation of state updates.
Expand Down Expand Up @@ -423,6 +458,9 @@ class DataFlowAnalysis {
return state;
}

/// Return the configuration of the solver used for this analysis.
const DataFlowConfig &getSolverConfig() const { return solver.getConfig(); }

#if LLVM_ENABLE_ABI_BREAKING_CHECKS
/// When compiling with debugging, keep a name for the analyis.
StringRef debugName;
Expand Down
42 changes: 29 additions & 13 deletions mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,22 @@ LogicalResult AbstractDenseForwardDataFlowAnalysis::visit(ProgramPoint point) {
}

void AbstractDenseForwardDataFlowAnalysis::visitCallOperation(
CallOpInterface call, AbstractDenseLattice *after) {
CallOpInterface call, const AbstractDenseLattice &before,
AbstractDenseLattice *after) {
// Allow for customizing the behavior of calls to external symbols, including
// when the analysis is explicitly marked as non-interprocedural.
auto callable =
dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
if (!getSolverConfig().isInterprocedural() ||
(callable && !callable.getCallableRegion())) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to have something like if !callable || !callable.getCallableRegion() here. I ask because I'm not sure how this would work in the presense of indirect function calls, which perhaps have some information on an attribute, but also I worry may otherwise may not be visited correctly for the required conservative assumptions.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For indirect calls, callable will be null so we won't enter this branch. The control flow will proceed and, in the generic case, will enter branch with "not all predecessors known" condition because we can't reason about the callee of indirect calls. That will set everything to fixpoint and return, so it is conservatively correct. I haven't double-checked, but there may be a rare but happier situation where the set of possible callees is actually known (constant propagation or some custom logic), at which point we will be able to proceed interprocedurally.

We could add a customization hook for this situation, but I believe it should be different from "external callee" as (1) the callee is not necessarily external and (2) the implementer of the hook could benefit from knowing we are in the "unknown callee" situation. I'd suggest doing this as a separate change to we see the effects it has on various tests

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, that seems reasonable to me.

return visitCallControlFlowTransfer(
call, CallControlFlowAction::ExternalCallee, before, after);
}

const auto *predecessors =
getOrCreateFor<PredecessorState>(call.getOperation(), call);
// If not all return sites are known, then conservatively assume we can't
// reason about the data-flow.
// Otherwise, if not all return sites are known, then conservatively assume we
// can't reason about the data-flow.
if (!predecessors->allPredecessorsKnown())
return setToEntryState(after);

Expand Down Expand Up @@ -108,7 +118,7 @@ void AbstractDenseForwardDataFlowAnalysis::processOperation(Operation *op) {
// If this is a call operation, then join its lattices across known return
// sites.
if (auto call = dyn_cast<CallOpInterface>(op))
return visitCallOperation(call, after);
return visitCallOperation(call, *before, after);

// Invoke the operation transfer function.
visitOperationImpl(op, *before, after);
Expand All @@ -130,8 +140,10 @@ void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) {
if (callable && callable.getCallableRegion() == block->getParent()) {
const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
// If not all callsites are known, conservatively mark all lattices as
// having reached their pessimistic fixpoints.
if (!callsites->allPredecessorsKnown())
// having reached their pessimistic fixpoints. Do the same if
// interprocedural analysis is not enabled.
if (!callsites->allPredecessorsKnown() ||
!getSolverConfig().isInterprocedural())
return setToEntryState(after);
for (Operation *callsite : callsites->getKnownPredecessors()) {
// Get the dense lattice before the callsite.
Expand Down Expand Up @@ -267,18 +279,20 @@ LogicalResult AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
}

void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation(
CallOpInterface call, AbstractDenseLattice *before) {
CallOpInterface call, const AbstractDenseLattice &after,
AbstractDenseLattice *before) {
// Find the callee.
Operation *callee = call.resolveCallable(&symbolTable);
auto callable = dyn_cast_or_null<CallableOpInterface>(callee);
if (!callable)
return setToExitState(before);

// No region means the callee is only declared in this module and we shouldn't
// assume anything about it.
// No region means the callee is only declared in this module.
Region *region = callable.getCallableRegion();
if (!region || region->empty())
return setToExitState(before);
if (!region || region->empty() || !getSolverConfig().isInterprocedural()) {
return visitCallControlFlowTransfer(
call, CallControlFlowAction::ExternalCallee, after, before);
}

// Call-level control flow specifies the data flow here.
//
Expand Down Expand Up @@ -324,7 +338,7 @@ void AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) {
return visitRegionBranchOperation(op, branch, RegionBranchPoint::parent(),
before);
if (auto call = dyn_cast<CallOpInterface>(op))
return visitCallOperation(call, before);
return visitCallOperation(call, *after, before);

// Invoke the operation transfer function.
visitOperationImpl(op, *after, before);
Expand Down Expand Up @@ -359,8 +373,10 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
// If not all call sites are known, conservative mark all lattices as
// having reached their pessimistic fix points.
if (!callsites->allPredecessorsKnown())
if (!callsites->allPredecessorsKnown() ||
!getSolverConfig().isInterprocedural()) {
return setToExitState(before);
}

for (Operation *callsite : callsites->getKnownPredecessors()) {
const AbstractDenseLattice *after;
Expand Down
Loading