Skip to content

[MLIR][OpenMP] Normalize lowering of omp.loop_nest #127217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 24, 2025

Conversation

skatrak
Copy link
Member

@skatrak skatrak commented Feb 14, 2025

This patch refactors the translation of omp.loop_nest operations into LLVM IR so that it is handled similarly to other operations. Before this change, the responsibility of translating the loop nest fell into each loop wrapper, causing code duplication. This patch centralizes that handling of the loop. One consequence of this was fixing an issue lowering non-inclusive omp.simd loops.

As a result, it is now expected that the handling of composite constructs is performed collaboratively among translating functions for each operation involved. At the moment, only do/for simd is supported by ignoring SIMD information, and this behavior is preserved.

The translation of loop wrapper operations needs access to the llvm::CanonicalLoopInfo loop information structure in order to apply transformations to it. This is now created in the nested call to convertOmpLoopNest, so it needs to be passed up to all associated loop wrapper translation functions. This is done via the creation of an OpenMPLoopInfoStackFrame within convertHostOrTargetOperation, associated to the outermost loop wrapper. This structure is updated by convertOmpLoopNest, making the result available to all loop wrappers after their body has been translated.

@llvmbot
Copy link
Member

llvmbot commented Feb 14, 2025

@llvm/pr-subscribers-flang-openmp
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-openmp

Author: Sergio Afonso (skatrak)

Changes

This patch refactors the translation of omp.loop_nest operations into LLVM IR so that it is handled similarly to other operations. Before this change, the responsibility of translating the loop nest fell into each loop wrapper, causing code duplication. This patch centralizes that handling of the loop. One consequence of this was fixing an issue lowering non-inclusive omp.simd loops.

As a result, it is now expected that the handling of composite constructs is performed collaboratively among translating functions for each operation involved. At the moment, only do/for simd is supported by ignoring SIMD information, and this behavior is preserved.

The translation of loop wrapper operations need access to the llvm::CanonicalLoopInfo loop information structure in order to apply transformations to it. This is now created in the nested call to convertOmpLoopNest, so it needs to be passed up to all associated loop wrapper translation functions. This is done via the creation of an OpenMPLoopInfoStackFrame within convertOmpLoopNest and its removal after its outermost associated loop wrapper has been translated.


Patch is 36.07 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/127217.diff

3 Files Affected:

  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+303-292)
  • (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+1-1)
  • (modified) mlir/test/Target/LLVMIR/openmp-simd-private.mlir (+6-3)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 51a3cbdbb5e7f..a5ff3eff6439f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -75,6 +75,19 @@ class OpenMPAllocaStackFrame
   llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
 };
 
+/// Stack frame to hold a \see llvm::CanonicalLoopInfo representing the
+/// collapsed canonical loop information corresponding to an \c omp.loop_nest
+/// operation.
+class OpenMPLoopInfoStackFrame
+    : public LLVM::ModuleTranslation::StackFrameBase<OpenMPLoopInfoStackFrame> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
+
+  explicit OpenMPLoopInfoStackFrame(llvm::CanonicalLoopInfo *loopInfo)
+      : loopInfo(loopInfo) {}
+  llvm::CanonicalLoopInfo *loopInfo;
+};
+
 /// Custom error class to signal translation errors that don't need reporting,
 /// since encountering them will have already triggered relevant error messages.
 ///
@@ -372,6 +385,20 @@ findAllocaInsertPoint(llvm::IRBuilderBase &builder,
       &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
 }
 
+/// Find the loop information structure for the loop nest being translated. It
+/// will not return a value unless called from the translation function for
+/// a loop wrapper operation after successfully translating its body.
+static std::optional<llvm::CanonicalLoopInfo *>
+findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation) {
+  std::optional<llvm::CanonicalLoopInfo *> loopInfo;
+  moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
+      [&](const OpenMPLoopInfoStackFrame &frame) {
+        loopInfo = frame.loopInfo;
+        return WalkResult::interrupt();
+      });
+  return loopInfo;
+}
+
 /// Converts the given region that appears within an OpenMP dialect operation to
 /// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
 /// region, and a branch from any block with an successor-less OpenMP terminator
@@ -381,6 +408,8 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
     Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
     LLVM::ModuleTranslation &moduleTranslation,
     SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
+  bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.getParentOp());
+
   llvm::BasicBlock *continuationBlock =
       splitBB(builder, true, "omp.region.cont");
   llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
@@ -397,30 +426,34 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
 
   // Terminators (namely YieldOp) may be forwarding values to the region that
   // need to be available in the continuation block. Collect the types of these
-  // operands in preparation of creating PHI nodes.
+  // operands in preparation of creating PHI nodes. This is skipped for loop
+  // wrapper operations, for which we know in advance they have no terminators.
   SmallVector<llvm::Type *> continuationBlockPHITypes;
-  bool operandsProcessed = false;
   unsigned numYields = 0;
-  for (Block &bb : region.getBlocks()) {
-    if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
-      if (!operandsProcessed) {
-        for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
-          continuationBlockPHITypes.push_back(
-              moduleTranslation.convertType(yield->getOperand(i).getType()));
-        }
-        operandsProcessed = true;
-      } else {
-        assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
-               "mismatching number of values yielded from the region");
-        for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
-          llvm::Type *operandType =
-              moduleTranslation.convertType(yield->getOperand(i).getType());
-          (void)operandType;
-          assert(continuationBlockPHITypes[i] == operandType &&
-                 "values of mismatching types yielded from the region");
+
+  if (!isLoopWrapper) {
+    bool operandsProcessed = false;
+    for (Block &bb : region.getBlocks()) {
+      if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
+        if (!operandsProcessed) {
+          for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
+            continuationBlockPHITypes.push_back(
+                moduleTranslation.convertType(yield->getOperand(i).getType()));
+          }
+          operandsProcessed = true;
+        } else {
+          assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
+                 "mismatching number of values yielded from the region");
+          for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
+            llvm::Type *operandType =
+                moduleTranslation.convertType(yield->getOperand(i).getType());
+            (void)operandType;
+            assert(continuationBlockPHITypes[i] == operandType &&
+                   "values of mismatching types yielded from the region");
+          }
         }
+        numYields++;
       }
-      numYields++;
     }
   }
 
@@ -458,6 +491,13 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
             moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
       return llvm::make_error<PreviouslyReportedError>();
 
+    // Create a direct branch here for loop wrappers to prevent their lack of a
+    // terminator from causing a crash below.
+    if (isLoopWrapper) {
+      builder.CreateBr(continuationBlock);
+      continue;
+    }
+
     // Special handling for `omp.yield` and `omp.terminator` (we may have more
     // than one): they return the control to the parent OpenMP dialect operation
     // so replace them with the branch to the continuation block. We handle this
@@ -509,7 +549,7 @@ static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
 /// This must be called after block arguments of parent wrappers have already
 /// been mapped to LLVM IR values.
 static LogicalResult
-convertIgnoredWrapper(omp::LoopWrapperInterface &opInst,
+convertIgnoredWrapper(omp::LoopWrapperInterface opInst,
                       LLVM::ModuleTranslation &moduleTranslation) {
   // Map block arguments directly to the LLVM value associated to the
   // corresponding operand. This is semantically equivalent to this wrapper not
@@ -531,34 +571,10 @@ convertIgnoredWrapper(omp::LoopWrapperInterface &opInst,
         return success();
       })
       .Default([&](Operation *op) {
-        return op->emitError() << "cannot ignore nested wrapper";
+        return op->emitError() << "cannot ignore wrapper";
       });
 }
 
-/// Helper function to call \c convertIgnoredWrapper() for all wrappers of the
-/// given \c loopOp nested inside of \c parentOp. This has the effect of mapping
-/// entry block arguments defined by these operations to outside values.
-///
-/// It must be called after block arguments of \c parentOp have already been
-/// mapped themselves.
-static LogicalResult
-convertIgnoredWrappers(omp::LoopNestOp loopOp,
-                       omp::LoopWrapperInterface parentOp,
-                       LLVM::ModuleTranslation &moduleTranslation) {
-  SmallVector<omp::LoopWrapperInterface> wrappers;
-  loopOp.gatherWrappers(wrappers);
-
-  // Process wrappers nested inside of `parentOp` from outermost to innermost.
-  for (auto it =
-           std::next(std::find(wrappers.rbegin(), wrappers.rend(), parentOp));
-       it != wrappers.rend(); ++it) {
-    if (failed(convertIgnoredWrapper(*it, moduleTranslation)))
-      return failure();
-  }
-
-  return success();
-}
-
 /// Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
 convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -1876,6 +1892,7 @@ convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder,
 static LogicalResult
 convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
                  LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
   auto wsloopOp = cast<omp::WsloopOp>(opInst);
   if (failed(checkImplementationStatus(opInst)))
     return failure();
@@ -1956,90 +1973,25 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
                                reductionVariableMap, isByRef, deferredStores)))
     return failure();
 
-  // TODO: Replace this with proper composite translation support.
-  // Currently, all nested wrappers are ignored, so 'do/for simd' will be
-  // treated the same as a standalone 'do/for'. This is allowed by the spec,
-  // since it's equivalent to always using a SIMD length of 1.
-  if (failed(convertIgnoredWrappers(loopOp, wsloopOp, moduleTranslation)))
-    return failure();
-
-  // Set up the source location value for OpenMP runtime.
-  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
-
-  // Generator of the canonical loop body.
-  SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
-  SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
-  auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
-                     llvm::Value *iv) -> llvm::Error {
-    // Make sure further conversions know about the induction variable.
-    moduleTranslation.mapValue(
-        loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
-
-    // Capture the body insertion point for use in nested loops. BodyIP of the
-    // CanonicalLoopInfo always points to the beginning of the entry block of
-    // the body.
-    bodyInsertPoints.push_back(ip);
-
-    if (loopInfos.size() != loopOp.getNumLoops() - 1)
-      return llvm::Error::success();
-
-    // Convert the body of the loop.
-    builder.restoreIP(ip);
-    return convertOmpOpRegions(loopOp.getRegion(), "omp.wsloop.region", builder,
-                               moduleTranslation)
-        .takeError();
-  };
-
-  // Delegate actual loop construction to the OpenMP IRBuilder.
-  // TODO: this currently assumes omp.loop_nest is semantically similar to SCF
-  // loop, i.e. it has a positive step, uses signed integer semantics.
-  // Reconsider this code when the nested loop operation clearly supports more
-  // cases.
-  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
-  for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
-    llvm::Value *lowerBound =
-        moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
-    llvm::Value *upperBound =
-        moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]);
-    llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]);
-
-    // Make sure loop trip count are emitted in the preheader of the outermost
-    // loop at the latest so that they are all available for the new collapsed
-    // loop will be created below.
-    llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
-    llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
-    if (i != 0) {
-      loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back());
-      computeIP = loopInfos.front()->getPreheaderIP();
-    }
-
-    llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
-        ompBuilder->createCanonicalLoop(
-            loc, bodyGen, lowerBound, upperBound, step,
-            /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
-
-    if (failed(handleError(loopResult, *loopOp)))
-      return failure();
-
-    loopInfos.push_back(*loopResult);
-  }
-
-  // Collapse loops. Store the insertion point because LoopInfos may get
-  // invalidated.
-  llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
-  llvm::CanonicalLoopInfo *loopInfo =
-      ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
-
-  allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
-
   // TODO: Handle doacross loops when the ordered clause has a parameter.
   bool isOrdered = wsloopOp.getOrdered().has_value();
   std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
   bool isSimd = wsloopOp.getScheduleSimd();
+  bool loopNeedsBarrier = !wsloopOp.getNowait();
+
+  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+  llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
+      wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
+
+  if (failed(handleError(regionBlock, opInst)))
+    return failure();
+
+  builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+  llvm::CanonicalLoopInfo *loopInfo = *findCurrentLoopInfo(moduleTranslation);
 
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
       ompBuilder->applyWorkshareLoop(
-          ompLoc.DL, loopInfo, allocaIP, !wsloopOp.getNowait(),
+          ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
           convertToScheduleKind(schedule), chunk, isSimd,
           scheduleMod == omp::ScheduleModifier::monotonic,
           scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered);
@@ -2047,12 +1999,6 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   if (failed(handleError(wsloopIP, opInst)))
     return failure();
 
-  // Continue building IR after the loop. Note that the LoopInfo returned by
-  // `collapseLoops` points inside the outermost loop and is intended for
-  // potential further loop transformations. Use the insertion point stored
-  // before collapsing loops instead.
-  builder.restoreIP(afterIP);
-
   // Process the reductions if required.
   if (failed(createReductionsAndCleanup(wsloopOp, builder, moduleTranslation,
                                         allocaIP, reductionDecls,
@@ -2261,8 +2207,20 @@ convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
 static LogicalResult
 convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
                LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
   auto simdOp = cast<omp::SimdOp>(opInst);
-  auto loopOp = cast<omp::LoopNestOp>(simdOp.getWrappedLoop());
+
+  // TODO: Replace this with proper composite translation support.
+  // Currently, simd information on composite constructs is ignored, so e.g.
+  // 'do/for simd' will be treated the same as a standalone 'do/for'. This is
+  // allowed by the spec, since it's equivalent to using a SIMD length of 1.
+  if (simdOp.isComposite()) {
+    if (failed(convertIgnoredWrapper(simdOp, moduleTranslation)))
+      return failure();
+
+    return inlineConvertOmpRegions(simdOp.getRegion(), "omp.simd.region",
+                                   builder, moduleTranslation);
+  }
 
   if (failed(checkImplementationStatus(opInst)))
     return failure();
@@ -2295,6 +2253,61 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
           .failed())
     return failure();
 
+  llvm::ConstantInt *simdlen = nullptr;
+  if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
+    simdlen = builder.getInt64(simdlenVar.value());
+
+  llvm::ConstantInt *safelen = nullptr;
+  if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
+    safelen = builder.getInt64(safelenVar.value());
+
+  llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
+  llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
+  llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
+  std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
+  mlir::OperandRange operands = simdOp.getAlignedVars();
+  for (size_t i = 0; i < operands.size(); ++i) {
+    llvm::Value *alignment = nullptr;
+    llvm::Value *llvmVal = moduleTranslation.lookupValue(operands[i]);
+    llvm::Type *ty = llvmVal->getType();
+    if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignmentValues)[i])) {
+      alignment = builder.getInt64(intAttr.getInt());
+      assert(ty->isPointerTy() && "Invalid type for aligned variable");
+      assert(alignment && "Invalid alignment value");
+      auto curInsert = builder.saveIP();
+      builder.SetInsertPoint(sourceBlock);
+      llvmVal = builder.CreateLoad(ty, llvmVal);
+      builder.restoreIP(curInsert);
+      alignedVars[llvmVal] = alignment;
+    }
+  }
+
+  llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
+      simdOp.getRegion(), "omp.simd.region", builder, moduleTranslation);
+
+  if (failed(handleError(regionBlock, opInst)))
+    return failure();
+
+  builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+  llvm::CanonicalLoopInfo *loopInfo = *findCurrentLoopInfo(moduleTranslation);
+  ompBuilder->applySimd(loopInfo, alignedVars,
+                        simdOp.getIfExpr()
+                            ? moduleTranslation.lookupValue(simdOp.getIfExpr())
+                            : nullptr,
+                        order, simdlen, safelen);
+
+  return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
+                            llvmPrivateVars, privateDecls);
+}
+
+/// Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
+static LogicalResult
+convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
+                   LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+  auto loopOp = cast<omp::LoopNestOp>(opInst);
+
+  // Set up the source location value for OpenMP runtime.
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
 
   // Generator of the canonical loop body.
@@ -2316,9 +2329,13 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
 
     // Convert the body of the loop.
     builder.restoreIP(ip);
-    return convertOmpOpRegions(loopOp.getRegion(), "omp.simd.region", builder,
-                               moduleTranslation)
-        .takeError();
+    llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
+        loopOp.getRegion(), "omp.loop_nest.region", builder, moduleTranslation);
+    if (!regionBlock)
+      return regionBlock.takeError();
+
+    builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+    return llvm::Error::success();
   };
 
   // Delegate actual loop construction to the OpenMP IRBuilder.
@@ -2326,7 +2343,6 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
   // loop, i.e. it has a positive step, uses signed integer semantics.
   // Reconsider this code when the nested loop operation clearly supports more
   // cases.
-  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
   for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
     llvm::Value *lowerBound =
         moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
@@ -2348,7 +2364,7 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
     llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
         ompBuilder->createCanonicalLoop(
             loc, bodyGen, lowerBound, upperBound, step,
-            /*IsSigned=*/true, /*InclusiveStop=*/true, computeIP);
+            /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
 
     if (failed(handleError(loopResult, *loopOp)))
       return failure();
@@ -2356,49 +2372,23 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
     loopInfos.push_back(*loopResult);
   }
 
-  // Collapse loops.
-  llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
-  llvm::CanonicalLoopInfo *loopInfo =
-      ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
-
-  llvm::ConstantInt *simdlen = nullptr;
-  if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
-    simdlen = builder.getInt64(simdlenVar.value());
+  // Collapse loops. Store the insertion point because LoopInfos may get
+  // invalidated.
+  llvm::OpenMPIRBuilder::InsertPointTy afterIP =
+      loopInfos.front()->getAfterIP();
 
-  llvm::ConstantInt *safelen = nullptr;
-  if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
-    safelen = builder.getInt64(safelenVar.value());
-
-  llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
-  llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
-  llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
-  std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
-  mlir::OperandRange...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Feb 14, 2025

@llvm/pr-subscribers-mlir-llvm

Author: Sergio Afonso (skatrak)

Changes

This patch refactors the translation of omp.loop_nest operations into LLVM IR so that it is handled similarly to other operations. Before this change, the responsibility of translating the loop nest fell into each loop wrapper, causing code duplication. This patch centralizes that handling of the loop. One consequence of this was fixing an issue lowering non-inclusive omp.simd loops.

As a result, it is now expected that the handling of composite constructs is performed collaboratively among translating functions for each operation involved. At the moment, only do/for simd is supported by ignoring SIMD information, and this behavior is preserved.

The translation of loop wrapper operations need access to the llvm::CanonicalLoopInfo loop information structure in order to apply transformations to it. This is now created in the nested call to convertOmpLoopNest, so it needs to be passed up to all associated loop wrapper translation functions. This is done via the creation of an OpenMPLoopInfoStackFrame within convertOmpLoopNest and its removal after its outermost associated loop wrapper has been translated.


Patch is 36.07 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/127217.diff

3 Files Affected:

  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+303-292)
  • (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+1-1)
  • (modified) mlir/test/Target/LLVMIR/openmp-simd-private.mlir (+6-3)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 51a3cbdbb5e7f..a5ff3eff6439f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -75,6 +75,19 @@ class OpenMPAllocaStackFrame
   llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
 };
 
+/// Stack frame to hold a \see llvm::CanonicalLoopInfo representing the
+/// collapsed canonical loop information corresponding to an \c omp.loop_nest
+/// operation.
+class OpenMPLoopInfoStackFrame
+    : public LLVM::ModuleTranslation::StackFrameBase<OpenMPLoopInfoStackFrame> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
+
+  explicit OpenMPLoopInfoStackFrame(llvm::CanonicalLoopInfo *loopInfo)
+      : loopInfo(loopInfo) {}
+  llvm::CanonicalLoopInfo *loopInfo;
+};
+
 /// Custom error class to signal translation errors that don't need reporting,
 /// since encountering them will have already triggered relevant error messages.
 ///
@@ -372,6 +385,20 @@ findAllocaInsertPoint(llvm::IRBuilderBase &builder,
       &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
 }
 
+/// Find the loop information structure for the loop nest being translated. It
+/// will not return a value unless called from the translation function for
+/// a loop wrapper operation after successfully translating its body.
+static std::optional<llvm::CanonicalLoopInfo *>
+findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation) {
+  std::optional<llvm::CanonicalLoopInfo *> loopInfo;
+  moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
+      [&](const OpenMPLoopInfoStackFrame &frame) {
+        loopInfo = frame.loopInfo;
+        return WalkResult::interrupt();
+      });
+  return loopInfo;
+}
+
 /// Converts the given region that appears within an OpenMP dialect operation to
 /// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
 /// region, and a branch from any block with an successor-less OpenMP terminator
@@ -381,6 +408,8 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
     Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
     LLVM::ModuleTranslation &moduleTranslation,
     SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
+  bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.getParentOp());
+
   llvm::BasicBlock *continuationBlock =
       splitBB(builder, true, "omp.region.cont");
   llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
@@ -397,30 +426,34 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
 
   // Terminators (namely YieldOp) may be forwarding values to the region that
   // need to be available in the continuation block. Collect the types of these
-  // operands in preparation of creating PHI nodes.
+  // operands in preparation of creating PHI nodes. This is skipped for loop
+  // wrapper operations, for which we know in advance they have no terminators.
   SmallVector<llvm::Type *> continuationBlockPHITypes;
-  bool operandsProcessed = false;
   unsigned numYields = 0;
-  for (Block &bb : region.getBlocks()) {
-    if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
-      if (!operandsProcessed) {
-        for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
-          continuationBlockPHITypes.push_back(
-              moduleTranslation.convertType(yield->getOperand(i).getType()));
-        }
-        operandsProcessed = true;
-      } else {
-        assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
-               "mismatching number of values yielded from the region");
-        for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
-          llvm::Type *operandType =
-              moduleTranslation.convertType(yield->getOperand(i).getType());
-          (void)operandType;
-          assert(continuationBlockPHITypes[i] == operandType &&
-                 "values of mismatching types yielded from the region");
+
+  if (!isLoopWrapper) {
+    bool operandsProcessed = false;
+    for (Block &bb : region.getBlocks()) {
+      if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
+        if (!operandsProcessed) {
+          for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
+            continuationBlockPHITypes.push_back(
+                moduleTranslation.convertType(yield->getOperand(i).getType()));
+          }
+          operandsProcessed = true;
+        } else {
+          assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
+                 "mismatching number of values yielded from the region");
+          for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
+            llvm::Type *operandType =
+                moduleTranslation.convertType(yield->getOperand(i).getType());
+            (void)operandType;
+            assert(continuationBlockPHITypes[i] == operandType &&
+                   "values of mismatching types yielded from the region");
+          }
         }
+        numYields++;
       }
-      numYields++;
     }
   }
 
@@ -458,6 +491,13 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
             moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
       return llvm::make_error<PreviouslyReportedError>();
 
+    // Create a direct branch here for loop wrappers to prevent their lack of a
+    // terminator from causing a crash below.
+    if (isLoopWrapper) {
+      builder.CreateBr(continuationBlock);
+      continue;
+    }
+
     // Special handling for `omp.yield` and `omp.terminator` (we may have more
     // than one): they return the control to the parent OpenMP dialect operation
     // so replace them with the branch to the continuation block. We handle this
@@ -509,7 +549,7 @@ static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
 /// This must be called after block arguments of parent wrappers have already
 /// been mapped to LLVM IR values.
 static LogicalResult
-convertIgnoredWrapper(omp::LoopWrapperInterface &opInst,
+convertIgnoredWrapper(omp::LoopWrapperInterface opInst,
                       LLVM::ModuleTranslation &moduleTranslation) {
   // Map block arguments directly to the LLVM value associated to the
   // corresponding operand. This is semantically equivalent to this wrapper not
@@ -531,34 +571,10 @@ convertIgnoredWrapper(omp::LoopWrapperInterface &opInst,
         return success();
       })
       .Default([&](Operation *op) {
-        return op->emitError() << "cannot ignore nested wrapper";
+        return op->emitError() << "cannot ignore wrapper";
       });
 }
 
-/// Helper function to call \c convertIgnoredWrapper() for all wrappers of the
-/// given \c loopOp nested inside of \c parentOp. This has the effect of mapping
-/// entry block arguments defined by these operations to outside values.
-///
-/// It must be called after block arguments of \c parentOp have already been
-/// mapped themselves.
-static LogicalResult
-convertIgnoredWrappers(omp::LoopNestOp loopOp,
-                       omp::LoopWrapperInterface parentOp,
-                       LLVM::ModuleTranslation &moduleTranslation) {
-  SmallVector<omp::LoopWrapperInterface> wrappers;
-  loopOp.gatherWrappers(wrappers);
-
-  // Process wrappers nested inside of `parentOp` from outermost to innermost.
-  for (auto it =
-           std::next(std::find(wrappers.rbegin(), wrappers.rend(), parentOp));
-       it != wrappers.rend(); ++it) {
-    if (failed(convertIgnoredWrapper(*it, moduleTranslation)))
-      return failure();
-  }
-
-  return success();
-}
-
 /// Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
 convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -1876,6 +1892,7 @@ convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder,
 static LogicalResult
 convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
                  LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
   auto wsloopOp = cast<omp::WsloopOp>(opInst);
   if (failed(checkImplementationStatus(opInst)))
     return failure();
@@ -1956,90 +1973,25 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
                                reductionVariableMap, isByRef, deferredStores)))
     return failure();
 
-  // TODO: Replace this with proper composite translation support.
-  // Currently, all nested wrappers are ignored, so 'do/for simd' will be
-  // treated the same as a standalone 'do/for'. This is allowed by the spec,
-  // since it's equivalent to always using a SIMD length of 1.
-  if (failed(convertIgnoredWrappers(loopOp, wsloopOp, moduleTranslation)))
-    return failure();
-
-  // Set up the source location value for OpenMP runtime.
-  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
-
-  // Generator of the canonical loop body.
-  SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
-  SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
-  auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
-                     llvm::Value *iv) -> llvm::Error {
-    // Make sure further conversions know about the induction variable.
-    moduleTranslation.mapValue(
-        loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
-
-    // Capture the body insertion point for use in nested loops. BodyIP of the
-    // CanonicalLoopInfo always points to the beginning of the entry block of
-    // the body.
-    bodyInsertPoints.push_back(ip);
-
-    if (loopInfos.size() != loopOp.getNumLoops() - 1)
-      return llvm::Error::success();
-
-    // Convert the body of the loop.
-    builder.restoreIP(ip);
-    return convertOmpOpRegions(loopOp.getRegion(), "omp.wsloop.region", builder,
-                               moduleTranslation)
-        .takeError();
-  };
-
-  // Delegate actual loop construction to the OpenMP IRBuilder.
-  // TODO: this currently assumes omp.loop_nest is semantically similar to SCF
-  // loop, i.e. it has a positive step, uses signed integer semantics.
-  // Reconsider this code when the nested loop operation clearly supports more
-  // cases.
-  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
-  for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
-    llvm::Value *lowerBound =
-        moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
-    llvm::Value *upperBound =
-        moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]);
-    llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]);
-
-    // Make sure loop trip count are emitted in the preheader of the outermost
-    // loop at the latest so that they are all available for the new collapsed
-    // loop will be created below.
-    llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
-    llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
-    if (i != 0) {
-      loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back());
-      computeIP = loopInfos.front()->getPreheaderIP();
-    }
-
-    llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
-        ompBuilder->createCanonicalLoop(
-            loc, bodyGen, lowerBound, upperBound, step,
-            /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
-
-    if (failed(handleError(loopResult, *loopOp)))
-      return failure();
-
-    loopInfos.push_back(*loopResult);
-  }
-
-  // Collapse loops. Store the insertion point because LoopInfos may get
-  // invalidated.
-  llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
-  llvm::CanonicalLoopInfo *loopInfo =
-      ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
-
-  allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
-
   // TODO: Handle doacross loops when the ordered clause has a parameter.
   bool isOrdered = wsloopOp.getOrdered().has_value();
   std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
   bool isSimd = wsloopOp.getScheduleSimd();
+  bool loopNeedsBarrier = !wsloopOp.getNowait();
+
+  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+  llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
+      wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
+
+  if (failed(handleError(regionBlock, opInst)))
+    return failure();
+
+  builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+  llvm::CanonicalLoopInfo *loopInfo = *findCurrentLoopInfo(moduleTranslation);
 
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
       ompBuilder->applyWorkshareLoop(
-          ompLoc.DL, loopInfo, allocaIP, !wsloopOp.getNowait(),
+          ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
           convertToScheduleKind(schedule), chunk, isSimd,
           scheduleMod == omp::ScheduleModifier::monotonic,
           scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered);
@@ -2047,12 +1999,6 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   if (failed(handleError(wsloopIP, opInst)))
     return failure();
 
-  // Continue building IR after the loop. Note that the LoopInfo returned by
-  // `collapseLoops` points inside the outermost loop and is intended for
-  // potential further loop transformations. Use the insertion point stored
-  // before collapsing loops instead.
-  builder.restoreIP(afterIP);
-
   // Process the reductions if required.
   if (failed(createReductionsAndCleanup(wsloopOp, builder, moduleTranslation,
                                         allocaIP, reductionDecls,
@@ -2261,8 +2207,20 @@ convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
 static LogicalResult
 convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
                LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
   auto simdOp = cast<omp::SimdOp>(opInst);
-  auto loopOp = cast<omp::LoopNestOp>(simdOp.getWrappedLoop());
+
+  // TODO: Replace this with proper composite translation support.
+  // Currently, simd information on composite constructs is ignored, so e.g.
+  // 'do/for simd' will be treated the same as a standalone 'do/for'. This is
+  // allowed by the spec, since it's equivalent to using a SIMD length of 1.
+  if (simdOp.isComposite()) {
+    if (failed(convertIgnoredWrapper(simdOp, moduleTranslation)))
+      return failure();
+
+    return inlineConvertOmpRegions(simdOp.getRegion(), "omp.simd.region",
+                                   builder, moduleTranslation);
+  }
 
   if (failed(checkImplementationStatus(opInst)))
     return failure();
@@ -2295,6 +2253,61 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
           .failed())
     return failure();
 
+  llvm::ConstantInt *simdlen = nullptr;
+  if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
+    simdlen = builder.getInt64(simdlenVar.value());
+
+  llvm::ConstantInt *safelen = nullptr;
+  if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
+    safelen = builder.getInt64(safelenVar.value());
+
+  llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
+  llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
+  llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
+  std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
+  mlir::OperandRange operands = simdOp.getAlignedVars();
+  for (size_t i = 0; i < operands.size(); ++i) {
+    llvm::Value *alignment = nullptr;
+    llvm::Value *llvmVal = moduleTranslation.lookupValue(operands[i]);
+    llvm::Type *ty = llvmVal->getType();
+    if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignmentValues)[i])) {
+      alignment = builder.getInt64(intAttr.getInt());
+      assert(ty->isPointerTy() && "Invalid type for aligned variable");
+      assert(alignment && "Invalid alignment value");
+      auto curInsert = builder.saveIP();
+      builder.SetInsertPoint(sourceBlock);
+      llvmVal = builder.CreateLoad(ty, llvmVal);
+      builder.restoreIP(curInsert);
+      alignedVars[llvmVal] = alignment;
+    }
+  }
+
+  llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
+      simdOp.getRegion(), "omp.simd.region", builder, moduleTranslation);
+
+  if (failed(handleError(regionBlock, opInst)))
+    return failure();
+
+  builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+  llvm::CanonicalLoopInfo *loopInfo = *findCurrentLoopInfo(moduleTranslation);
+  ompBuilder->applySimd(loopInfo, alignedVars,
+                        simdOp.getIfExpr()
+                            ? moduleTranslation.lookupValue(simdOp.getIfExpr())
+                            : nullptr,
+                        order, simdlen, safelen);
+
+  return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
+                            llvmPrivateVars, privateDecls);
+}
+
+/// Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
+static LogicalResult
+convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
+                   LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+  auto loopOp = cast<omp::LoopNestOp>(opInst);
+
+  // Set up the source location value for OpenMP runtime.
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
 
   // Generator of the canonical loop body.
@@ -2316,9 +2329,13 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
 
     // Convert the body of the loop.
     builder.restoreIP(ip);
-    return convertOmpOpRegions(loopOp.getRegion(), "omp.simd.region", builder,
-                               moduleTranslation)
-        .takeError();
+    llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
+        loopOp.getRegion(), "omp.loop_nest.region", builder, moduleTranslation);
+    if (!regionBlock)
+      return regionBlock.takeError();
+
+    builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+    return llvm::Error::success();
   };
 
   // Delegate actual loop construction to the OpenMP IRBuilder.
@@ -2326,7 +2343,6 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
   // loop, i.e. it has a positive step, uses signed integer semantics.
   // Reconsider this code when the nested loop operation clearly supports more
   // cases.
-  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
   for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
     llvm::Value *lowerBound =
         moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
@@ -2348,7 +2364,7 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
     llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
         ompBuilder->createCanonicalLoop(
             loc, bodyGen, lowerBound, upperBound, step,
-            /*IsSigned=*/true, /*InclusiveStop=*/true, computeIP);
+            /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
 
     if (failed(handleError(loopResult, *loopOp)))
       return failure();
@@ -2356,49 +2372,23 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
     loopInfos.push_back(*loopResult);
   }
 
-  // Collapse loops.
-  llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
-  llvm::CanonicalLoopInfo *loopInfo =
-      ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
-
-  llvm::ConstantInt *simdlen = nullptr;
-  if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
-    simdlen = builder.getInt64(simdlenVar.value());
+  // Collapse loops. Store the insertion point because LoopInfos may get
+  // invalidated.
+  llvm::OpenMPIRBuilder::InsertPointTy afterIP =
+      loopInfos.front()->getAfterIP();
 
-  llvm::ConstantInt *safelen = nullptr;
-  if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
-    safelen = builder.getInt64(safelenVar.value());
-
-  llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
-  llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
-  llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
-  std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
-  mlir::OperandRange...
[truncated]

@skatrak
Copy link
Member Author

skatrak commented Feb 14, 2025

The first commit in this PR is #115475, since this change would otherwise trigger a bug in the translation of loop wrappers and a compiler crash. If this patch is accepted, that one will be merged first.

Copy link
Member

@Meinersbur Meinersbur left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One consequence of this was fixing an issue lowering non-inclusive omp.simd loops.

What is the issue? Off-by-one error/forgotten inclusive?

This patch centralizes that handling of the loop.

old: convertOmpWsloop,

new: convertOmpLoopNest

I still se a lot of code moved into convertOmpSimd? Why is that?

// Add a stack frame holding information about the resulting loop after
// applying transformations, to be further transformed by parent loop
// wrappers.
moduleTranslation.stackPush<OpenMPLoopInfoStackFrame>(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see stackPush only be used with SaveStack. Is it possible to hoist this up the call graph where both stackPush/stackPop to be in the same function?

The pop is conditional in convertHostOrTargetOperation, the pairing is not obviously to me.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a tricky issue. The problem with SaveStack is that it pushes and pops from within the same function. In this case, we could move those two calls to convertHostOrTargetOperation, but they would have to be only done conditionally for the outermost loop wrapper. That means we couldn't use the SaveStack class either way, but at least the push and pop calls would be located in the same function (which is certainly easier to follow).

However, if we do that, another issue that we encounter is that the current design of StackFrame handling expects them to be created once and never modified. Here, we want to set the loopInfo pointer based on the value produced while translating the inner omp.loop_nest and then use it during translation of the associated parent loop wrappers. So, we would have to create a stack frame with a null value initially and then update/create some way of accessing it in a non-const way (removing const from ModuleTranslation::stackWalk, for example).

This all boils down to the core issue of trying to pass information from inside (the nested loop) to outside (its wrappers), and the stack frame system is designed to do the opposite. It doesn't quite fit either way we try to do it, but so far it's the best I've been able to come up with (the alternative being a global variable). The implementation I proposed here is the one that doesn't introduce changes to ModuleTranslation for something that's a very unusual use of the stack frame feature.

The pop is conditional in convertHostOrTargetOperation, the pairing is not obviously to me.

For each omp.loop_nest there will be 1 or more loop wrappers. If we push a stack frame when we create the omp.loop_nest to be used up the chain, the place where we can finally pop that stack frame is when we finish processing the outermost loop wrapper. That's how the pairing is done.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having said that, I noticed that the current approach can break when other stack frames are inserted by loop wrappers, because it messes up with their order. So they can end up popping the stack frame inserted by the omp.loop_nest rather than the one they pushed prior to translating the loop.

I'm going to work on the alternative implementation where push and pop are done conditionally in convertHostOrTargetOperation, and make the necessary changes to ModuleTranslation so that there's a non-const way of accessing the stack frames.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just pushed that change, hopefully this approach makes more sense to you.

Comment on lines +4688 to +4818
.Case([&](omp::LoopNestOp) {
return convertOmpLoopNest(*op, builder, moduleTranslation);
})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what centralizes it (Ops calling convertHostOrTargetOperation to do the handling of LoopNestOp)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly. Now wrappers expect the loop nest to be translated when they call convertOmpOpRegions, and that is triggered by this path.

@@ -2348,57 +2364,31 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
ompBuilder->createCanonicalLoop(
loc, bodyGen, lowerBound, upperBound, step,
/*IsSigned=*/true, /*InclusiveStop=*/true, computeIP);
/*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The simd issue?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we weren't looking at the inclusive attribute for omp.simd.

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clean up. Just some nits of me to accompany Michael's questions.

Copy link
Member Author

@skatrak skatrak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still se a lot of code moved into convertOmpSimd? Why is that?

I think the diff is messing up the view and making it more confusing. What was basically done was moving the omp.loop_nest lowering, which was duplicated in convertOmpWsloop and convertOmpSimd, into a new convertOmpLoopNest function. Everything that was wrapper-specific remained in the respective wrapper translation function, and the only thing that got moved from convertOmpWsloop into convertOmpSimd was the composite check, which results in SIMD information be ignored (same behavior as before, but checked where I think makes more sense).

// Add a stack frame holding information about the resulting loop after
// applying transformations, to be further transformed by parent loop
// wrappers.
moduleTranslation.stackPush<OpenMPLoopInfoStackFrame>(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a tricky issue. The problem with SaveStack is that it pushes and pops from within the same function. In this case, we could move those two calls to convertHostOrTargetOperation, but they would have to be only done conditionally for the outermost loop wrapper. That means we couldn't use the SaveStack class either way, but at least the push and pop calls would be located in the same function (which is certainly easier to follow).

However, if we do that, another issue that we encounter is that the current design of StackFrame handling expects them to be created once and never modified. Here, we want to set the loopInfo pointer based on the value produced while translating the inner omp.loop_nest and then use it during translation of the associated parent loop wrappers. So, we would have to create a stack frame with a null value initially and then update/create some way of accessing it in a non-const way (removing const from ModuleTranslation::stackWalk, for example).

This all boils down to the core issue of trying to pass information from inside (the nested loop) to outside (its wrappers), and the stack frame system is designed to do the opposite. It doesn't quite fit either way we try to do it, but so far it's the best I've been able to come up with (the alternative being a global variable). The implementation I proposed here is the one that doesn't introduce changes to ModuleTranslation for something that's a very unusual use of the stack frame feature.

The pop is conditional in convertHostOrTargetOperation, the pairing is not obviously to me.

For each omp.loop_nest there will be 1 or more loop wrappers. If we push a stack frame when we create the omp.loop_nest to be used up the chain, the place where we can finally pop that stack frame is when we finish processing the outermost loop wrapper. That's how the pairing is done.

// Add a stack frame holding information about the resulting loop after
// applying transformations, to be further transformed by parent loop
// wrappers.
moduleTranslation.stackPush<OpenMPLoopInfoStackFrame>(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having said that, I noticed that the current approach can break when other stack frames are inserted by loop wrappers, because it messes up with their order. So they can end up popping the stack frame inserted by the omp.loop_nest rather than the one they pushed prior to translating the loop.

I'm going to work on the alternative implementation where push and pop are done conditionally in convertHostOrTargetOperation, and make the necessary changes to ModuleTranslation so that there's a non-const way of accessing the stack frames.

@@ -2348,57 +2364,31 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
ompBuilder->createCanonicalLoop(
loc, bodyGen, lowerBound, upperBound, step,
/*IsSigned=*/true, /*InclusiveStop=*/true, computeIP);
/*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we weren't looking at the inclusive attribute for omp.simd.

Comment on lines +4688 to +4818
.Case([&](omp::LoopNestOp) {
return convertOmpLoopNest(*op, builder, moduleTranslation);
})
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly. Now wrappers expect the loop nest to be translated when they call convertOmpOpRegions, and that is triggered by this path.

Copy link
Member

@Meinersbur Meinersbur left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having the stack push/pop withe the same condition (isOutermostLoopWrapper) makes the balancing obvious, thanks.

LGTM

Copy link
Contributor

@jsjodin jsjodin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

This patch refactors the translation of `omp.loop_nest` operations into LLVM IR
so that it is handled similarly to other operations. Before this change, the
responsibility of translating the loop nest fell into each loop wrapper,
causing code duplication. This patch centralizes that handling of the loop.
One consequence of this was fixing an issue lowering non-inclusive `omp.simd`
loops.

As a result, it is now expected that the handling of composite constructs is
performed collaboratively among translating functions for each operation
involved. At the moment, only `do/for simd` is supported by ignoring SIMD
information, and this behavior is preserved.

The translation of loop wrapper operations needs access to the
`llvm::CanonicalLoopInfo` loop information structure in order to apply
transformations to it. This is now created in the nested call to
`convertOmpLoopNest`, so it needs to be passed up to all associated loop
wrapper translation functions. This is done via the creation of an
`OpenMPLoopInfoStackFrame` within `convertHostOrTargetOperation`, associated to
the outermost loop wrapper. This structure is updated by `convertOmpLoopNest`,
making the result available to all loop wrappers after their body has been
translated.
@skatrak skatrak force-pushed the users/skatrak/target-spmd-01-loop-refactor branch from 0ef4436 to 343f573 Compare February 24, 2025 12:13
@skatrak skatrak merged commit c83bdc7 into main Feb 24, 2025
11 checks passed
@skatrak skatrak deleted the users/skatrak/target-spmd-01-loop-refactor branch February 24, 2025 13:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants