Skip to content

Commit cc75671

Browse files
authored
[mlir][NFC] Move LLVM::ModuleTranslation::SaveStack to a shared header (#144897)
This is so that we can re-use the same code in Flang.
1 parent f6973ba commit cc75671

File tree

6 files changed

+141
-70
lines changed

6 files changed

+141
-70
lines changed
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
//===- StateStack.h - Utility for storing a stack of state ------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines utilities for storing a stack of generic context.
10+
// The context can be arbitrary data, possibly including file-scoped types. Data
11+
// must be derived from StateStackFrameBase and implement MLIR TypeID.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#ifndef MLIR_SUPPORT_STACKFRAME_H
16+
#define MLIR_SUPPORT_STACKFRAME_H
17+
18+
#include "mlir/IR/Visitors.h"
19+
#include "mlir/Support/TypeID.h"
20+
#include <memory>
21+
22+
namespace mlir {
23+
24+
/// Common CRTP base class for StateStack frames.
25+
class StateStackFrame {
26+
public:
27+
virtual ~StateStackFrame() = default;
28+
TypeID getTypeID() const { return typeID; }
29+
30+
protected:
31+
explicit StateStackFrame(TypeID typeID) : typeID(typeID) {}
32+
33+
private:
34+
const TypeID typeID;
35+
virtual void anchor();
36+
};
37+
38+
/// Concrete CRTP base class for StateStack frames. This is used for keeping a
39+
/// stack of common state useful for recursive IR conversions. For example, when
40+
/// translating operations with regions, users of StateStack can store state on
41+
/// StateStack before entering the region and inspect it when converting
42+
/// operations nested within that region. Users are expected to derive this
43+
/// class and put any relevant information into fields of the derived class. The
44+
/// usual isa/dyn_cast functionality is available for instances of derived
45+
/// classes.
46+
template <typename Derived>
47+
class StateStackFrameBase : public StateStackFrame {
48+
public:
49+
explicit StateStackFrameBase() : StateStackFrame(TypeID::get<Derived>()) {}
50+
};
51+
52+
class StateStack {
53+
public:
54+
/// Creates a stack frame of type `T` on StateStack. `T` must
55+
/// be derived from `StackFrameBase<T>` and constructible from the provided
56+
/// arguments. Doing this before entering the region of the op being
57+
/// translated makes the frame available when translating ops within that
58+
/// region.
59+
template <typename T, typename... Args>
60+
void stackPush(Args &&...args) {
61+
static_assert(std::is_base_of<StateStackFrame, T>::value,
62+
"can only push instances of StackFrame on StateStack");
63+
stack.push_back(std::make_unique<T>(std::forward<Args>(args)...));
64+
}
65+
66+
/// Pops the last element from the StateStack.
67+
void stackPop() { stack.pop_back(); }
68+
69+
/// Calls `callback` for every StateStack frame of type `T`
70+
/// starting from the top of the stack.
71+
template <typename T>
72+
WalkResult stackWalk(llvm::function_ref<WalkResult(T &)> callback) {
73+
static_assert(std::is_base_of<StateStackFrame, T>::value,
74+
"expected T derived from StackFrame");
75+
if (!callback)
76+
return WalkResult::skip();
77+
for (std::unique_ptr<StateStackFrame> &frame : llvm::reverse(stack)) {
78+
if (T *ptr = dyn_cast_or_null<T>(frame.get())) {
79+
WalkResult result = callback(*ptr);
80+
if (result.wasInterrupted())
81+
return result;
82+
}
83+
}
84+
return WalkResult::advance();
85+
}
86+
87+
private:
88+
SmallVector<std::unique_ptr<StateStackFrame>> stack;
89+
};
90+
91+
/// RAII object calling stackPush/stackPop on construction/destruction.
92+
/// HostClass could be a StateStack or some other class which forwards calls to
93+
/// one.
94+
template <typename T, typename HostClass = StateStack>
95+
struct SaveStateStack {
96+
template <typename... Args>
97+
explicit SaveStateStack(HostClass &host, Args &&...args) : host(host) {
98+
host.template stackPush<T>(std::forward<Args>(args)...);
99+
}
100+
~SaveStateStack() { host.stackPop(); }
101+
102+
private:
103+
HostClass &host;
104+
};
105+
106+
} // namespace mlir
107+
108+
namespace llvm {
109+
template <typename T>
110+
struct isa_impl<T, ::mlir::StateStackFrame> {
111+
static inline bool doit(const ::mlir::StateStackFrame &frame) {
112+
return frame.getTypeID() == ::mlir::TypeID::get<T>();
113+
}
114+
};
115+
} // namespace llvm
116+
117+
#endif // MLIR_SUPPORT_STACKFRAME_H

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 6 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/IR/Operation.h"
1919
#include "mlir/IR/SymbolTable.h"
2020
#include "mlir/IR/Value.h"
21+
#include "mlir/Support/StateStack.h"
2122
#include "mlir/Target/LLVMIR/Export.h"
2223
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
2324
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
@@ -271,80 +272,29 @@ class ModuleTranslation {
271272
/// it if it does not exist.
272273
llvm::NamedMDNode *getOrInsertNamedModuleMetadata(StringRef name);
273274

274-
/// Common CRTP base class for ModuleTranslation stack frames.
275-
class StackFrame {
276-
public:
277-
virtual ~StackFrame() = default;
278-
TypeID getTypeID() const { return typeID; }
279-
280-
protected:
281-
explicit StackFrame(TypeID typeID) : typeID(typeID) {}
282-
283-
private:
284-
const TypeID typeID;
285-
virtual void anchor();
286-
};
287-
288-
/// Concrete CRTP base class for ModuleTranslation stack frames. When
289-
/// translating operations with regions, users of ModuleTranslation can store
290-
/// state on ModuleTranslation stack before entering the region and inspect
291-
/// it when converting operations nested within that region. Users are
292-
/// expected to derive this class and put any relevant information into fields
293-
/// of the derived class. The usual isa/dyn_cast functionality is available
294-
/// for instances of derived classes.
295-
template <typename Derived>
296-
class StackFrameBase : public StackFrame {
297-
public:
298-
explicit StackFrameBase() : StackFrame(TypeID::get<Derived>()) {}
299-
};
300-
301275
/// Creates a stack frame of type `T` on ModuleTranslation stack. `T` must
302276
/// be derived from `StackFrameBase<T>` and constructible from the provided
303277
/// arguments. Doing this before entering the region of the op being
304278
/// translated makes the frame available when translating ops within that
305279
/// region.
306280
template <typename T, typename... Args>
307281
void stackPush(Args &&...args) {
308-
static_assert(
309-
std::is_base_of<StackFrame, T>::value,
310-
"can only push instances of StackFrame on ModuleTranslation stack");
311-
stack.push_back(std::make_unique<T>(std::forward<Args>(args)...));
282+
stack.stackPush<T>(std::forward<Args>(args)...);
312283
}
313284

314285
/// Pops the last element from the ModuleTranslation stack.
315-
void stackPop() { stack.pop_back(); }
286+
void stackPop() { stack.stackPop(); }
316287

317288
/// Calls `callback` for every ModuleTranslation stack frame of type `T`
318289
/// starting from the top of the stack.
319290
template <typename T>
320291
WalkResult stackWalk(llvm::function_ref<WalkResult(T &)> callback) {
321-
static_assert(std::is_base_of<StackFrame, T>::value,
322-
"expected T derived from StackFrame");
323-
if (!callback)
324-
return WalkResult::skip();
325-
for (std::unique_ptr<StackFrame> &frame : llvm::reverse(stack)) {
326-
if (T *ptr = dyn_cast_or_null<T>(frame.get())) {
327-
WalkResult result = callback(*ptr);
328-
if (result.wasInterrupted())
329-
return result;
330-
}
331-
}
332-
return WalkResult::advance();
292+
return stack.stackWalk(callback);
333293
}
334294

335295
/// RAII object calling stackPush/stackPop on construction/destruction.
336296
template <typename T>
337-
struct SaveStack {
338-
template <typename... Args>
339-
explicit SaveStack(ModuleTranslation &m, Args &&...args)
340-
: moduleTranslation(m) {
341-
moduleTranslation.stackPush<T>(std::forward<Args>(args)...);
342-
}
343-
~SaveStack() { moduleTranslation.stackPop(); }
344-
345-
private:
346-
ModuleTranslation &moduleTranslation;
347-
};
297+
using SaveStack = SaveStateStack<T, ModuleTranslation>;
348298

349299
SymbolTableCollection &symbolTable() { return symbolTableCollection; }
350300

@@ -468,7 +418,7 @@ class ModuleTranslation {
468418

469419
/// Stack of user-specified state elements, useful when translating operations
470420
/// with regions.
471-
SmallVector<std::unique_ptr<StackFrame>> stack;
421+
StateStack stack;
472422

473423
/// A cache for the symbol tables constructed during symbols lookup.
474424
SymbolTableCollection symbolTableCollection;
@@ -510,14 +460,4 @@ llvm::CallInst *createIntrinsicCall(
510460
} // namespace LLVM
511461
} // namespace mlir
512462

513-
namespace llvm {
514-
template <typename T>
515-
struct isa_impl<T, ::mlir::LLVM::ModuleTranslation::StackFrame> {
516-
static inline bool
517-
doit(const ::mlir::LLVM::ModuleTranslation::StackFrame &frame) {
518-
return frame.getTypeID() == ::mlir::TypeID::get<T>();
519-
}
520-
};
521-
} // namespace llvm
522-
523463
#endif // MLIR_TARGET_LLVMIR_MODULETRANSLATION_H

mlir/lib/Support/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_mlir_library(MLIRSupport
1111
FileUtilities.cpp
1212
InterfaceSupport.cpp
1313
RawOstreamExtras.cpp
14+
StateStack.cpp
1415
StorageUniquer.cpp
1516
Timing.cpp
1617
ToolUtilities.cpp

mlir/lib/Support/StateStack.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//===- StateStack.cpp - Utility for storing a stack of state --------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Support/StateStack.h"
10+
11+
namespace mlir {
12+
13+
void StateStackFrame::anchor() {}
14+
15+
} // namespace mlir

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
7171
/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
7272
/// insertion points for allocas.
7373
class OpenMPAllocaStackFrame
74-
: public LLVM::ModuleTranslation::StackFrameBase<OpenMPAllocaStackFrame> {
74+
: public StateStackFrameBase<OpenMPAllocaStackFrame> {
7575
public:
7676
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPAllocaStackFrame)
7777

@@ -84,7 +84,7 @@ class OpenMPAllocaStackFrame
8484
/// collapsed canonical loop information corresponding to an \c omp.loop_nest
8585
/// operation.
8686
class OpenMPLoopInfoStackFrame
87-
: public LLVM::ModuleTranslation::StackFrameBase<OpenMPLoopInfoStackFrame> {
87+
: public StateStackFrameBase<OpenMPLoopInfoStackFrame> {
8888
public:
8989
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
9090
llvm::CanonicalLoopInfo *loopInfo = nullptr;

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2225,8 +2225,6 @@ ModuleTranslation::getOrInsertNamedModuleMetadata(StringRef name) {
22252225
return llvmModule->getOrInsertNamedMetadata(name);
22262226
}
22272227

2228-
void ModuleTranslation::StackFrame::anchor() {}
2229-
22302228
static std::unique_ptr<llvm::Module>
22312229
prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
22322230
StringRef name) {

0 commit comments

Comments
 (0)