Skip to content

Commit 0ef4436

Browse files
committed
Push and pop stack frame within the same function
1 parent 254b7a5 commit 0ef4436

File tree

2 files changed

+42
-43
lines changed

2 files changed

+42
-43
lines changed

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -290,13 +290,12 @@ class ModuleTranslation {
290290
/// Calls `callback` for every ModuleTranslation stack frame of type `T`
291291
/// starting from the top of the stack.
292292
template <typename T>
293-
WalkResult
294-
stackWalk(llvm::function_ref<WalkResult(const T &)> callback) const {
293+
WalkResult stackWalk(llvm::function_ref<WalkResult(T &)> callback) {
295294
static_assert(std::is_base_of<StackFrame, T>::value,
296295
"expected T derived from StackFrame");
297296
if (!callback)
298297
return WalkResult::skip();
299-
for (const std::unique_ptr<StackFrame> &frame : llvm::reverse(stack)) {
298+
for (std::unique_ptr<StackFrame> &frame : llvm::reverse(stack)) {
300299
if (T *ptr = dyn_cast_or_null<T>(frame.get())) {
301300
WalkResult result = callback(*ptr);
302301
if (result.wasInterrupted())

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,7 @@ class OpenMPLoopInfoStackFrame
8282
: public LLVM::ModuleTranslation::StackFrameBase<OpenMPLoopInfoStackFrame> {
8383
public:
8484
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
85-
86-
explicit OpenMPLoopInfoStackFrame(llvm::CanonicalLoopInfo *loopInfo)
87-
: loopInfo(loopInfo) {}
88-
llvm::CanonicalLoopInfo *loopInfo;
85+
llvm::CanonicalLoopInfo *loopInfo = nullptr;
8986
};
9087

9188
/// Custom error class to signal translation errors that don't need reporting,
@@ -348,13 +345,13 @@ static LogicalResult handleError(llvm::Expected<T> &result, Operation &op) {
348345
/// normal operations in the builder.
349346
static llvm::OpenMPIRBuilder::InsertPointTy
350347
findAllocaInsertPoint(llvm::IRBuilderBase &builder,
351-
const LLVM::ModuleTranslation &moduleTranslation) {
348+
LLVM::ModuleTranslation &moduleTranslation) {
352349
// If there is an alloca insertion point on stack, i.e. we are in a nested
353350
// operation and a specific point was provided by some surrounding operation,
354351
// use it.
355352
llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
356353
WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocaStackFrame>(
357-
[&](const OpenMPAllocaStackFrame &frame) {
354+
[&](OpenMPAllocaStackFrame &frame) {
358355
allocaInsertPoint = frame.allocaInsertPoint;
359356
return WalkResult::interrupt();
360357
});
@@ -386,13 +383,13 @@ findAllocaInsertPoint(llvm::IRBuilderBase &builder,
386383
}
387384

388385
/// Find the loop information structure for the loop nest being translated. It
389-
/// will not return a value unless called from the translation function for
386+
/// will return a `null` value unless called from the translation function for
390387
/// a loop wrapper operation after successfully translating its body.
391-
static std::optional<llvm::CanonicalLoopInfo *>
388+
static llvm::CanonicalLoopInfo *
392389
findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation) {
393-
std::optional<llvm::CanonicalLoopInfo *> loopInfo;
390+
llvm::CanonicalLoopInfo *loopInfo = nullptr;
394391
moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
395-
[&](const OpenMPLoopInfoStackFrame &frame) {
392+
[&](OpenMPLoopInfoStackFrame &frame) {
396393
loopInfo = frame.loopInfo;
397394
return WalkResult::interrupt();
398395
});
@@ -1987,7 +1984,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
19871984
return failure();
19881985

19891986
builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
1990-
llvm::CanonicalLoopInfo *loopInfo = *findCurrentLoopInfo(moduleTranslation);
1987+
llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
19911988

19921989
llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
19931990
ompBuilder->applyWorkshareLoop(
@@ -2270,16 +2267,16 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
22702267
llvm::Value *alignment = nullptr;
22712268
llvm::Value *llvmVal = moduleTranslation.lookupValue(operands[i]);
22722269
llvm::Type *ty = llvmVal->getType();
2273-
if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignmentValues)[i])) {
2274-
alignment = builder.getInt64(intAttr.getInt());
2275-
assert(ty->isPointerTy() && "Invalid type for aligned variable");
2276-
assert(alignment && "Invalid alignment value");
2277-
auto curInsert = builder.saveIP();
2278-
builder.SetInsertPoint(sourceBlock);
2279-
llvmVal = builder.CreateLoad(ty, llvmVal);
2280-
builder.restoreIP(curInsert);
2281-
alignedVars[llvmVal] = alignment;
2282-
}
2270+
2271+
auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
2272+
alignment = builder.getInt64(intAttr.getInt());
2273+
assert(ty->isPointerTy() && "Invalid type for aligned variable");
2274+
assert(alignment && "Invalid alignment value");
2275+
auto curInsert = builder.saveIP();
2276+
builder.SetInsertPoint(sourceBlock);
2277+
llvmVal = builder.CreateLoad(ty, llvmVal);
2278+
builder.restoreIP(curInsert);
2279+
alignedVars[llvmVal] = alignment;
22832280
}
22842281

22852282
llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
@@ -2289,7 +2286,7 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
22892286
return failure();
22902287

22912288
builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2292-
llvm::CanonicalLoopInfo *loopInfo = *findCurrentLoopInfo(moduleTranslation);
2289+
llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
22932290
ompBuilder->applySimd(loopInfo, alignedVars,
22942291
simdOp.getIfExpr()
22952292
? moduleTranslation.lookupValue(simdOp.getIfExpr())
@@ -2377,11 +2374,13 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
23772374
llvm::OpenMPIRBuilder::InsertPointTy afterIP =
23782375
loopInfos.front()->getAfterIP();
23792376

2380-
// Add a stack frame holding information about the resulting loop after
2381-
// applying transformations, to be further transformed by parent loop
2382-
// wrappers.
2383-
moduleTranslation.stackPush<OpenMPLoopInfoStackFrame>(
2384-
ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {}));
2377+
// Update the stack frame created for this loop to point to the resulting loop
2378+
// after applying transformations.
2379+
moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
2380+
[&](OpenMPLoopInfoStackFrame &frame) {
2381+
frame.loopInfo = ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
2382+
return WalkResult::interrupt();
2383+
});
23852384

23862385
// Continue building IR after the loop. Note that the LoopInfo returned by
23872386
// `collapseLoops` points inside the outermost loop and is intended for
@@ -4576,6 +4575,19 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
45764575
LLVM::ModuleTranslation &moduleTranslation) {
45774576
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
45784577

4578+
// For each loop, introduce one stack frame to hold loop information. Ensure
4579+
// this is only done for the outermost loop wrapper to prevent introducing
4580+
// multiple stack frames for a single loop. Initially set to null, the loop
4581+
// information structure is initialized during translation of the nested
4582+
// omp.loop_nest operation, making it available to translation of all loop
4583+
// wrappers after their body has been successfully translated.
4584+
bool isOutermostLoopWrapper =
4585+
isa_and_present<omp::LoopWrapperInterface>(op) &&
4586+
!dyn_cast_if_present<omp::LoopWrapperInterface>(op->getParentOp());
4587+
4588+
if (isOutermostLoopWrapper)
4589+
moduleTranslation.stackPush<OpenMPLoopInfoStackFrame>();
4590+
45794591
auto result =
45804592
llvm::TypeSwitch<Operation *, LogicalResult>(op)
45814593
.Case([&](omp::BarrierOp op) -> LogicalResult {
@@ -4700,19 +4712,7 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
47004712
<< "not yet implemented: " << inst->getName();
47014713
});
47024714

4703-
// When translating an omp.loop_nest, one stack frame was pushed to hold that
4704-
// loop's information. The code below ensures that this stack frame is removed
4705-
// when encountering the outermost loop wrapper associated to that loop. This
4706-
// approach allows all loop wrappers have access to that loop's information
4707-
// (to e.g. apply transformations to it) after their associated omp.loop_nest
4708-
// operation has been translated.
4709-
bool isOutermostLoopWrapper =
4710-
isa_and_present<omp::LoopWrapperInterface>(op) &&
4711-
!dyn_cast_if_present<omp::LoopWrapperInterface>(op->getParentOp());
4712-
4713-
// We need to check that a loop info is present as well, in case translation
4714-
// of the loop failed before it was created.
4715-
if (isOutermostLoopWrapper && findCurrentLoopInfo(moduleTranslation))
4715+
if (isOutermostLoopWrapper)
47164716
moduleTranslation.stackPop();
47174717

47184718
return result;

0 commit comments

Comments
 (0)