Skip to content

[mlir][NFC] Move LLVM::ModuleTranslation::SaveStack to a shared header #144897

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions mlir/include/mlir/Support/StateStack.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
//===- StateStack.h - Utility for storing a stack of state ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines utilities for storing a stack of generic context.
// The context can be arbitrary data, possibly including file-scoped types. Data
// must be derived from StateStackFrameBase and implement MLIR TypeID.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_SUPPORT_STACKFRAME_H
#define MLIR_SUPPORT_STACKFRAME_H

#include "mlir/IR/Visitors.h"
#include "mlir/Support/TypeID.h"
#include <memory>

namespace mlir {

/// Common CRTP base class for StateStack frames.
class StateStackFrame {
public:
virtual ~StateStackFrame() = default;
TypeID getTypeID() const { return typeID; }

protected:
explicit StateStackFrame(TypeID typeID) : typeID(typeID) {}

private:
const TypeID typeID;
virtual void anchor();
};

/// Concrete CRTP base class for StateStack frames. This is used for keeping a
/// stack of common state useful for recursive IR conversions. For example, when
/// translating operations with regions, users of StateStack can store state on
/// StateStack before entering the region and inspect it when converting
/// operations nested within that region. Users are expected to derive this
/// class and put any relevant information into fields of the derived class. The
/// usual isa/dyn_cast functionality is available for instances of derived
/// classes.
template <typename Derived>
class StateStackFrameBase : public StateStackFrame {
public:
explicit StateStackFrameBase() : StateStackFrame(TypeID::get<Derived>()) {}
};

class StateStack {
public:
/// Creates a stack frame of type `T` on StateStack. `T` must
/// be derived from `StackFrameBase<T>` and constructible from the provided
/// arguments. Doing this before entering the region of the op being
/// translated makes the frame available when translating ops within that
/// region.
template <typename T, typename... Args>
void stackPush(Args &&...args) {
static_assert(std::is_base_of<StateStackFrame, T>::value,
"can only push instances of StackFrame on StateStack");
stack.push_back(std::make_unique<T>(std::forward<Args>(args)...));
}

/// Pops the last element from the StateStack.
void stackPop() { stack.pop_back(); }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't there value in being, separately, able to access the top of the stack instead of discarding it or examining the top value before/after popping it off the stack?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably yes. My aim here was to keep the interface the same instead of attempting to improving it at the same time as moving, especially while there are no users for a return value from stackPop.


/// Calls `callback` for every StateStack frame of type `T`
/// starting from the top of the stack.
template <typename T>
WalkResult stackWalk(llvm::function_ref<WalkResult(T &)> callback) {
static_assert(std::is_base_of<StateStackFrame, T>::value,
"expected T derived from StackFrame");
if (!callback)
return WalkResult::skip();
for (std::unique_ptr<StateStackFrame> &frame : llvm::reverse(stack)) {
if (T *ptr = dyn_cast_or_null<T>(frame.get())) {
WalkResult result = callback(*ptr);
if (result.wasInterrupted())
return result;
}
}
return WalkResult::advance();
}

private:
SmallVector<std::unique_ptr<StateStackFrame>> stack;
};

/// RAII object calling stackPush/stackPop on construction/destruction.
/// HostClass could be a StateStack or some other class which forwards calls to
/// one.
template <typename T, typename HostClass = StateStack>
struct SaveStateStack {
template <typename... Args>
explicit SaveStateStack(HostClass &host, Args &&...args) : host(host) {
host.template stackPush<T>(std::forward<Args>(args)...);
}
~SaveStateStack() { host.stackPop(); }

private:
HostClass &host;
};

} // namespace mlir

namespace llvm {
template <typename T>
struct isa_impl<T, ::mlir::StateStackFrame> {
static inline bool doit(const ::mlir::StateStackFrame &frame) {
return frame.getTypeID() == ::mlir::TypeID::get<T>();
}
};
} // namespace llvm

#endif // MLIR_SUPPORT_STACKFRAME_H
72 changes: 6 additions & 66 deletions mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/StateStack.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
Expand Down Expand Up @@ -271,80 +272,29 @@ class ModuleTranslation {
/// it if it does not exist.
llvm::NamedMDNode *getOrInsertNamedModuleMetadata(StringRef name);

/// Common CRTP base class for ModuleTranslation stack frames.
class StackFrame {
public:
virtual ~StackFrame() = default;
TypeID getTypeID() const { return typeID; }

protected:
explicit StackFrame(TypeID typeID) : typeID(typeID) {}

private:
const TypeID typeID;
virtual void anchor();
};

/// Concrete CRTP base class for ModuleTranslation stack frames. When
/// translating operations with regions, users of ModuleTranslation can store
/// state on ModuleTranslation stack before entering the region and inspect
/// it when converting operations nested within that region. Users are
/// expected to derive this class and put any relevant information into fields
/// of the derived class. The usual isa/dyn_cast functionality is available
/// for instances of derived classes.
template <typename Derived>
class StackFrameBase : public StackFrame {
public:
explicit StackFrameBase() : StackFrame(TypeID::get<Derived>()) {}
};

/// Creates a stack frame of type `T` on ModuleTranslation stack. `T` must
/// be derived from `StackFrameBase<T>` and constructible from the provided
/// arguments. Doing this before entering the region of the op being
/// translated makes the frame available when translating ops within that
/// region.
template <typename T, typename... Args>
void stackPush(Args &&...args) {
static_assert(
std::is_base_of<StackFrame, T>::value,
"can only push instances of StackFrame on ModuleTranslation stack");
stack.push_back(std::make_unique<T>(std::forward<Args>(args)...));
stack.stackPush<T>(std::forward<Args>(args)...);
}

/// Pops the last element from the ModuleTranslation stack.
void stackPop() { stack.pop_back(); }
void stackPop() { stack.stackPop(); }

/// Calls `callback` for every ModuleTranslation stack frame of type `T`
/// starting from the top of the stack.
template <typename T>
WalkResult stackWalk(llvm::function_ref<WalkResult(T &)> callback) {
static_assert(std::is_base_of<StackFrame, T>::value,
"expected T derived from StackFrame");
if (!callback)
return WalkResult::skip();
for (std::unique_ptr<StackFrame> &frame : llvm::reverse(stack)) {
if (T *ptr = dyn_cast_or_null<T>(frame.get())) {
WalkResult result = callback(*ptr);
if (result.wasInterrupted())
return result;
}
}
return WalkResult::advance();
return stack.stackWalk(callback);
}

/// RAII object calling stackPush/stackPop on construction/destruction.
template <typename T>
struct SaveStack {
template <typename... Args>
explicit SaveStack(ModuleTranslation &m, Args &&...args)
: moduleTranslation(m) {
moduleTranslation.stackPush<T>(std::forward<Args>(args)...);
}
~SaveStack() { moduleTranslation.stackPop(); }

private:
ModuleTranslation &moduleTranslation;
};
using SaveStack = SaveStateStack<T, ModuleTranslation>;

SymbolTableCollection &symbolTable() { return symbolTableCollection; }

Expand Down Expand Up @@ -468,7 +418,7 @@ class ModuleTranslation {

/// Stack of user-specified state elements, useful when translating operations
/// with regions.
SmallVector<std::unique_ptr<StackFrame>> stack;
StateStack stack;

/// A cache for the symbol tables constructed during symbols lookup.
SymbolTableCollection symbolTableCollection;
Expand Down Expand Up @@ -510,14 +460,4 @@ llvm::CallInst *createIntrinsicCall(
} // namespace LLVM
} // namespace mlir

namespace llvm {
template <typename T>
struct isa_impl<T, ::mlir::LLVM::ModuleTranslation::StackFrame> {
static inline bool
doit(const ::mlir::LLVM::ModuleTranslation::StackFrame &frame) {
return frame.getTypeID() == ::mlir::TypeID::get<T>();
}
};
} // namespace llvm

#endif // MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
1 change: 1 addition & 0 deletions mlir/lib/Support/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_mlir_library(MLIRSupport
FileUtilities.cpp
InterfaceSupport.cpp
RawOstreamExtras.cpp
StateStack.cpp
StorageUniquer.cpp
Timing.cpp
ToolUtilities.cpp
Expand Down
15 changes: 15 additions & 0 deletions mlir/lib/Support/StateStack.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
//===- StateStack.cpp - Utility for storing a stack of state --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Support/StateStack.h"

namespace mlir {

void StateStackFrame::anchor() {}

} // namespace mlir
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
/// insertion points for allocas.
class OpenMPAllocaStackFrame
: public LLVM::ModuleTranslation::StackFrameBase<OpenMPAllocaStackFrame> {
: public StateStackFrameBase<OpenMPAllocaStackFrame> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPAllocaStackFrame)

Expand All @@ -84,7 +84,7 @@ class OpenMPAllocaStackFrame
/// collapsed canonical loop information corresponding to an \c omp.loop_nest
/// operation.
class OpenMPLoopInfoStackFrame
: public LLVM::ModuleTranslation::StackFrameBase<OpenMPLoopInfoStackFrame> {
: public StateStackFrameBase<OpenMPLoopInfoStackFrame> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
llvm::CanonicalLoopInfo *loopInfo = nullptr;
Expand Down
2 changes: 0 additions & 2 deletions mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2225,8 +2225,6 @@ ModuleTranslation::getOrInsertNamedModuleMetadata(StringRef name) {
return llvmModule->getOrInsertNamedMetadata(name);
}

void ModuleTranslation::StackFrame::anchor() {}

static std::unique_ptr<llvm::Module>
prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
StringRef name) {
Expand Down
Loading