Skip to content

Commit 485746f

Browse files
River707jpienaar
authored andcommitted
Implement the initial AnalysisManagement infrastructure, with the introduction of the FunctionAnalysisManager and ModuleAnalysisManager classes. These classes provide analysis computation, caching, and invalidation for a specific IR unit. The invalidation is currently limited to either all or none, i.e. you cannot yet preserve specific analyses.
An analysis can be any class, but it must provide the following: * A constructor for a given IR unit. struct MyAnalysis { // Compute this analysis with the provided module. MyAnalysis(Module *module); }; Analyses can be accessed from a Pass by calling either the 'getAnalysisResult<AnalysisT>' or 'getCachedAnalysisResult<AnalysisT>' methods. A FunctionPass may query for a cached analysis on the parent module with 'getCachedModuleAnalysisResult'. Similary, a ModulePass may query an analysis, it doesn't need to be cached, on a child function with 'getFunctionAnalysisResult'. By default, when running a pass all cached analyses are set to be invalidated. If no transformation was performed, a pass can use the method 'markAllAnalysesPreserved' to preserve all analysis results. As noted above, preserving specific analyses is not yet supported. PiperOrigin-RevId: 236505642
1 parent c1b02a1 commit 485746f

File tree

5 files changed

+360
-42
lines changed

5 files changed

+360
-42
lines changed
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
//===- AnalysisManager.h - Analysis Management Infrastructure ---*- C++ -*-===//
2+
//
3+
// Copyright 2019 The MLIR Authors.
4+
//
5+
// Licensed under the Apache License, Version 2.0 (the "License");
6+
// you may not use this file except in compliance with the License.
7+
// You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing, software
12+
// distributed under the License is distributed on an "AS IS" BASIS,
13+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
// See the License for the specific language governing permissions and
15+
// limitations under the License.
16+
// =============================================================================
17+
18+
#ifndef MLIR_PASS_ANALYSISMANAGER_H
19+
#define MLIR_PASS_ANALYSISMANAGER_H
20+
21+
#include "mlir/IR/Module.h"
22+
#include "mlir/Support/LLVM.h"
23+
#include "llvm/ADT/DenseMap.h"
24+
#include "llvm/ADT/SmallPtrSet.h"
25+
26+
namespace mlir {
27+
/// A special type used by analyses to provide an address that identifies a
28+
/// particular analysis set or a concrete analysis type.
29+
struct AnalysisID {
30+
template <typename AnalysisT> static AnalysisID *getID() {
31+
static AnalysisID id;
32+
return &id;
33+
}
34+
};
35+
36+
//===----------------------------------------------------------------------===//
37+
// Analysis Preservation and Result Modeling
38+
//===----------------------------------------------------------------------===//
39+
40+
namespace detail {
41+
/// A utility class to represent the analyses that are known to be preserved.
42+
class PreservedAnalyses {
43+
public:
44+
/// Mark all analyses as preserved.
45+
void preserveAll() { preservedIDs.insert(&allAnalysesID); }
46+
47+
/// Returns if all analyses were marked preserved.
48+
bool isAll() const { return preservedIDs.count(&allAnalysesID); }
49+
50+
private:
51+
/// An identifier used to represent all potential analyses.
52+
constexpr static AnalysisID allAnalysesID = {};
53+
54+
/// The set of analyses that are known to be preserved.
55+
SmallPtrSet<const void *, 2> preservedIDs;
56+
};
57+
58+
/// The abstract polymorphic base class representing an analysis.
59+
struct AnalysisConcept {
60+
virtual ~AnalysisConcept() = default;
61+
};
62+
63+
/// A derived analysis model used to hold a specific analysis object.
64+
template <typename AnalysisT> struct AnalysisModel : public AnalysisConcept {
65+
template <typename... Args>
66+
explicit AnalysisModel(Args &&... args)
67+
: analysis(std::forward<Args>(args)...) {}
68+
69+
AnalysisT analysis;
70+
};
71+
72+
/// This class represents a cache of analysis results for a single IR unit. All
73+
/// computation, caching, and invalidation of analyses takes place here.
74+
template <typename IRUnitT> class AnalysisResultMap {
75+
/// A mapping between an analysis id and an existing analysis instance.
76+
using ResultMap =
77+
DenseMap<const AnalysisID *, std::unique_ptr<AnalysisConcept>>;
78+
79+
public:
80+
explicit AnalysisResultMap(IRUnitT *ir) : ir(ir) {}
81+
82+
/// Get an analysis for the current IR unit, computing it if necessary.
83+
template <typename AnalysisT> AnalysisT &getResult() {
84+
typename ResultMap::iterator it;
85+
bool wasInserted;
86+
std::tie(it, wasInserted) =
87+
results.try_emplace(AnalysisID::getID<AnalysisT>());
88+
89+
// If we don't have a cached result for this function, compute it directly
90+
// and add it to the cache.
91+
if (wasInserted)
92+
it->second = llvm::make_unique<AnalysisModel<AnalysisT>>(ir);
93+
return static_cast<AnalysisModel<AnalysisT> &>(*it->second).analysis;
94+
}
95+
96+
/// Get a cached analysis instance if one exists, otherwise return null.
97+
template <typename AnalysisT>
98+
llvm::Optional<std::reference_wrapper<AnalysisT>> getCachedResult() const {
99+
auto res = results.find(AnalysisID::getID<AnalysisT>());
100+
if (res == results.end())
101+
return llvm::None;
102+
return {static_cast<AnalysisModel<AnalysisT> &>(*res->second).analysis};
103+
}
104+
105+
/// Returns the IR unit that this result map represents.
106+
IRUnitT *getIRUnit() { return ir; }
107+
const IRUnitT *getIRUnit() const { return ir; }
108+
109+
/// Clear any held analysis results.
110+
void clear() { results.clear(); }
111+
112+
/// Invalidate any cached analyses based upon the given set of preserved
113+
/// analyses.
114+
void invalidate(const detail::PreservedAnalyses &pa) {
115+
// If all analyses were preserved, then there is nothing to do here.
116+
if (pa.isAll())
117+
return;
118+
// TODO: Fine grain invalidation of analyses.
119+
clear();
120+
}
121+
122+
private:
123+
IRUnitT *ir;
124+
ResultMap results;
125+
};
126+
127+
} // namespace detail
128+
129+
//===----------------------------------------------------------------------===//
130+
// Analysis Management
131+
//===----------------------------------------------------------------------===//
132+
133+
/// An analysis manager for a specific function instance. This class can only be
134+
/// constructed from a ModuleAnalysisManager instance.
135+
class FunctionAnalysisManager {
136+
public:
137+
// Query for a cached analysis on the parent Module. The analysis may not
138+
// exist and if it does it may be stale.
139+
template <typename AnalysisT>
140+
llvm::Optional<std::reference_wrapper<AnalysisT>>
141+
getCachedModuleResult() const {
142+
return parentImpl->getCachedResult<AnalysisT>();
143+
}
144+
145+
// Query for the given analysis for the current function.
146+
template <typename AnalysisT> AnalysisT &getResult() {
147+
return impl->getResult<AnalysisT>();
148+
}
149+
150+
// Query for a cached entry of the given analysis on the current function.
151+
template <typename AnalysisT>
152+
llvm::Optional<std::reference_wrapper<AnalysisT>> getCachedResult() const {
153+
return impl->getCachedResult<AnalysisT>();
154+
}
155+
156+
/// Invalidate any non preserved analyses,
157+
void invalidate(const detail::PreservedAnalyses &pa) { impl->invalidate(pa); }
158+
159+
/// Clear any held analyses.
160+
void clear() { impl->clear(); }
161+
162+
private:
163+
FunctionAnalysisManager(const detail::AnalysisResultMap<Module> *parentImpl,
164+
detail::AnalysisResultMap<Function> *impl)
165+
: parentImpl(parentImpl), impl(impl) {}
166+
167+
/// A reference to the results map of the parent module within the owning
168+
/// analysis manager.
169+
const detail::AnalysisResultMap<Module> *parentImpl;
170+
171+
/// A reference to the results map within the owning analysis manager.
172+
detail::AnalysisResultMap<Function> *impl;
173+
174+
/// Allow access to the constructor.
175+
friend class ModuleAnalysisManager;
176+
};
177+
178+
/// An analysis manager for a specific module instance.
179+
class ModuleAnalysisManager {
180+
public:
181+
ModuleAnalysisManager(Module *module) : moduleAnalyses(module) {}
182+
ModuleAnalysisManager(const ModuleAnalysisManager &) = delete;
183+
ModuleAnalysisManager &operator=(const ModuleAnalysisManager &) = delete;
184+
185+
/// Query for the analysis of a function. The analysis is computed if it does
186+
/// not exist.
187+
template <typename AnalysisT>
188+
AnalysisT &getFunctionResult(Function *function) {
189+
return slice(function).getResult<AnalysisT>();
190+
}
191+
192+
/// Query for the analysis of a module. The analysis is computed if it does
193+
/// not exist.
194+
template <typename AnalysisT> AnalysisT &getResult() {
195+
return moduleAnalyses.getResult<AnalysisT>();
196+
}
197+
198+
/// Query for a cached analysis for the module, or return nullptr.
199+
template <typename AnalysisT>
200+
llvm::Optional<std::reference_wrapper<AnalysisT>> getCachedResult() const {
201+
return moduleAnalyses.getCachedResult<AnalysisT>();
202+
}
203+
204+
/// Create an analysis slice for the given child function.
205+
FunctionAnalysisManager slice(Function *function);
206+
207+
/// Invalidate any non preserved analyses.
208+
void invalidate(const detail::PreservedAnalyses &pa);
209+
210+
private:
211+
/// The cached analyses for functions within the current module.
212+
DenseMap<Function *, detail::AnalysisResultMap<Function>> functionAnalyses;
213+
214+
/// The analyses for the owning module.
215+
detail::AnalysisResultMap<Module> moduleAnalyses;
216+
};
217+
218+
} // end namespace mlir
219+
220+
#endif // MLIR_PASS_ANALYSISMANAGER_H

mlir/include/mlir/Pass/Pass.h

Lines changed: 68 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,11 @@
1818
#ifndef MLIR_PASS_PASS_H
1919
#define MLIR_PASS_PASS_H
2020

21-
#include "mlir/IR/Module.h"
21+
#include "mlir/Pass/AnalysisManager.h"
2222
#include "mlir/Pass/PassRegistry.h"
2323
#include "llvm/ADT/PointerIntPair.h"
2424

2525
namespace mlir {
26-
class Function;
27-
class Module;
28-
2926
/// The abstract base pass class. This class contains information describing the
3027
/// derived pass object, e.g its kind and abstract PassInfo.
3128
class Pass {
@@ -67,18 +64,30 @@ class ModulePassExecutor;
6764

6865
/// The state for a single execution of a pass. This provides a unified
6966
/// interface for accessing and initializing necessary state for pass execution.
70-
template <typename IRUnitT> struct PassExecutionState {
71-
explicit PassExecutionState(IRUnitT *ir) : irAndPassFailed(ir, false) {}
67+
template <typename IRUnitT, typename AnalysisManagerT>
68+
struct PassExecutionState {
69+
PassExecutionState(IRUnitT *ir, AnalysisManagerT &analysisManager)
70+
: irAndPassFailed(ir, false), analysisManager(analysisManager) {}
7271

73-
/// The current IR unit being transformed.
72+
/// The current IR unit being transformed and a bool for if the pass signaled
73+
/// a failure.
7474
llvm::PointerIntPair<IRUnitT *, 1, bool> irAndPassFailed;
75+
76+
/// The analysis manager for the IR unit.
77+
AnalysisManagerT &analysisManager;
78+
79+
/// The set of preserved analyses for the current execution.
80+
detail::PreservedAnalyses preservedAnalyses;
7581
};
7682
} // namespace detail
7783

7884
/// Pass to transform a specific function within a module. Derived passes should
7985
/// not inherit from this class directly, and instead should use the CRTP
8086
/// FunctionPass class.
8187
class FunctionPassBase : public Pass {
88+
using PassStateT =
89+
detail::PassExecutionState<Function, FunctionAnalysisManager>;
90+
8291
public:
8392
static bool classof(const Pass *pass) {
8493
return pass->getKind() == Kind::FunctionPass;
@@ -96,19 +105,24 @@ class FunctionPassBase : public Pass {
96105
}
97106

98107
/// Returns the current pass state.
99-
detail::PassExecutionState<Function> &getPassState() {
108+
PassStateT &getPassState() {
100109
assert(passState && "pass state was never initialized");
101110
return *passState;
102111
}
103112

113+
/// Returns the current analysis manager.
114+
FunctionAnalysisManager &getAnalysisManager() {
115+
return getPassState().analysisManager;
116+
}
117+
104118
private:
105119
/// Forwarding function to execute this pass. Returns false if the pass
106120
/// execution failed, true otherwise.
107121
LLVM_NODISCARD
108-
bool run(Function *fn);
122+
bool run(Function *fn, FunctionAnalysisManager &fam);
109123

110124
/// The current execution state for the pass.
111-
llvm::Optional<detail::PassExecutionState<Function>> passState;
125+
llvm::Optional<PassStateT> passState;
112126

113127
/// Allow access to 'run'.
114128
friend detail::FunctionPassExecutor;
@@ -117,6 +131,8 @@ class FunctionPassBase : public Pass {
117131
/// Pass to transform a module. Derived passes should not inherit from this
118132
/// class directly, and instead should use the CRTP ModulePass class.
119133
class ModulePassBase : public Pass {
134+
using PassStateT = detail::PassExecutionState<Module, ModuleAnalysisManager>;
135+
120136
public:
121137
static bool classof(const Pass *pass) {
122138
return pass->getKind() == Kind::ModulePass;
@@ -132,19 +148,24 @@ class ModulePassBase : public Pass {
132148
Module &getModule() { return *getPassState().irAndPassFailed.getPointer(); }
133149

134150
/// Returns the current pass state.
135-
detail::PassExecutionState<Module> &getPassState() {
151+
PassStateT &getPassState() {
136152
assert(passState && "pass state was never initialized");
137153
return *passState;
138154
}
139155

156+
/// Returns the current analysis manager.
157+
ModuleAnalysisManager &getAnalysisManager() {
158+
return getPassState().analysisManager;
159+
}
160+
140161
private:
141162
/// Forwarding function to execute this pass. Returns false if the pass
142163
/// execution failed, true otherwise.
143164
LLVM_NODISCARD
144-
bool run(Module *module);
165+
bool run(Module *module, ModuleAnalysisManager &mam);
145166

146167
/// The current execution state for the pass.
147-
llvm::Optional<detail::PassExecutionState<Module>> passState;
168+
llvm::Optional<PassStateT> passState;
148169

149170
/// Allow access to 'run'.
150171
friend detail::ModulePassExecutor;
@@ -162,13 +183,30 @@ class PassModel : public BasePassT {
162183
PassModel() : BasePassT(PassID::getID<PassT>()) {}
163184

164185
/// TODO(riverriddle) Provide additional utilities for cloning, getting the
165-
/// derived class name, etc..
186+
/// derived class name, etc.
166187

167188
/// Signal that some invariant was broken when running. The IR is allowed to
168189
/// be in an invalid state.
169190
void signalPassFailure() {
170191
this->getPassState().irAndPassFailed.setInt(true);
171192
}
193+
194+
/// Query the result of an analysis for the current ir unit.
195+
template <typename AnalysisT> AnalysisT &getAnalysisResult() {
196+
return this->getAnalysisManager().template getResult<AnalysisT>();
197+
}
198+
199+
/// Query the cached result of an analysis for the current ir unit if one
200+
/// exists.
201+
template <typename AnalysisT>
202+
llvm::Optional<std::reference_wrapper<AnalysisT>> getCachedAnalysisResult() {
203+
return this->getAnalysisManager().template getCachedResult<AnalysisT>();
204+
}
205+
206+
/// Mark all analyses as preserved.
207+
void markAllAnalysesPreserved() {
208+
this->getPassState().preservedAnalyses.preserveAll();
209+
}
172210
};
173211
} // end namespace detail
174212

@@ -183,14 +221,28 @@ class PassModel : public BasePassT {
183221
/// Derived function passes are expected to provide the following:
184222
/// - A 'void runOnFunction()' method.
185223
template <typename T>
186-
using FunctionPass = detail::PassModel<Function, T, FunctionPassBase>;
224+
struct FunctionPass : public detail::PassModel<Function, T, FunctionPassBase> {
225+
/// Returns the analysis result for the parent module if it exists.
226+
template <typename AnalysisT>
227+
llvm::Optional<std::reference_wrapper<AnalysisT>>
228+
getCachedModuleAnalysisResult() {
229+
return this->getAnalysisManager()
230+
.template getCachedModuleResult<AnalysisT>();
231+
}
232+
};
187233

188234
/// A model for providing module pass specific utilities.
189235
///
190236
/// Derived module passes are expected to provide the following:
191237
/// - A 'void runOnModule()' method.
192238
template <typename T>
193-
using ModulePass = detail::PassModel<Module, T, ModulePassBase>;
239+
struct ModulePass : public detail::PassModel<Module, T, ModulePassBase> {
240+
/// Returns the analysis result for a child function.
241+
template <typename AnalysisT>
242+
AnalysisT &getFunctionAnalysisResult(Function *f) {
243+
return this->getAnalysisManager().template getFunctionResult<AnalysisT>(f);
244+
}
245+
};
194246
} // end namespace mlir
195247

196248
#endif // MLIR_PASS_PASS_H

0 commit comments

Comments
 (0)