Skip to content

Commit 7ac1750

Browse files
committed
[mlir][NFC] Move LLVM::ModuleTranslation::SaveStack to a shared header
This is so that we can re-use the same code in Flang.
1 parent e75e248 commit 7ac1750

File tree

6 files changed

+134
-70
lines changed

6 files changed

+134
-70
lines changed
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
//===- StateStack.h - Utility for storing a stack of state ------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines utilities for storing a stack of generic context.
10+
// The context can be arbitrary data, possibly including file-scoped types.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_SUPPORT_STACKFRAME_H
15+
#define MLIR_SUPPORT_STACKFRAME_H
16+
17+
#include "mlir/IR/Visitors.h"
18+
#include "mlir/Support/TypeID.h"
19+
#include <memory>
20+
21+
namespace mlir {
22+
23+
/// Common CRTP base class for StateStack frames.
24+
class StateStackFrame {
25+
public:
26+
virtual ~StateStackFrame() = default;
27+
TypeID getTypeID() const { return typeID; }
28+
29+
protected:
30+
explicit StateStackFrame(TypeID typeID) : typeID(typeID) {}
31+
32+
private:
33+
const TypeID typeID;
34+
virtual void anchor() {};
35+
};
36+
37+
/// Concrete CRTP base class for StateStack frames. This is used for keeping a
38+
/// stack of common state useful for recursive IR conversions. For example, when
39+
/// translating operations with regions, users of StateStack can store state on
40+
/// StateStack before entering the region and inspect it when converting
41+
/// operations nested within that region. Users are expected to derive this
42+
/// class and put any relevant information into fields of the derived class. The
43+
/// usual isa/dyn_cast functionality is available for instances of derived
44+
/// classes.
45+
template <typename Derived>
46+
class StateStackFrameBase : public StateStackFrame {
47+
public:
48+
explicit StateStackFrameBase() : StateStackFrame(TypeID::get<Derived>()) {}
49+
};
50+
51+
class StateStack {
52+
public:
53+
/// Creates a stack frame of type `T` on StateStack. `T` must
54+
/// be derived from `StackFrameBase<T>` and constructible from the provided
55+
/// arguments. Doing this before entering the region of the op being
56+
/// translated makes the frame available when translating ops within that
57+
/// region.
58+
template <typename T, typename... Args>
59+
void stackPush(Args &&...args) {
60+
static_assert(std::is_base_of<StateStackFrame, T>::value,
61+
"can only push instances of StackFrame on StateStack");
62+
stack.push_back(std::make_unique<T>(std::forward<Args>(args)...));
63+
}
64+
65+
/// Pops the last element from the StateStack.
66+
void stackPop() { stack.pop_back(); }
67+
68+
/// Calls `callback` for every StateStack frame of type `T`
69+
/// starting from the top of the stack.
70+
template <typename T>
71+
WalkResult stackWalk(llvm::function_ref<WalkResult(T &)> callback) {
72+
static_assert(std::is_base_of<StateStackFrame, T>::value,
73+
"expected T derived from StackFrame");
74+
if (!callback)
75+
return WalkResult::skip();
76+
for (std::unique_ptr<StateStackFrame> &frame : llvm::reverse(stack)) {
77+
if (T *ptr = dyn_cast_or_null<T>(frame.get())) {
78+
WalkResult result = callback(*ptr);
79+
if (result.wasInterrupted())
80+
return result;
81+
}
82+
}
83+
return WalkResult::advance();
84+
}
85+
86+
private:
87+
SmallVector<std::unique_ptr<StateStackFrame>> stack;
88+
};
89+
90+
/// RAII object calling stackPush/stackPop on construction/destruction.
91+
/// HOST_CLASS could be a StateStack or some other class which forwards calls to
92+
/// one.
93+
template <typename T, typename HOST_CLASS>
94+
struct SaveStateStack {
95+
template <typename... Args>
96+
explicit SaveStateStack(HOST_CLASS &host, Args &&...args) : host(host) {
97+
host.template stackPush<T>(std::forward<Args>(args)...);
98+
}
99+
~SaveStateStack() { host.stackPop(); }
100+
101+
private:
102+
HOST_CLASS &host;
103+
};
104+
105+
} // namespace mlir
106+
107+
namespace llvm {
108+
template <typename T>
109+
struct isa_impl<T, ::mlir::StateStackFrame> {
110+
static inline bool doit(const ::mlir::StateStackFrame &frame) {
111+
return frame.getTypeID() == ::mlir::TypeID::get<T>();
112+
}
113+
};
114+
} // namespace llvm
115+
116+
#endif // MLIR_SUPPORT_STACKFRAME_H

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 6 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/IR/Operation.h"
1919
#include "mlir/IR/SymbolTable.h"
2020
#include "mlir/IR/Value.h"
21+
#include "mlir/Support/StateStack.h"
2122
#include "mlir/Target/LLVMIR/Export.h"
2223
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
2324
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
@@ -271,80 +272,29 @@ class ModuleTranslation {
271272
/// it if it does not exist.
272273
llvm::NamedMDNode *getOrInsertNamedModuleMetadata(StringRef name);
273274

274-
/// Common CRTP base class for ModuleTranslation stack frames.
275-
class StackFrame {
276-
public:
277-
virtual ~StackFrame() = default;
278-
TypeID getTypeID() const { return typeID; }
279-
280-
protected:
281-
explicit StackFrame(TypeID typeID) : typeID(typeID) {}
282-
283-
private:
284-
const TypeID typeID;
285-
virtual void anchor();
286-
};
287-
288-
/// Concrete CRTP base class for ModuleTranslation stack frames. When
289-
/// translating operations with regions, users of ModuleTranslation can store
290-
/// state on ModuleTranslation stack before entering the region and inspect
291-
/// it when converting operations nested within that region. Users are
292-
/// expected to derive this class and put any relevant information into fields
293-
/// of the derived class. The usual isa/dyn_cast functionality is available
294-
/// for instances of derived classes.
295-
template <typename Derived>
296-
class StackFrameBase : public StackFrame {
297-
public:
298-
explicit StackFrameBase() : StackFrame(TypeID::get<Derived>()) {}
299-
};
300-
301275
/// Creates a stack frame of type `T` on ModuleTranslation stack. `T` must
302276
/// be derived from `StackFrameBase<T>` and constructible from the provided
303277
/// arguments. Doing this before entering the region of the op being
304278
/// translated makes the frame available when translating ops within that
305279
/// region.
306280
template <typename T, typename... Args>
307281
void stackPush(Args &&...args) {
308-
static_assert(
309-
std::is_base_of<StackFrame, T>::value,
310-
"can only push instances of StackFrame on ModuleTranslation stack");
311-
stack.push_back(std::make_unique<T>(std::forward<Args>(args)...));
282+
stack.stackPush<T>(std::forward<Args>(args)...);
312283
}
313284

314285
/// Pops the last element from the ModuleTranslation stack.
315-
void stackPop() { stack.pop_back(); }
286+
void stackPop() { stack.stackPop(); }
316287

317288
/// Calls `callback` for every ModuleTranslation stack frame of type `T`
318289
/// starting from the top of the stack.
319290
template <typename T>
320291
WalkResult stackWalk(llvm::function_ref<WalkResult(T &)> callback) {
321-
static_assert(std::is_base_of<StackFrame, T>::value,
322-
"expected T derived from StackFrame");
323-
if (!callback)
324-
return WalkResult::skip();
325-
for (std::unique_ptr<StackFrame> &frame : llvm::reverse(stack)) {
326-
if (T *ptr = dyn_cast_or_null<T>(frame.get())) {
327-
WalkResult result = callback(*ptr);
328-
if (result.wasInterrupted())
329-
return result;
330-
}
331-
}
332-
return WalkResult::advance();
292+
return stack.stackWalk(callback);
333293
}
334294

335295
/// RAII object calling stackPush/stackPop on construction/destruction.
336296
template <typename T>
337-
struct SaveStack {
338-
template <typename... Args>
339-
explicit SaveStack(ModuleTranslation &m, Args &&...args)
340-
: moduleTranslation(m) {
341-
moduleTranslation.stackPush<T>(std::forward<Args>(args)...);
342-
}
343-
~SaveStack() { moduleTranslation.stackPop(); }
344-
345-
private:
346-
ModuleTranslation &moduleTranslation;
347-
};
297+
using SaveStack = SaveStateStack<T, ModuleTranslation>;
348298

349299
SymbolTableCollection &symbolTable() { return symbolTableCollection; }
350300

@@ -468,7 +418,7 @@ class ModuleTranslation {
468418

469419
/// Stack of user-specified state elements, useful when translating operations
470420
/// with regions.
471-
SmallVector<std::unique_ptr<StackFrame>> stack;
421+
StateStack stack;
472422

473423
/// A cache for the symbol tables constructed during symbols lookup.
474424
SymbolTableCollection symbolTableCollection;
@@ -510,14 +460,4 @@ llvm::CallInst *createIntrinsicCall(
510460
} // namespace LLVM
511461
} // namespace mlir
512462

513-
namespace llvm {
514-
template <typename T>
515-
struct isa_impl<T, ::mlir::LLVM::ModuleTranslation::StackFrame> {
516-
static inline bool
517-
doit(const ::mlir::LLVM::ModuleTranslation::StackFrame &frame) {
518-
return frame.getTypeID() == ::mlir::TypeID::get<T>();
519-
}
520-
};
521-
} // namespace llvm
522-
523463
#endif // MLIR_TARGET_LLVMIR_MODULETRANSLATION_H

mlir/lib/Support/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_mlir_library(MLIRSupport
1111
FileUtilities.cpp
1212
InterfaceSupport.cpp
1313
RawOstreamExtras.cpp
14+
StateStack.cpp
1415
StorageUniquer.cpp
1516
Timing.cpp
1617
ToolUtilities.cpp

mlir/lib/Support/StateStack.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//===- StateStack.cpp - Utility for storing a stack of state --------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Support/StateStack.h"

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
7171
/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
7272
/// insertion points for allocas.
7373
class OpenMPAllocaStackFrame
74-
: public LLVM::ModuleTranslation::StackFrameBase<OpenMPAllocaStackFrame> {
74+
: public StateStackFrameBase<OpenMPAllocaStackFrame> {
7575
public:
7676
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPAllocaStackFrame)
7777

@@ -84,7 +84,7 @@ class OpenMPAllocaStackFrame
8484
/// collapsed canonical loop information corresponding to an \c omp.loop_nest
8585
/// operation.
8686
class OpenMPLoopInfoStackFrame
87-
: public LLVM::ModuleTranslation::StackFrameBase<OpenMPLoopInfoStackFrame> {
87+
: public StateStackFrameBase<OpenMPLoopInfoStackFrame> {
8888
public:
8989
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
9090
llvm::CanonicalLoopInfo *loopInfo = nullptr;

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2225,8 +2225,6 @@ ModuleTranslation::getOrInsertNamedModuleMetadata(StringRef name) {
22252225
return llvmModule->getOrInsertNamedMetadata(name);
22262226
}
22272227

2228-
void ModuleTranslation::StackFrame::anchor() {}
2229-
22302228
static std::unique_ptr<llvm::Module>
22312229
prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
22322230
StringRef name) {

0 commit comments

Comments
 (0)