@@ -82,10 +82,7 @@ class OpenMPLoopInfoStackFrame
82
82
: public LLVM::ModuleTranslation::StackFrameBase<OpenMPLoopInfoStackFrame> {
83
83
public:
84
84
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID (OpenMPLoopInfoStackFrame)
85
-
86
- explicit OpenMPLoopInfoStackFrame (llvm::CanonicalLoopInfo *loopInfo)
87
- : loopInfo(loopInfo) {}
88
- llvm::CanonicalLoopInfo *loopInfo;
85
+ llvm::CanonicalLoopInfo *loopInfo = nullptr ;
89
86
};
90
87
91
88
// / Custom error class to signal translation errors that don't need reporting,
@@ -348,13 +345,13 @@ static LogicalResult handleError(llvm::Expected<T> &result, Operation &op) {
348
345
// / normal operations in the builder.
349
346
static llvm::OpenMPIRBuilder::InsertPointTy
350
347
findAllocaInsertPoint (llvm::IRBuilderBase &builder,
351
- const LLVM::ModuleTranslation &moduleTranslation) {
348
+ LLVM::ModuleTranslation &moduleTranslation) {
352
349
// If there is an alloca insertion point on stack, i.e. we are in a nested
353
350
// operation and a specific point was provided by some surrounding operation,
354
351
// use it.
355
352
llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
356
353
WalkResult walkResult = moduleTranslation.stackWalk <OpenMPAllocaStackFrame>(
357
- [&](const OpenMPAllocaStackFrame &frame) {
354
+ [&](OpenMPAllocaStackFrame &frame) {
358
355
allocaInsertPoint = frame.allocaInsertPoint ;
359
356
return WalkResult::interrupt ();
360
357
});
@@ -386,13 +383,13 @@ findAllocaInsertPoint(llvm::IRBuilderBase &builder,
386
383
}
387
384
388
385
// / Find the loop information structure for the loop nest being translated. It
389
- // / will not return a value unless called from the translation function for
386
+ // / will return a `null` value unless called from the translation function for
390
387
// / a loop wrapper operation after successfully translating its body.
391
- static std::optional< llvm::CanonicalLoopInfo *>
388
+ static llvm::CanonicalLoopInfo *
392
389
findCurrentLoopInfo (LLVM::ModuleTranslation &moduleTranslation) {
393
- std::optional< llvm::CanonicalLoopInfo *> loopInfo;
390
+ llvm::CanonicalLoopInfo *loopInfo = nullptr ;
394
391
moduleTranslation.stackWalk <OpenMPLoopInfoStackFrame>(
395
- [&](const OpenMPLoopInfoStackFrame &frame) {
392
+ [&](OpenMPLoopInfoStackFrame &frame) {
396
393
loopInfo = frame.loopInfo ;
397
394
return WalkResult::interrupt ();
398
395
});
@@ -1987,7 +1984,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
1987
1984
return failure ();
1988
1985
1989
1986
builder.SetInsertPoint (*regionBlock, (*regionBlock)->begin ());
1990
- llvm::CanonicalLoopInfo *loopInfo = * findCurrentLoopInfo (moduleTranslation);
1987
+ llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo (moduleTranslation);
1991
1988
1992
1989
llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
1993
1990
ompBuilder->applyWorkshareLoop (
@@ -2270,16 +2267,16 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
2270
2267
llvm::Value *alignment = nullptr ;
2271
2268
llvm::Value *llvmVal = moduleTranslation.lookupValue (operands[i]);
2272
2269
llvm::Type *ty = llvmVal->getType ();
2273
- if ( auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignmentValues)[i])) {
2274
- alignment = builder. getInt64 (intAttr. getInt () );
2275
- assert (ty-> isPointerTy () && " Invalid type for aligned variable " );
2276
- assert (alignment && " Invalid alignment value " );
2277
- auto curInsert = builder. saveIP ( );
2278
- builder.SetInsertPoint (sourceBlock );
2279
- llvmVal = builder.CreateLoad (ty, llvmVal );
2280
- builder.restoreIP (curInsert );
2281
- alignedVars[llvmVal] = alignment ;
2282
- }
2270
+
2271
+ auto intAttr = cast<IntegerAttr>((*alignmentValues)[i] );
2272
+ alignment = builder. getInt64 (intAttr. getInt () );
2273
+ assert (ty-> isPointerTy () && " Invalid type for aligned variable " );
2274
+ assert (alignment && " Invalid alignment value " );
2275
+ auto curInsert = builder.saveIP ( );
2276
+ builder.SetInsertPoint (sourceBlock );
2277
+ llvmVal = builder.CreateLoad (ty, llvmVal );
2278
+ builder. restoreIP (curInsert) ;
2279
+ alignedVars[llvmVal] = alignment;
2283
2280
}
2284
2281
2285
2282
llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions (
@@ -2289,7 +2286,7 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
2289
2286
return failure ();
2290
2287
2291
2288
builder.SetInsertPoint (*regionBlock, (*regionBlock)->begin ());
2292
- llvm::CanonicalLoopInfo *loopInfo = * findCurrentLoopInfo (moduleTranslation);
2289
+ llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo (moduleTranslation);
2293
2290
ompBuilder->applySimd (loopInfo, alignedVars,
2294
2291
simdOp.getIfExpr ()
2295
2292
? moduleTranslation.lookupValue (simdOp.getIfExpr ())
@@ -2377,11 +2374,13 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
2377
2374
llvm::OpenMPIRBuilder::InsertPointTy afterIP =
2378
2375
loopInfos.front ()->getAfterIP ();
2379
2376
2380
- // Add a stack frame holding information about the resulting loop after
2381
- // applying transformations, to be further transformed by parent loop
2382
- // wrappers.
2383
- moduleTranslation.stackPush <OpenMPLoopInfoStackFrame>(
2384
- ompBuilder->collapseLoops (ompLoc.DL , loopInfos, {}));
2377
+ // Update the stack frame created for this loop to point to the resulting loop
2378
+ // after applying transformations.
2379
+ moduleTranslation.stackWalk <OpenMPLoopInfoStackFrame>(
2380
+ [&](OpenMPLoopInfoStackFrame &frame) {
2381
+ frame.loopInfo = ompBuilder->collapseLoops (ompLoc.DL , loopInfos, {});
2382
+ return WalkResult::interrupt ();
2383
+ });
2385
2384
2386
2385
// Continue building IR after the loop. Note that the LoopInfo returned by
2387
2386
// `collapseLoops` points inside the outermost loop and is intended for
@@ -4576,6 +4575,19 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
4576
4575
LLVM::ModuleTranslation &moduleTranslation) {
4577
4576
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
4578
4577
4578
+ // For each loop, introduce one stack frame to hold loop information. Ensure
4579
+ // this is only done for the outermost loop wrapper to prevent introducing
4580
+ // multiple stack frames for a single loop. Initially set to null, the loop
4581
+ // information structure is initialized during translation of the nested
4582
+ // omp.loop_nest operation, making it available to translation of all loop
4583
+ // wrappers after their body has been successfully translated.
4584
+ bool isOutermostLoopWrapper =
4585
+ isa_and_present<omp::LoopWrapperInterface>(op) &&
4586
+ !dyn_cast_if_present<omp::LoopWrapperInterface>(op->getParentOp ());
4587
+
4588
+ if (isOutermostLoopWrapper)
4589
+ moduleTranslation.stackPush <OpenMPLoopInfoStackFrame>();
4590
+
4579
4591
auto result =
4580
4592
llvm::TypeSwitch<Operation *, LogicalResult>(op)
4581
4593
.Case ([&](omp::BarrierOp op) -> LogicalResult {
@@ -4700,19 +4712,7 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
4700
4712
<< " not yet implemented: " << inst->getName ();
4701
4713
});
4702
4714
4703
- // When translating an omp.loop_nest, one stack frame was pushed to hold that
4704
- // loop's information. The code below ensures that this stack frame is removed
4705
- // when encountering the outermost loop wrapper associated to that loop. This
4706
- // approach allows all loop wrappers have access to that loop's information
4707
- // (to e.g. apply transformations to it) after their associated omp.loop_nest
4708
- // operation has been translated.
4709
- bool isOutermostLoopWrapper =
4710
- isa_and_present<omp::LoopWrapperInterface>(op) &&
4711
- !dyn_cast_if_present<omp::LoopWrapperInterface>(op->getParentOp ());
4712
-
4713
- // We need to check that a loop info is present as well, in case translation
4714
- // of the loop failed before it was created.
4715
- if (isOutermostLoopWrapper && findCurrentLoopInfo (moduleTranslation))
4715
+ if (isOutermostLoopWrapper)
4716
4716
moduleTranslation.stackPop ();
4717
4717
4718
4718
return result;
0 commit comments