-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][OpenMP] Normalize lowering of omp.loop_nest #127217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-flang-openmp @llvm/pr-subscribers-mlir-openmp Author: Sergio Afonso (skatrak) ChangesThis patch refactors the translation of As a result, it is now expected that the handling of composite constructs is performed collaboratively among translating functions for each operation involved. At the moment, only The translation of loop wrapper operations need access to the Patch is 36.07 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/127217.diff 3 Files Affected:
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 51a3cbdbb5e7f..a5ff3eff6439f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -75,6 +75,19 @@ class OpenMPAllocaStackFrame
llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
};
+/// Stack frame to hold a \see llvm::CanonicalLoopInfo representing the
+/// collapsed canonical loop information corresponding to an \c omp.loop_nest
+/// operation.
+class OpenMPLoopInfoStackFrame
+ : public LLVM::ModuleTranslation::StackFrameBase<OpenMPLoopInfoStackFrame> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
+
+ explicit OpenMPLoopInfoStackFrame(llvm::CanonicalLoopInfo *loopInfo)
+ : loopInfo(loopInfo) {}
+ llvm::CanonicalLoopInfo *loopInfo;
+};
+
/// Custom error class to signal translation errors that don't need reporting,
/// since encountering them will have already triggered relevant error messages.
///
@@ -372,6 +385,20 @@ findAllocaInsertPoint(llvm::IRBuilderBase &builder,
&funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
}
+/// Find the loop information structure for the loop nest being translated. It
+/// will not return a value unless called from the translation function for
+/// a loop wrapper operation after successfully translating its body.
+static std::optional<llvm::CanonicalLoopInfo *>
+findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation) {
+ std::optional<llvm::CanonicalLoopInfo *> loopInfo;
+ moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
+ [&](const OpenMPLoopInfoStackFrame &frame) {
+ loopInfo = frame.loopInfo;
+ return WalkResult::interrupt();
+ });
+ return loopInfo;
+}
+
/// Converts the given region that appears within an OpenMP dialect operation to
/// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
/// region, and a branch from any block with an successor-less OpenMP terminator
@@ -381,6 +408,8 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
+ bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.getParentOp());
+
llvm::BasicBlock *continuationBlock =
splitBB(builder, true, "omp.region.cont");
llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
@@ -397,30 +426,34 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
// Terminators (namely YieldOp) may be forwarding values to the region that
// need to be available in the continuation block. Collect the types of these
- // operands in preparation of creating PHI nodes.
+ // operands in preparation of creating PHI nodes. This is skipped for loop
+ // wrapper operations, for which we know in advance they have no terminators.
SmallVector<llvm::Type *> continuationBlockPHITypes;
- bool operandsProcessed = false;
unsigned numYields = 0;
- for (Block &bb : region.getBlocks()) {
- if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
- if (!operandsProcessed) {
- for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
- continuationBlockPHITypes.push_back(
- moduleTranslation.convertType(yield->getOperand(i).getType()));
- }
- operandsProcessed = true;
- } else {
- assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
- "mismatching number of values yielded from the region");
- for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
- llvm::Type *operandType =
- moduleTranslation.convertType(yield->getOperand(i).getType());
- (void)operandType;
- assert(continuationBlockPHITypes[i] == operandType &&
- "values of mismatching types yielded from the region");
+
+ if (!isLoopWrapper) {
+ bool operandsProcessed = false;
+ for (Block &bb : region.getBlocks()) {
+ if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
+ if (!operandsProcessed) {
+ for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
+ continuationBlockPHITypes.push_back(
+ moduleTranslation.convertType(yield->getOperand(i).getType()));
+ }
+ operandsProcessed = true;
+ } else {
+ assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
+ "mismatching number of values yielded from the region");
+ for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
+ llvm::Type *operandType =
+ moduleTranslation.convertType(yield->getOperand(i).getType());
+ (void)operandType;
+ assert(continuationBlockPHITypes[i] == operandType &&
+ "values of mismatching types yielded from the region");
+ }
}
+ numYields++;
}
- numYields++;
}
}
@@ -458,6 +491,13 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
return llvm::make_error<PreviouslyReportedError>();
+ // Create a direct branch here for loop wrappers to prevent their lack of a
+ // terminator from causing a crash below.
+ if (isLoopWrapper) {
+ builder.CreateBr(continuationBlock);
+ continue;
+ }
+
// Special handling for `omp.yield` and `omp.terminator` (we may have more
// than one): they return the control to the parent OpenMP dialect operation
// so replace them with the branch to the continuation block. We handle this
@@ -509,7 +549,7 @@ static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
/// This must be called after block arguments of parent wrappers have already
/// been mapped to LLVM IR values.
static LogicalResult
-convertIgnoredWrapper(omp::LoopWrapperInterface &opInst,
+convertIgnoredWrapper(omp::LoopWrapperInterface opInst,
LLVM::ModuleTranslation &moduleTranslation) {
// Map block arguments directly to the LLVM value associated to the
// corresponding operand. This is semantically equivalent to this wrapper not
@@ -531,34 +571,10 @@ convertIgnoredWrapper(omp::LoopWrapperInterface &opInst,
return success();
})
.Default([&](Operation *op) {
- return op->emitError() << "cannot ignore nested wrapper";
+ return op->emitError() << "cannot ignore wrapper";
});
}
-/// Helper function to call \c convertIgnoredWrapper() for all wrappers of the
-/// given \c loopOp nested inside of \c parentOp. This has the effect of mapping
-/// entry block arguments defined by these operations to outside values.
-///
-/// It must be called after block arguments of \c parentOp have already been
-/// mapped themselves.
-static LogicalResult
-convertIgnoredWrappers(omp::LoopNestOp loopOp,
- omp::LoopWrapperInterface parentOp,
- LLVM::ModuleTranslation &moduleTranslation) {
- SmallVector<omp::LoopWrapperInterface> wrappers;
- loopOp.gatherWrappers(wrappers);
-
- // Process wrappers nested inside of `parentOp` from outermost to innermost.
- for (auto it =
- std::next(std::find(wrappers.rbegin(), wrappers.rend(), parentOp));
- it != wrappers.rend(); ++it) {
- if (failed(convertIgnoredWrapper(*it, moduleTranslation)))
- return failure();
- }
-
- return success();
-}
-
/// Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult
convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -1876,6 +1892,7 @@ convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder,
static LogicalResult
convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
auto wsloopOp = cast<omp::WsloopOp>(opInst);
if (failed(checkImplementationStatus(opInst)))
return failure();
@@ -1956,90 +1973,25 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
reductionVariableMap, isByRef, deferredStores)))
return failure();
- // TODO: Replace this with proper composite translation support.
- // Currently, all nested wrappers are ignored, so 'do/for simd' will be
- // treated the same as a standalone 'do/for'. This is allowed by the spec,
- // since it's equivalent to always using a SIMD length of 1.
- if (failed(convertIgnoredWrappers(loopOp, wsloopOp, moduleTranslation)))
- return failure();
-
- // Set up the source location value for OpenMP runtime.
- llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
-
- // Generator of the canonical loop body.
- SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
- SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
- auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
- llvm::Value *iv) -> llvm::Error {
- // Make sure further conversions know about the induction variable.
- moduleTranslation.mapValue(
- loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
-
- // Capture the body insertion point for use in nested loops. BodyIP of the
- // CanonicalLoopInfo always points to the beginning of the entry block of
- // the body.
- bodyInsertPoints.push_back(ip);
-
- if (loopInfos.size() != loopOp.getNumLoops() - 1)
- return llvm::Error::success();
-
- // Convert the body of the loop.
- builder.restoreIP(ip);
- return convertOmpOpRegions(loopOp.getRegion(), "omp.wsloop.region", builder,
- moduleTranslation)
- .takeError();
- };
-
- // Delegate actual loop construction to the OpenMP IRBuilder.
- // TODO: this currently assumes omp.loop_nest is semantically similar to SCF
- // loop, i.e. it has a positive step, uses signed integer semantics.
- // Reconsider this code when the nested loop operation clearly supports more
- // cases.
- llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
- for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
- llvm::Value *lowerBound =
- moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
- llvm::Value *upperBound =
- moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]);
- llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]);
-
- // Make sure loop trip count are emitted in the preheader of the outermost
- // loop at the latest so that they are all available for the new collapsed
- // loop will be created below.
- llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
- llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
- if (i != 0) {
- loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back());
- computeIP = loopInfos.front()->getPreheaderIP();
- }
-
- llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
- ompBuilder->createCanonicalLoop(
- loc, bodyGen, lowerBound, upperBound, step,
- /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
-
- if (failed(handleError(loopResult, *loopOp)))
- return failure();
-
- loopInfos.push_back(*loopResult);
- }
-
- // Collapse loops. Store the insertion point because LoopInfos may get
- // invalidated.
- llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
- llvm::CanonicalLoopInfo *loopInfo =
- ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
-
- allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
-
// TODO: Handle doacross loops when the ordered clause has a parameter.
bool isOrdered = wsloopOp.getOrdered().has_value();
std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
bool isSimd = wsloopOp.getScheduleSimd();
+ bool loopNeedsBarrier = !wsloopOp.getNowait();
+
+ llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+ llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
+ wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
+
+ if (failed(handleError(regionBlock, opInst)))
+ return failure();
+
+ builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+ llvm::CanonicalLoopInfo *loopInfo = *findCurrentLoopInfo(moduleTranslation);
llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
ompBuilder->applyWorkshareLoop(
- ompLoc.DL, loopInfo, allocaIP, !wsloopOp.getNowait(),
+ ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
convertToScheduleKind(schedule), chunk, isSimd,
scheduleMod == omp::ScheduleModifier::monotonic,
scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered);
@@ -2047,12 +1999,6 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
if (failed(handleError(wsloopIP, opInst)))
return failure();
- // Continue building IR after the loop. Note that the LoopInfo returned by
- // `collapseLoops` points inside the outermost loop and is intended for
- // potential further loop transformations. Use the insertion point stored
- // before collapsing loops instead.
- builder.restoreIP(afterIP);
-
// Process the reductions if required.
if (failed(createReductionsAndCleanup(wsloopOp, builder, moduleTranslation,
allocaIP, reductionDecls,
@@ -2261,8 +2207,20 @@ convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
static LogicalResult
convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
auto simdOp = cast<omp::SimdOp>(opInst);
- auto loopOp = cast<omp::LoopNestOp>(simdOp.getWrappedLoop());
+
+ // TODO: Replace this with proper composite translation support.
+ // Currently, simd information on composite constructs is ignored, so e.g.
+ // 'do/for simd' will be treated the same as a standalone 'do/for'. This is
+ // allowed by the spec, since it's equivalent to using a SIMD length of 1.
+ if (simdOp.isComposite()) {
+ if (failed(convertIgnoredWrapper(simdOp, moduleTranslation)))
+ return failure();
+
+ return inlineConvertOmpRegions(simdOp.getRegion(), "omp.simd.region",
+ builder, moduleTranslation);
+ }
if (failed(checkImplementationStatus(opInst)))
return failure();
@@ -2295,6 +2253,61 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
.failed())
return failure();
+ llvm::ConstantInt *simdlen = nullptr;
+ if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
+ simdlen = builder.getInt64(simdlenVar.value());
+
+ llvm::ConstantInt *safelen = nullptr;
+ if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
+ safelen = builder.getInt64(safelenVar.value());
+
+ llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
+ llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
+ llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
+ std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
+ mlir::OperandRange operands = simdOp.getAlignedVars();
+ for (size_t i = 0; i < operands.size(); ++i) {
+ llvm::Value *alignment = nullptr;
+ llvm::Value *llvmVal = moduleTranslation.lookupValue(operands[i]);
+ llvm::Type *ty = llvmVal->getType();
+ if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignmentValues)[i])) {
+ alignment = builder.getInt64(intAttr.getInt());
+ assert(ty->isPointerTy() && "Invalid type for aligned variable");
+ assert(alignment && "Invalid alignment value");
+ auto curInsert = builder.saveIP();
+ builder.SetInsertPoint(sourceBlock);
+ llvmVal = builder.CreateLoad(ty, llvmVal);
+ builder.restoreIP(curInsert);
+ alignedVars[llvmVal] = alignment;
+ }
+ }
+
+ llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
+ simdOp.getRegion(), "omp.simd.region", builder, moduleTranslation);
+
+ if (failed(handleError(regionBlock, opInst)))
+ return failure();
+
+ builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+ llvm::CanonicalLoopInfo *loopInfo = *findCurrentLoopInfo(moduleTranslation);
+ ompBuilder->applySimd(loopInfo, alignedVars,
+ simdOp.getIfExpr()
+ ? moduleTranslation.lookupValue(simdOp.getIfExpr())
+ : nullptr,
+ order, simdlen, safelen);
+
+ return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
+ llvmPrivateVars, privateDecls);
+}
+
+/// Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
+static LogicalResult
+convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+ auto loopOp = cast<omp::LoopNestOp>(opInst);
+
+ // Set up the source location value for OpenMP runtime.
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
// Generator of the canonical loop body.
@@ -2316,9 +2329,13 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
// Convert the body of the loop.
builder.restoreIP(ip);
- return convertOmpOpRegions(loopOp.getRegion(), "omp.simd.region", builder,
- moduleTranslation)
- .takeError();
+ llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
+ loopOp.getRegion(), "omp.loop_nest.region", builder, moduleTranslation);
+ if (!regionBlock)
+ return regionBlock.takeError();
+
+ builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+ return llvm::Error::success();
};
// Delegate actual loop construction to the OpenMP IRBuilder.
@@ -2326,7 +2343,6 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
// loop, i.e. it has a positive step, uses signed integer semantics.
// Reconsider this code when the nested loop operation clearly supports more
// cases.
- llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
llvm::Value *lowerBound =
moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
@@ -2348,7 +2364,7 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
ompBuilder->createCanonicalLoop(
loc, bodyGen, lowerBound, upperBound, step,
- /*IsSigned=*/true, /*InclusiveStop=*/true, computeIP);
+ /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
if (failed(handleError(loopResult, *loopOp)))
return failure();
@@ -2356,49 +2372,23 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
loopInfos.push_back(*loopResult);
}
- // Collapse loops.
- llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
- llvm::CanonicalLoopInfo *loopInfo =
- ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
-
- llvm::ConstantInt *simdlen = nullptr;
- if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
- simdlen = builder.getInt64(simdlenVar.value());
+ // Collapse loops. Store the insertion point because LoopInfos may get
+ // invalidated.
+ llvm::OpenMPIRBuilder::InsertPointTy afterIP =
+ loopInfos.front()->getAfterIP();
- llvm::ConstantInt *safelen = nullptr;
- if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
- safelen = builder.getInt64(safelenVar.value());
-
- llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
- llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
- llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
- std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
- mlir::OperandRange...
[truncated]
|
@llvm/pr-subscribers-mlir-llvm Author: Sergio Afonso (skatrak) ChangesThis patch refactors the translation of As a result, it is now expected that the handling of composite constructs is performed collaboratively among translating functions for each operation involved. At the moment, only The translation of loop wrapper operations need access to the Patch is 36.07 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/127217.diff 3 Files Affected:
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 51a3cbdbb5e7f..a5ff3eff6439f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -75,6 +75,19 @@ class OpenMPAllocaStackFrame
llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
};
+/// Stack frame to hold a \see llvm::CanonicalLoopInfo representing the
+/// collapsed canonical loop information corresponding to an \c omp.loop_nest
+/// operation.
+class OpenMPLoopInfoStackFrame
+ : public LLVM::ModuleTranslation::StackFrameBase<OpenMPLoopInfoStackFrame> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
+
+ explicit OpenMPLoopInfoStackFrame(llvm::CanonicalLoopInfo *loopInfo)
+ : loopInfo(loopInfo) {}
+ llvm::CanonicalLoopInfo *loopInfo;
+};
+
/// Custom error class to signal translation errors that don't need reporting,
/// since encountering them will have already triggered relevant error messages.
///
@@ -372,6 +385,20 @@ findAllocaInsertPoint(llvm::IRBuilderBase &builder,
&funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
}
+/// Find the loop information structure for the loop nest being translated. It
+/// will not return a value unless called from the translation function for
+/// a loop wrapper operation after successfully translating its body.
+static std::optional<llvm::CanonicalLoopInfo *>
+findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation) {
+ std::optional<llvm::CanonicalLoopInfo *> loopInfo;
+ moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
+ [&](const OpenMPLoopInfoStackFrame &frame) {
+ loopInfo = frame.loopInfo;
+ return WalkResult::interrupt();
+ });
+ return loopInfo;
+}
+
/// Converts the given region that appears within an OpenMP dialect operation to
/// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
/// region, and a branch from any block with an successor-less OpenMP terminator
@@ -381,6 +408,8 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
+ bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.getParentOp());
+
llvm::BasicBlock *continuationBlock =
splitBB(builder, true, "omp.region.cont");
llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
@@ -397,30 +426,34 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
// Terminators (namely YieldOp) may be forwarding values to the region that
// need to be available in the continuation block. Collect the types of these
- // operands in preparation of creating PHI nodes.
+ // operands in preparation of creating PHI nodes. This is skipped for loop
+ // wrapper operations, for which we know in advance they have no terminators.
SmallVector<llvm::Type *> continuationBlockPHITypes;
- bool operandsProcessed = false;
unsigned numYields = 0;
- for (Block &bb : region.getBlocks()) {
- if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
- if (!operandsProcessed) {
- for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
- continuationBlockPHITypes.push_back(
- moduleTranslation.convertType(yield->getOperand(i).getType()));
- }
- operandsProcessed = true;
- } else {
- assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
- "mismatching number of values yielded from the region");
- for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
- llvm::Type *operandType =
- moduleTranslation.convertType(yield->getOperand(i).getType());
- (void)operandType;
- assert(continuationBlockPHITypes[i] == operandType &&
- "values of mismatching types yielded from the region");
+
+ if (!isLoopWrapper) {
+ bool operandsProcessed = false;
+ for (Block &bb : region.getBlocks()) {
+ if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
+ if (!operandsProcessed) {
+ for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
+ continuationBlockPHITypes.push_back(
+ moduleTranslation.convertType(yield->getOperand(i).getType()));
+ }
+ operandsProcessed = true;
+ } else {
+ assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
+ "mismatching number of values yielded from the region");
+ for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
+ llvm::Type *operandType =
+ moduleTranslation.convertType(yield->getOperand(i).getType());
+ (void)operandType;
+ assert(continuationBlockPHITypes[i] == operandType &&
+ "values of mismatching types yielded from the region");
+ }
}
+ numYields++;
}
- numYields++;
}
}
@@ -458,6 +491,13 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
return llvm::make_error<PreviouslyReportedError>();
+ // Create a direct branch here for loop wrappers to prevent their lack of a
+ // terminator from causing a crash below.
+ if (isLoopWrapper) {
+ builder.CreateBr(continuationBlock);
+ continue;
+ }
+
// Special handling for `omp.yield` and `omp.terminator` (we may have more
// than one): they return the control to the parent OpenMP dialect operation
// so replace them with the branch to the continuation block. We handle this
@@ -509,7 +549,7 @@ static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
/// This must be called after block arguments of parent wrappers have already
/// been mapped to LLVM IR values.
static LogicalResult
-convertIgnoredWrapper(omp::LoopWrapperInterface &opInst,
+convertIgnoredWrapper(omp::LoopWrapperInterface opInst,
LLVM::ModuleTranslation &moduleTranslation) {
// Map block arguments directly to the LLVM value associated to the
// corresponding operand. This is semantically equivalent to this wrapper not
@@ -531,34 +571,10 @@ convertIgnoredWrapper(omp::LoopWrapperInterface &opInst,
return success();
})
.Default([&](Operation *op) {
- return op->emitError() << "cannot ignore nested wrapper";
+ return op->emitError() << "cannot ignore wrapper";
});
}
-/// Helper function to call \c convertIgnoredWrapper() for all wrappers of the
-/// given \c loopOp nested inside of \c parentOp. This has the effect of mapping
-/// entry block arguments defined by these operations to outside values.
-///
-/// It must be called after block arguments of \c parentOp have already been
-/// mapped themselves.
-static LogicalResult
-convertIgnoredWrappers(omp::LoopNestOp loopOp,
- omp::LoopWrapperInterface parentOp,
- LLVM::ModuleTranslation &moduleTranslation) {
- SmallVector<omp::LoopWrapperInterface> wrappers;
- loopOp.gatherWrappers(wrappers);
-
- // Process wrappers nested inside of `parentOp` from outermost to innermost.
- for (auto it =
- std::next(std::find(wrappers.rbegin(), wrappers.rend(), parentOp));
- it != wrappers.rend(); ++it) {
- if (failed(convertIgnoredWrapper(*it, moduleTranslation)))
- return failure();
- }
-
- return success();
-}
-
/// Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult
convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -1876,6 +1892,7 @@ convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder,
static LogicalResult
convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
auto wsloopOp = cast<omp::WsloopOp>(opInst);
if (failed(checkImplementationStatus(opInst)))
return failure();
@@ -1956,90 +1973,25 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
reductionVariableMap, isByRef, deferredStores)))
return failure();
- // TODO: Replace this with proper composite translation support.
- // Currently, all nested wrappers are ignored, so 'do/for simd' will be
- // treated the same as a standalone 'do/for'. This is allowed by the spec,
- // since it's equivalent to always using a SIMD length of 1.
- if (failed(convertIgnoredWrappers(loopOp, wsloopOp, moduleTranslation)))
- return failure();
-
- // Set up the source location value for OpenMP runtime.
- llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
-
- // Generator of the canonical loop body.
- SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
- SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
- auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
- llvm::Value *iv) -> llvm::Error {
- // Make sure further conversions know about the induction variable.
- moduleTranslation.mapValue(
- loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
-
- // Capture the body insertion point for use in nested loops. BodyIP of the
- // CanonicalLoopInfo always points to the beginning of the entry block of
- // the body.
- bodyInsertPoints.push_back(ip);
-
- if (loopInfos.size() != loopOp.getNumLoops() - 1)
- return llvm::Error::success();
-
- // Convert the body of the loop.
- builder.restoreIP(ip);
- return convertOmpOpRegions(loopOp.getRegion(), "omp.wsloop.region", builder,
- moduleTranslation)
- .takeError();
- };
-
- // Delegate actual loop construction to the OpenMP IRBuilder.
- // TODO: this currently assumes omp.loop_nest is semantically similar to SCF
- // loop, i.e. it has a positive step, uses signed integer semantics.
- // Reconsider this code when the nested loop operation clearly supports more
- // cases.
- llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
- for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
- llvm::Value *lowerBound =
- moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
- llvm::Value *upperBound =
- moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]);
- llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]);
-
- // Make sure loop trip count are emitted in the preheader of the outermost
- // loop at the latest so that they are all available for the new collapsed
- // loop will be created below.
- llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
- llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
- if (i != 0) {
- loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back());
- computeIP = loopInfos.front()->getPreheaderIP();
- }
-
- llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
- ompBuilder->createCanonicalLoop(
- loc, bodyGen, lowerBound, upperBound, step,
- /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
-
- if (failed(handleError(loopResult, *loopOp)))
- return failure();
-
- loopInfos.push_back(*loopResult);
- }
-
- // Collapse loops. Store the insertion point because LoopInfos may get
- // invalidated.
- llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
- llvm::CanonicalLoopInfo *loopInfo =
- ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
-
- allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
-
// TODO: Handle doacross loops when the ordered clause has a parameter.
bool isOrdered = wsloopOp.getOrdered().has_value();
std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
bool isSimd = wsloopOp.getScheduleSimd();
+ bool loopNeedsBarrier = !wsloopOp.getNowait();
+
+ llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+ llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
+ wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
+
+ if (failed(handleError(regionBlock, opInst)))
+ return failure();
+
+ builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+ llvm::CanonicalLoopInfo *loopInfo = *findCurrentLoopInfo(moduleTranslation);
llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
ompBuilder->applyWorkshareLoop(
- ompLoc.DL, loopInfo, allocaIP, !wsloopOp.getNowait(),
+ ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
convertToScheduleKind(schedule), chunk, isSimd,
scheduleMod == omp::ScheduleModifier::monotonic,
scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered);
@@ -2047,12 +1999,6 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
if (failed(handleError(wsloopIP, opInst)))
return failure();
- // Continue building IR after the loop. Note that the LoopInfo returned by
- // `collapseLoops` points inside the outermost loop and is intended for
- // potential further loop transformations. Use the insertion point stored
- // before collapsing loops instead.
- builder.restoreIP(afterIP);
-
// Process the reductions if required.
if (failed(createReductionsAndCleanup(wsloopOp, builder, moduleTranslation,
allocaIP, reductionDecls,
@@ -2261,8 +2207,20 @@ convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
static LogicalResult
convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
auto simdOp = cast<omp::SimdOp>(opInst);
- auto loopOp = cast<omp::LoopNestOp>(simdOp.getWrappedLoop());
+
+ // TODO: Replace this with proper composite translation support.
+ // Currently, simd information on composite constructs is ignored, so e.g.
+ // 'do/for simd' will be treated the same as a standalone 'do/for'. This is
+ // allowed by the spec, since it's equivalent to using a SIMD length of 1.
+ if (simdOp.isComposite()) {
+ if (failed(convertIgnoredWrapper(simdOp, moduleTranslation)))
+ return failure();
+
+ return inlineConvertOmpRegions(simdOp.getRegion(), "omp.simd.region",
+ builder, moduleTranslation);
+ }
if (failed(checkImplementationStatus(opInst)))
return failure();
@@ -2295,6 +2253,61 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
.failed())
return failure();
+ llvm::ConstantInt *simdlen = nullptr;
+ if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
+ simdlen = builder.getInt64(simdlenVar.value());
+
+ llvm::ConstantInt *safelen = nullptr;
+ if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
+ safelen = builder.getInt64(safelenVar.value());
+
+ llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
+ llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
+ llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
+ std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
+ mlir::OperandRange operands = simdOp.getAlignedVars();
+ for (size_t i = 0; i < operands.size(); ++i) {
+ llvm::Value *alignment = nullptr;
+ llvm::Value *llvmVal = moduleTranslation.lookupValue(operands[i]);
+ llvm::Type *ty = llvmVal->getType();
+ if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignmentValues)[i])) {
+ alignment = builder.getInt64(intAttr.getInt());
+ assert(ty->isPointerTy() && "Invalid type for aligned variable");
+ assert(alignment && "Invalid alignment value");
+ auto curInsert = builder.saveIP();
+ builder.SetInsertPoint(sourceBlock);
+ llvmVal = builder.CreateLoad(ty, llvmVal);
+ builder.restoreIP(curInsert);
+ alignedVars[llvmVal] = alignment;
+ }
+ }
+
+ llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
+ simdOp.getRegion(), "omp.simd.region", builder, moduleTranslation);
+
+ if (failed(handleError(regionBlock, opInst)))
+ return failure();
+
+ builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+ llvm::CanonicalLoopInfo *loopInfo = *findCurrentLoopInfo(moduleTranslation);
+ ompBuilder->applySimd(loopInfo, alignedVars,
+ simdOp.getIfExpr()
+ ? moduleTranslation.lookupValue(simdOp.getIfExpr())
+ : nullptr,
+ order, simdlen, safelen);
+
+ return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
+ llvmPrivateVars, privateDecls);
+}
+
+/// Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
+static LogicalResult
+convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+ auto loopOp = cast<omp::LoopNestOp>(opInst);
+
+ // Set up the source location value for OpenMP runtime.
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
// Generator of the canonical loop body.
@@ -2316,9 +2329,13 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
// Convert the body of the loop.
builder.restoreIP(ip);
- return convertOmpOpRegions(loopOp.getRegion(), "omp.simd.region", builder,
- moduleTranslation)
- .takeError();
+ llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
+ loopOp.getRegion(), "omp.loop_nest.region", builder, moduleTranslation);
+ if (!regionBlock)
+ return regionBlock.takeError();
+
+ builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+ return llvm::Error::success();
};
// Delegate actual loop construction to the OpenMP IRBuilder.
@@ -2326,7 +2343,6 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
// loop, i.e. it has a positive step, uses signed integer semantics.
// Reconsider this code when the nested loop operation clearly supports more
// cases.
- llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
llvm::Value *lowerBound =
moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
@@ -2348,7 +2364,7 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
ompBuilder->createCanonicalLoop(
loc, bodyGen, lowerBound, upperBound, step,
- /*IsSigned=*/true, /*InclusiveStop=*/true, computeIP);
+ /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
if (failed(handleError(loopResult, *loopOp)))
return failure();
@@ -2356,49 +2372,23 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
loopInfos.push_back(*loopResult);
}
- // Collapse loops.
- llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
- llvm::CanonicalLoopInfo *loopInfo =
- ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
-
- llvm::ConstantInt *simdlen = nullptr;
- if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
- simdlen = builder.getInt64(simdlenVar.value());
+ // Collapse loops. Store the insertion point because LoopInfos may get
+ // invalidated.
+ llvm::OpenMPIRBuilder::InsertPointTy afterIP =
+ loopInfos.front()->getAfterIP();
- llvm::ConstantInt *safelen = nullptr;
- if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
- safelen = builder.getInt64(safelenVar.value());
-
- llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
- llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
- llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
- std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
- mlir::OperandRange...
[truncated]
|
The first commit in this PR is #115475, since this change would otherwise trigger a bug in the translation of loop wrappers and a compiler crash. If this patch is accepted, that one will be merged first. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One consequence of this was fixing an issue lowering non-inclusive omp.simd loops.
What is the issue? Off-by-one error/forgotten inclusive
?
This patch centralizes that handling of the loop.
old: convertOmpWsloop,
new: convertOmpLoopNest
I still se a lot of code moved into convertOmpSimd
? Why is that?
// Add a stack frame holding information about the resulting loop after | ||
// applying transformations, to be further transformed by parent loop | ||
// wrappers. | ||
moduleTranslation.stackPush<OpenMPLoopInfoStackFrame>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see stackPush
only be used with SaveStack
. Is it possible to hoist this up the call graph where both stackPush
/stackPop
to be in the same function?
The pop is conditional in convertHostOrTargetOperation
, the pairing is not obviously to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a tricky issue. The problem with SaveStack
is that it pushes and pops from within the same function. In this case, we could move those two calls to convertHostOrTargetOperation
, but they would have to be only done conditionally for the outermost loop wrapper. That means we couldn't use the SaveStack
class either way, but at least the push and pop calls would be located in the same function (which is certainly easier to follow).
However, if we do that, another issue that we encounter is that the current design of StackFrame
handling expects them to be created once and never modified. Here, we want to set the loopInfo
pointer based on the value produced while translating the inner omp.loop_nest
and then use it during translation of the associated parent loop wrappers. So, we would have to create a stack frame with a null
value initially and then update/create some way of accessing it in a non-const way (removing const
from ModuleTranslation::stackWalk
, for example).
This all boils down to the core issue of trying to pass information from inside (the nested loop) to outside (its wrappers), and the stack frame system is designed to do the opposite. It doesn't quite fit either way we try to do it, but so far it's the best I've been able to come up with (the alternative being a global variable). The implementation I proposed here is the one that doesn't introduce changes to ModuleTranslation
for something that's a very unusual use of the stack frame feature.
The pop is conditional in
convertHostOrTargetOperation
, the pairing is not obviously to me.
For each omp.loop_nest
there will be 1 or more loop wrappers. If we push a stack frame when we create the omp.loop_nest
to be used up the chain, the place where we can finally pop that stack frame is when we finish processing the outermost loop wrapper. That's how the pairing is done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having said that, I noticed that the current approach can break when other stack frames are inserted by loop wrappers, because it messes up with their order. So they can end up popping the stack frame inserted by the omp.loop_nest
rather than the one they pushed prior to translating the loop.
I'm going to work on the alternative implementation where push
and pop
are done conditionally in convertHostOrTargetOperation
, and make the necessary changes to ModuleTranslation
so that there's a non-const way of accessing the stack frames.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just pushed that change, hopefully this approach makes more sense to you.
.Case([&](omp::LoopNestOp) { | ||
return convertOmpLoopNest(*op, builder, moduleTranslation); | ||
}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is what centralizes it (Ops calling convertHostOrTargetOperation
to do the handling of LoopNestOp
)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly. Now wrappers expect the loop nest to be translated when they call convertOmpOpRegions
, and that is triggered by this path.
@@ -2348,57 +2364,31 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, | |||
llvm::Expected<llvm::CanonicalLoopInfo *> loopResult = | |||
ompBuilder->createCanonicalLoop( | |||
loc, bodyGen, lowerBound, upperBound, step, | |||
/*IsSigned=*/true, /*InclusiveStop=*/true, computeIP); | |||
/*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The simd issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we weren't looking at the inclusive
attribute for omp.simd
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clean up. Just some nits of me to accompany Michael's questions.
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still se a lot of code moved into
convertOmpSimd
? Why is that?
I think the diff is messing up the view and making it more confusing. What was basically done was moving the omp.loop_nest
lowering, which was duplicated in convertOmpWsloop
and convertOmpSimd
, into a new convertOmpLoopNest
function. Everything that was wrapper-specific remained in the respective wrapper translation function, and the only thing that got moved from convertOmpWsloop
into convertOmpSimd
was the composite check, which results in SIMD information be ignored (same behavior as before, but checked where I think makes more sense).
// Add a stack frame holding information about the resulting loop after | ||
// applying transformations, to be further transformed by parent loop | ||
// wrappers. | ||
moduleTranslation.stackPush<OpenMPLoopInfoStackFrame>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a tricky issue. The problem with SaveStack
is that it pushes and pops from within the same function. In this case, we could move those two calls to convertHostOrTargetOperation
, but they would have to be only done conditionally for the outermost loop wrapper. That means we couldn't use the SaveStack
class either way, but at least the push and pop calls would be located in the same function (which is certainly easier to follow).
However, if we do that, another issue that we encounter is that the current design of StackFrame
handling expects them to be created once and never modified. Here, we want to set the loopInfo
pointer based on the value produced while translating the inner omp.loop_nest
and then use it during translation of the associated parent loop wrappers. So, we would have to create a stack frame with a null
value initially and then update/create some way of accessing it in a non-const way (removing const
from ModuleTranslation::stackWalk
, for example).
This all boils down to the core issue of trying to pass information from inside (the nested loop) to outside (its wrappers), and the stack frame system is designed to do the opposite. It doesn't quite fit either way we try to do it, but so far it's the best I've been able to come up with (the alternative being a global variable). The implementation I proposed here is the one that doesn't introduce changes to ModuleTranslation
for something that's a very unusual use of the stack frame feature.
The pop is conditional in
convertHostOrTargetOperation
, the pairing is not obviously to me.
For each omp.loop_nest
there will be 1 or more loop wrappers. If we push a stack frame when we create the omp.loop_nest
to be used up the chain, the place where we can finally pop that stack frame is when we finish processing the outermost loop wrapper. That's how the pairing is done.
// Add a stack frame holding information about the resulting loop after | ||
// applying transformations, to be further transformed by parent loop | ||
// wrappers. | ||
moduleTranslation.stackPush<OpenMPLoopInfoStackFrame>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having said that, I noticed that the current approach can break when other stack frames are inserted by loop wrappers, because it messes up with their order. So they can end up popping the stack frame inserted by the omp.loop_nest
rather than the one they pushed prior to translating the loop.
I'm going to work on the alternative implementation where push
and pop
are done conditionally in convertHostOrTargetOperation
, and make the necessary changes to ModuleTranslation
so that there's a non-const way of accessing the stack frames.
@@ -2348,57 +2364,31 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, | |||
llvm::Expected<llvm::CanonicalLoopInfo *> loopResult = | |||
ompBuilder->createCanonicalLoop( | |||
loc, bodyGen, lowerBound, upperBound, step, | |||
/*IsSigned=*/true, /*InclusiveStop=*/true, computeIP); | |||
/*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we weren't looking at the inclusive
attribute for omp.simd
.
.Case([&](omp::LoopNestOp) { | ||
return convertOmpLoopNest(*op, builder, moduleTranslation); | ||
}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly. Now wrappers expect the loop nest to be translated when they call convertOmpOpRegions
, and that is triggered by this path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having the stack push/pop withe the same condition (isOutermostLoopWrapper
) makes the balancing obvious, thanks.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This patch refactors the translation of `omp.loop_nest` operations into LLVM IR so that it is handled similarly to other operations. Before this change, the responsibility of translating the loop nest fell into each loop wrapper, causing code duplication. This patch centralizes that handling of the loop. One consequence of this was fixing an issue lowering non-inclusive `omp.simd` loops. As a result, it is now expected that the handling of composite constructs is performed collaboratively among translating functions for each operation involved. At the moment, only `do/for simd` is supported by ignoring SIMD information, and this behavior is preserved. The translation of loop wrapper operations needs access to the `llvm::CanonicalLoopInfo` loop information structure in order to apply transformations to it. This is now created in the nested call to `convertOmpLoopNest`, so it needs to be passed up to all associated loop wrapper translation functions. This is done via the creation of an `OpenMPLoopInfoStackFrame` within `convertHostOrTargetOperation`, associated to the outermost loop wrapper. This structure is updated by `convertOmpLoopNest`, making the result available to all loop wrappers after their body has been translated.
0ef4436
to
343f573
Compare
This patch refactors the translation of
omp.loop_nest
operations into LLVM IR so that it is handled similarly to other operations. Before this change, the responsibility of translating the loop nest fell into each loop wrapper, causing code duplication. This patch centralizes that handling of the loop. One consequence of this was fixing an issue lowering non-inclusiveomp.simd
loops.As a result, it is now expected that the handling of composite constructs is performed collaboratively among translating functions for each operation involved. At the moment, only
do/for simd
is supported by ignoring SIMD information, and this behavior is preserved.The translation of loop wrapper operations needs access to the
llvm::CanonicalLoopInfo
loop information structure in order to apply transformations to it. This is now created in the nested call toconvertOmpLoopNest
, so it needs to be passed up to all associated loop wrapper translation functions. This is done via the creation of anOpenMPLoopInfoStackFrame
withinconvertHostOrTargetOperation
, associated to the outermost loop wrapper. This structure is updated byconvertOmpLoopNest
, making the result available to all loop wrappers after their body has been translated.