Skip to content

Commit 32a4e3f

Browse files
ftynsepengmai
andauthored
[mlir] support non-interprocedural dataflow analyses (#75583)
The core implementation of the dataflow anlysis framework is interpocedural by design. While this offers better analysis precision, it also comes with additional cost as it takes longer for the analysis to reach the fixpoint state. Add a configuration mechanism to the dataflow solver to control whether it operates inteprocedurally or not to offer clients a choice. As a positive side effect, this change also adds hooks for explicitly processing external/opaque function calls in the dataflow analyses, e.g., based off of attributes present in the the function declaration or call operation such as alias scopes and modref available in the LLVM dialect. This change should not affect existing analyses and the default solver configuration remains interprocedural. Co-authored-by: Jacob Peng <[email protected]>
1 parent 82a1bff commit 32a4e3f

12 files changed

+771
-171
lines changed

mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@ namespace dataflow {
2727
// CallControlFlowAction
2828
//===----------------------------------------------------------------------===//
2929

30-
/// Indicates whether the control enters or exits the callee.
31-
enum class CallControlFlowAction { EnterCallee, ExitCallee };
30+
/// Indicates whether the control enters, exits, or skips over the callee (in
31+
/// the case of external functions).
32+
enum class CallControlFlowAction { EnterCallee, ExitCallee, ExternalCallee };
3233

3334
//===----------------------------------------------------------------------===//
3435
// AbstractDenseLattice
@@ -131,14 +132,21 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
131132

132133
/// Propagate the dense lattice forward along the call control flow edge,
133134
/// which can be either entering or exiting the callee. Default implementation
134-
/// just meets the states, meaning that operations implementing
135-
/// `CallOpInterface` don't have any effect on the lattice that isn't already
136-
/// expressed by the interface itself.
135+
/// for enter and exit callee actions just meets the states, meaning that
136+
/// operations implementing `CallOpInterface` don't have any effect on the
137+
/// lattice that isn't already expressed by the interface itself. Default
138+
/// implementation for the external callee action additionally sets the
139+
/// "after" lattice to the entry state.
137140
virtual void visitCallControlFlowTransfer(CallOpInterface call,
138141
CallControlFlowAction action,
139142
const AbstractDenseLattice &before,
140143
AbstractDenseLattice *after) {
141144
join(after, before);
145+
// Note that `setToEntryState` may be a "partial fixpoint" for some
146+
// lattices, e.g., lattices that are lists of maps of other lattices will
147+
// only set fixpoint for "known" lattices.
148+
if (action == CallControlFlowAction::ExternalCallee)
149+
setToEntryState(after);
142150
}
143151

144152
/// Visit a program point within a region branch operation with predecessors
@@ -155,7 +163,9 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
155163

156164
/// Visit an operation for which the data flow is described by the
157165
/// `CallOpInterface`.
158-
void visitCallOperation(CallOpInterface call, AbstractDenseLattice *after);
166+
void visitCallOperation(CallOpInterface call,
167+
const AbstractDenseLattice &before,
168+
AbstractDenseLattice *after);
159169
};
160170

161171
//===----------------------------------------------------------------------===//
@@ -361,14 +371,22 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
361371

362372
/// Propagate the dense lattice backwards along the call control flow edge,
363373
/// which can be either entering or exiting the callee. Default implementation
364-
/// just meets the states, meaning that operations implementing
365-
/// `CallOpInterface` don't have any effect on hte lattice that isn't already
366-
/// expressed by the interface itself.
374+
/// for enter and exit callee action just meets the states, meaning that
375+
/// operations implementing `CallOpInterface` don't have any effect on the
376+
/// lattice that isn't already expressed by the interface itself. Default
377+
/// implementation for external callee action additional sets the result to
378+
/// the exit (fixpoint) state.
367379
virtual void visitCallControlFlowTransfer(CallOpInterface call,
368380
CallControlFlowAction action,
369381
const AbstractDenseLattice &after,
370382
AbstractDenseLattice *before) {
371383
meet(before, after);
384+
385+
// Note that `setToExitState` may be a "partial fixpoint" for some lattices,
386+
// e.g., lattices that are lists of maps of other lattices will only
387+
// set fixpoint for "known" lattices.
388+
if (action == CallControlFlowAction::ExternalCallee)
389+
setToExitState(before);
372390
}
373391

374392
private:
@@ -394,7 +412,9 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
394412
/// otherwise,
395413
/// - meet that state with the state before the call-like op, or use the
396414
/// custom logic if overridden by concrete analyses.
397-
void visitCallOperation(CallOpInterface call, AbstractDenseLattice *before);
415+
void visitCallOperation(CallOpInterface call,
416+
const AbstractDenseLattice &after,
417+
AbstractDenseLattice *before);
398418

399419
/// Symbol table for call-level control flow.
400420
SymbolTableCollection &symbolTable;

mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include "mlir/Analysis/DataFlowFramework.h"
1919
#include "mlir/IR/SymbolTable.h"
20+
#include "mlir/Interfaces/CallInterfaces.h"
2021
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2122
#include "llvm/ADT/SmallPtrSet.h"
2223

@@ -199,6 +200,12 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
199200
ArrayRef<const AbstractSparseLattice *> operandLattices,
200201
ArrayRef<AbstractSparseLattice *> resultLattices) = 0;
201202

203+
/// The transfer function for calls to external functions.
204+
virtual void visitExternalCallImpl(
205+
CallOpInterface call,
206+
ArrayRef<const AbstractSparseLattice *> argumentLattices,
207+
ArrayRef<AbstractSparseLattice *> resultLattices) = 0;
208+
202209
/// Given an operation with region control-flow, the lattices of the operands,
203210
/// and a region successor, compute the lattice values for block arguments
204211
/// that are not accounted for by the branching control flow (ex. the bounds
@@ -271,6 +278,14 @@ class SparseForwardDataFlowAnalysis
271278
virtual void visitOperation(Operation *op, ArrayRef<const StateT *> operands,
272279
ArrayRef<StateT *> results) = 0;
273280

281+
/// Visit a call operation to an externally defined function given the
282+
/// lattices of its arguments.
283+
virtual void visitExternalCall(CallOpInterface call,
284+
ArrayRef<const StateT *> argumentLattices,
285+
ArrayRef<StateT *> resultLattices) {
286+
setAllToEntryStates(resultLattices);
287+
}
288+
274289
/// Given an operation with possible region control-flow, the lattices of the
275290
/// operands, and a region successor, compute the lattice values for block
276291
/// arguments that are not accounted for by the branching control flow (ex.
@@ -321,6 +336,17 @@ class SparseForwardDataFlowAnalysis
321336
{reinterpret_cast<StateT *const *>(resultLattices.begin()),
322337
resultLattices.size()});
323338
}
339+
void visitExternalCallImpl(
340+
CallOpInterface call,
341+
ArrayRef<const AbstractSparseLattice *> argumentLattices,
342+
ArrayRef<AbstractSparseLattice *> resultLattices) override {
343+
visitExternalCall(
344+
call,
345+
{reinterpret_cast<const StateT *const *>(argumentLattices.begin()),
346+
argumentLattices.size()},
347+
{reinterpret_cast<StateT *const *>(resultLattices.begin()),
348+
resultLattices.size()});
349+
}
324350
void visitNonControlFlowArgumentsImpl(
325351
Operation *op, const RegionSuccessor &successor,
326352
ArrayRef<AbstractSparseLattice *> argLattices,
@@ -363,6 +389,11 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
363389
Operation *op, ArrayRef<AbstractSparseLattice *> operandLattices,
364390
ArrayRef<const AbstractSparseLattice *> resultLattices) = 0;
365391

392+
/// The transfer function for calls to external functions.
393+
virtual void visitExternalCallImpl(
394+
CallOpInterface call, ArrayRef<AbstractSparseLattice *> operandLattices,
395+
ArrayRef<const AbstractSparseLattice *> resultLattices) = 0;
396+
366397
// Visit operands on branch instructions that are not forwarded.
367398
virtual void visitBranchOperand(OpOperand &operand) = 0;
368399

@@ -444,6 +475,19 @@ class SparseBackwardDataFlowAnalysis
444475
virtual void visitOperation(Operation *op, ArrayRef<StateT *> operands,
445476
ArrayRef<const StateT *> results) = 0;
446477

478+
/// Visit a call to an external function. This function is expected to set
479+
/// lattice values of the call operands. By default, calls `visitCallOperand`
480+
/// for all operands.
481+
virtual void visitExternalCall(CallOpInterface call,
482+
ArrayRef<StateT *> argumentLattices,
483+
ArrayRef<const StateT *> resultLattices) {
484+
(void)argumentLattices;
485+
(void)resultLattices;
486+
for (OpOperand &operand : call->getOpOperands()) {
487+
visitCallOperand(operand);
488+
}
489+
};
490+
447491
protected:
448492
/// Get the lattice element for a value.
449493
StateT *getLatticeElement(Value value) override {
@@ -474,6 +518,17 @@ class SparseBackwardDataFlowAnalysis
474518
{reinterpret_cast<const StateT *const *>(resultLattices.begin()),
475519
resultLattices.size()});
476520
}
521+
522+
void visitExternalCallImpl(
523+
CallOpInterface call, ArrayRef<AbstractSparseLattice *> operandLattices,
524+
ArrayRef<const AbstractSparseLattice *> resultLattices) override {
525+
visitExternalCall(
526+
call,
527+
{reinterpret_cast<StateT *const *>(operandLattices.begin()),
528+
operandLattices.size()},
529+
{reinterpret_cast<const StateT *const *>(resultLattices.begin()),
530+
resultLattices.size()});
531+
}
477532
};
478533

479534
} // end namespace dataflow

mlir/include/mlir/Analysis/DataFlowFramework.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,32 @@ struct ProgramPoint
175175
/// Forward declaration of the data-flow analysis class.
176176
class DataFlowAnalysis;
177177

178+
//===----------------------------------------------------------------------===//
179+
// DataFlowConfig
180+
//===----------------------------------------------------------------------===//
181+
182+
/// Configuration class for data flow solver and child analyses. Follows the
183+
/// fluent API pattern.
184+
class DataFlowConfig {
185+
public:
186+
DataFlowConfig() = default;
187+
188+
/// Set whether the solver should operate interpocedurally, i.e. enter the
189+
/// callee body when available. Interprocedural analyses may be more precise,
190+
/// but also more expensive as more states need to be computed and the
191+
/// fixpoint convergence takes longer.
192+
DataFlowConfig &setInterprocedural(bool enable) {
193+
interprocedural = enable;
194+
return *this;
195+
}
196+
197+
/// Return `true` if the solver operates interprocedurally, `false` otherwise.
198+
bool isInterprocedural() const { return interprocedural; }
199+
200+
private:
201+
bool interprocedural = true;
202+
};
203+
178204
//===----------------------------------------------------------------------===//
179205
// DataFlowSolver
180206
//===----------------------------------------------------------------------===//
@@ -195,6 +221,9 @@ class DataFlowAnalysis;
195221
/// TODO: Optimize the internal implementation of the solver.
196222
class DataFlowSolver {
197223
public:
224+
explicit DataFlowSolver(const DataFlowConfig &config = DataFlowConfig())
225+
: config(config) {}
226+
198227
/// Load an analysis into the solver. Return the analysis instance.
199228
template <typename AnalysisT, typename... Args>
200229
AnalysisT *load(Args &&...args);
@@ -236,7 +265,13 @@ class DataFlowSolver {
236265
/// dependent work items to the back of the queue.
237266
void propagateIfChanged(AnalysisState *state, ChangeResult changed);
238267

268+
/// Get the configuration of the solver.
269+
const DataFlowConfig &getConfig() const { return config; }
270+
239271
private:
272+
/// Configuration of the dataflow solver.
273+
DataFlowConfig config;
274+
240275
/// The solver's work queue. Work items can be inserted to the front of the
241276
/// queue to be processed greedily, speeding up computations that otherwise
242277
/// quickly degenerate to quadratic due to propagation of state updates.
@@ -423,6 +458,9 @@ class DataFlowAnalysis {
423458
return state;
424459
}
425460

461+
/// Return the configuration of the solver used for this analysis.
462+
const DataFlowConfig &getSolverConfig() const { return solver.getConfig(); }
463+
426464
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
427465
/// When compiling with debugging, keep a name for the analyis.
428466
StringRef debugName;

mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,22 @@ LogicalResult AbstractDenseForwardDataFlowAnalysis::visit(ProgramPoint point) {
5454
}
5555

5656
void AbstractDenseForwardDataFlowAnalysis::visitCallOperation(
57-
CallOpInterface call, AbstractDenseLattice *after) {
57+
CallOpInterface call, const AbstractDenseLattice &before,
58+
AbstractDenseLattice *after) {
59+
// Allow for customizing the behavior of calls to external symbols, including
60+
// when the analysis is explicitly marked as non-interprocedural.
61+
auto callable =
62+
dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
63+
if (!getSolverConfig().isInterprocedural() ||
64+
(callable && !callable.getCallableRegion())) {
65+
return visitCallControlFlowTransfer(
66+
call, CallControlFlowAction::ExternalCallee, before, after);
67+
}
5868

5969
const auto *predecessors =
6070
getOrCreateFor<PredecessorState>(call.getOperation(), call);
61-
// If not all return sites are known, then conservatively assume we can't
62-
// reason about the data-flow.
71+
// Otherwise, if not all return sites are known, then conservatively assume we
72+
// can't reason about the data-flow.
6373
if (!predecessors->allPredecessorsKnown())
6474
return setToEntryState(after);
6575

@@ -108,7 +118,7 @@ void AbstractDenseForwardDataFlowAnalysis::processOperation(Operation *op) {
108118
// If this is a call operation, then join its lattices across known return
109119
// sites.
110120
if (auto call = dyn_cast<CallOpInterface>(op))
111-
return visitCallOperation(call, after);
121+
return visitCallOperation(call, *before, after);
112122

113123
// Invoke the operation transfer function.
114124
visitOperationImpl(op, *before, after);
@@ -130,8 +140,10 @@ void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) {
130140
if (callable && callable.getCallableRegion() == block->getParent()) {
131141
const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
132142
// If not all callsites are known, conservatively mark all lattices as
133-
// having reached their pessimistic fixpoints.
134-
if (!callsites->allPredecessorsKnown())
143+
// having reached their pessimistic fixpoints. Do the same if
144+
// interprocedural analysis is not enabled.
145+
if (!callsites->allPredecessorsKnown() ||
146+
!getSolverConfig().isInterprocedural())
135147
return setToEntryState(after);
136148
for (Operation *callsite : callsites->getKnownPredecessors()) {
137149
// Get the dense lattice before the callsite.
@@ -267,18 +279,20 @@ LogicalResult AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
267279
}
268280

269281
void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation(
270-
CallOpInterface call, AbstractDenseLattice *before) {
282+
CallOpInterface call, const AbstractDenseLattice &after,
283+
AbstractDenseLattice *before) {
271284
// Find the callee.
272285
Operation *callee = call.resolveCallable(&symbolTable);
273286
auto callable = dyn_cast_or_null<CallableOpInterface>(callee);
274287
if (!callable)
275288
return setToExitState(before);
276289

277-
// No region means the callee is only declared in this module and we shouldn't
278-
// assume anything about it.
290+
// No region means the callee is only declared in this module.
279291
Region *region = callable.getCallableRegion();
280-
if (!region || region->empty())
281-
return setToExitState(before);
292+
if (!region || region->empty() || !getSolverConfig().isInterprocedural()) {
293+
return visitCallControlFlowTransfer(
294+
call, CallControlFlowAction::ExternalCallee, after, before);
295+
}
282296

283297
// Call-level control flow specifies the data flow here.
284298
//
@@ -324,7 +338,7 @@ void AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) {
324338
return visitRegionBranchOperation(op, branch, RegionBranchPoint::parent(),
325339
before);
326340
if (auto call = dyn_cast<CallOpInterface>(op))
327-
return visitCallOperation(call, before);
341+
return visitCallOperation(call, *after, before);
328342

329343
// Invoke the operation transfer function.
330344
visitOperationImpl(op, *after, before);
@@ -359,8 +373,10 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
359373
const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
360374
// If not all call sites are known, conservative mark all lattices as
361375
// having reached their pessimistic fix points.
362-
if (!callsites->allPredecessorsKnown())
376+
if (!callsites->allPredecessorsKnown() ||
377+
!getSolverConfig().isInterprocedural()) {
363378
return setToExitState(before);
379+
}
364380

365381
for (Operation *callsite : callsites->getKnownPredecessors()) {
366382
const AbstractDenseLattice *after;

0 commit comments

Comments
 (0)