-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[NFC][OpenMP][MLIR] Refactor code related to collecting privatizer info into a shared util #131582
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NFC][OpenMP][MLIR] Refactor code related to collecting privatizer info into a shared util #131582
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-flang-openmp Author: Kareem Ergawy (ergawy) ChangesMoves code needed to collect info about delayed privatizers into a shared util instread of repeating the same patter across all relevant constructs. Patch is 22.93 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/131582.diff 1 Files Affected:
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 17d0a7007729f..315c6b8ccc553 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -696,20 +696,42 @@ convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
return success();
}
-/// Populates `privatizations` with privatization declarations used for the
-/// given op.
-template <class OP>
-static void collectPrivatizationDecls(
- OP op, SmallVectorImpl<omp::PrivateClauseOp> &privatizations) {
- std::optional<ArrayAttr> attr = op.getPrivateSyms();
- if (!attr)
- return;
+/// A util to collect info needed to convert delayed privatizers from MLIR to
+/// LLVM.
+struct PrivateVarsInfo {
+ template <typename OP>
+ PrivateVarsInfo(OP op)
+ : privateBlockArgs(
+ cast<omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
+ mlirPrivateVars.reserve(privateBlockArgs.size());
+ llvmPrivateVars.reserve(privateBlockArgs.size());
+ collectPrivatizationDecls<OP>(op, privateDecls);
- privatizations.reserve(privatizations.size() + attr->size());
- for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
- privatizations.push_back(findPrivatizer(op, symbolRef));
+ for (mlir::Value privateVar : op.getPrivateVars())
+ mlirPrivateVars.push_back(privateVar);
}
-}
+
+ MutableArrayRef<BlockArgument> privateBlockArgs;
+ SmallVector<mlir::Value> mlirPrivateVars;
+ SmallVector<llvm::Value *> llvmPrivateVars;
+ SmallVector<omp::PrivateClauseOp> privateDecls;
+
+private:
+ /// Populates `privatizations` with privatization declarations used for the
+ /// given op.
+ template <class OP>
+ static void collectPrivatizationDecls(
+ OP op, SmallVectorImpl<omp::PrivateClauseOp> &privatizations) {
+ std::optional<ArrayAttr> attr = op.getPrivateSyms();
+ if (!attr)
+ return;
+
+ privatizations.reserve(privatizations.size() + attr->size());
+ for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
+ privatizations.push_back(findPrivatizer(op, symbolRef));
+ }
+ }
+};
/// Populates `reductions` with reduction declarations used in the given op.
template <typename T>
@@ -1384,19 +1406,18 @@ static llvm::Expected<llvm::Value *> initPrivateVar(
static llvm::Error
initPrivateVars(llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
- MutableArrayRef<BlockArgument> privateBlockArgs,
- MutableArrayRef<omp::PrivateClauseOp> privateDecls,
- MutableArrayRef<mlir::Value> mlirPrivateVars,
- llvm::SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
+ PrivateVarsInfo &privateVarsInfo,
llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
- if (privateBlockArgs.empty())
+ if (privateVarsInfo.privateBlockArgs.empty())
return llvm::Error::success();
llvm::BasicBlock *privInitBlock = splitBB(builder, true, "omp.private.init");
setInsertPointForPossiblyEmptyBlock(builder, privInitBlock);
for (auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
- privateDecls, mlirPrivateVars, privateBlockArgs, llvmPrivateVars))) {
+ privateVarsInfo.privateDecls, privateVarsInfo.mlirPrivateVars,
+ privateVarsInfo.privateBlockArgs,
+ privateVarsInfo.llvmPrivateVars))) {
auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
llvm::Expected<llvm::Value *> privVarOrErr = initPrivateVar(
builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
@@ -1420,10 +1441,7 @@ initPrivateVars(llvm::IRBuilderBase &builder,
static llvm::Expected<llvm::BasicBlock *>
allocatePrivateVars(llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
- MutableArrayRef<BlockArgument> privateBlockArgs,
- MutableArrayRef<omp::PrivateClauseOp> privateDecls,
- MutableArrayRef<mlir::Value> mlirPrivateVars,
- llvm::SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
+ PrivateVarsInfo &privateVarsInfo,
const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
// Allocate private vars
@@ -1449,8 +1467,9 @@ allocatePrivateVars(llvm::IRBuilderBase &builder,
->getDataLayout()
.getProgramAddressSpace();
- for (auto [privDecl, mlirPrivVar, blockArg] :
- llvm::zip_equal(privateDecls, mlirPrivateVars, privateBlockArgs)) {
+ for (auto [privDecl, mlirPrivVar, blockArg] : llvm::zip_equal(
+ privateVarsInfo.privateDecls, privateVarsInfo.mlirPrivateVars,
+ privateVarsInfo.privateBlockArgs)) {
llvm::Type *llvmAllocType =
moduleTranslation.convertType(privDecl.getType());
builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
@@ -1460,7 +1479,7 @@ allocatePrivateVars(llvm::IRBuilderBase &builder,
llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
builder.getPtrTy(defaultAS));
- llvmPrivateVars.push_back(llvmPrivateVar);
+ privateVarsInfo.llvmPrivateVars.push_back(llvmPrivateVar);
}
return afterAllocas;
@@ -1888,19 +1907,9 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
if (failed(checkImplementationStatus(*taskOp)))
return failure();
- // Collect delayed privatisation declarations
- MutableArrayRef<BlockArgument> privateBlockArgs =
- cast<omp::BlockArgOpenMPOpInterface>(*taskOp).getPrivateBlockArgs();
- SmallVector<mlir::Value> mlirPrivateVars;
- SmallVector<llvm::Value *> llvmPrivateVars;
- SmallVector<omp::PrivateClauseOp> privateDecls;
- mlirPrivateVars.reserve(privateBlockArgs.size());
- llvmPrivateVars.reserve(privateBlockArgs.size());
- collectPrivatizationDecls(taskOp, privateDecls);
+ PrivateVarsInfo privateVarsInfo(taskOp);
TaskContextStructManager taskStructMgr{builder, moduleTranslation,
- privateDecls};
- for (mlir::Value privateVar : taskOp.getPrivateVars())
- mlirPrivateVars.push_back(privateVar);
+ privateVarsInfo.privateDecls};
// Allocate and copy private variables before creating the task. This avoids
// accessing invalid memory if (after this scope ends) the private variables
@@ -1959,7 +1968,9 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
taskStructMgr.createGEPsToPrivateVars();
for (auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
- llvm::zip_equal(privateDecls, mlirPrivateVars, privateBlockArgs,
+ llvm::zip_equal(privateVarsInfo.privateDecls,
+ privateVarsInfo.mlirPrivateVars,
+ privateVarsInfo.privateBlockArgs,
taskStructMgr.getLLVMPrivateVarGEPs())) {
// To be handled inside the task.
if (!privDecl.readsFromMold())
@@ -1998,9 +2009,9 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
// firstprivate copy region
setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
- if (failed(copyFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
- taskStructMgr.getLLVMPrivateVarGEPs(),
- privateDecls)))
+ if (failed(copyFirstPrivateVars(
+ builder, moduleTranslation, privateVarsInfo.mlirPrivateVars,
+ taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.privateDecls)))
return llvm::failure();
// Set up for call to createTask()
@@ -2017,9 +2028,11 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
builder.restoreIP(codegenIP);
llvm::BasicBlock *privInitBlock = nullptr;
- llvmPrivateVars.resize(privateBlockArgs.size());
+ privateVarsInfo.llvmPrivateVars.resize(
+ privateVarsInfo.privateBlockArgs.size());
for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
- privateBlockArgs, privateDecls, mlirPrivateVars))) {
+ privateVarsInfo.privateBlockArgs, privateVarsInfo.privateDecls,
+ privateVarsInfo.mlirPrivateVars))) {
auto [blockArg, privDecl, mlirPrivVar] = zip;
// This is handled before the task executes
if (privDecl.readsFromMold())
@@ -2038,23 +2051,25 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
if (!privateVarOrError)
return privateVarOrError.takeError();
moduleTranslation.mapValue(blockArg, privateVarOrError.get());
- llvmPrivateVars[i] = privateVarOrError.get();
+ privateVarsInfo.llvmPrivateVars[i] = privateVarOrError.get();
}
taskStructMgr.createGEPsToPrivateVars();
for (auto [i, llvmPrivVar] :
llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
if (!llvmPrivVar) {
- assert(llvmPrivateVars[i] && "This is added in the loop above");
+ assert(privateVarsInfo.llvmPrivateVars[i] &&
+ "This is added in the loop above");
continue;
}
- llvmPrivateVars[i] = llvmPrivVar;
+ privateVarsInfo.llvmPrivateVars[i] = llvmPrivVar;
}
// Find and map the addresses of each variable within the task context
// structure
- for (auto [blockArg, llvmPrivateVar, privateDecl] :
- llvm::zip_equal(privateBlockArgs, llvmPrivateVars, privateDecls)) {
+ for (auto [blockArg, llvmPrivateVar, privateDecl] : llvm::zip_equal(
+ privateVarsInfo.privateBlockArgs, privateVarsInfo.llvmPrivateVars,
+ privateVarsInfo.privateDecls)) {
// This was handled above.
if (!privateDecl.readsFromMold())
continue;
@@ -2076,7 +2091,8 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
if (failed(cleanupPrivateVars(builder, moduleTranslation, taskOp.getLoc(),
- llvmPrivateVars, privateDecls)))
+ privateVarsInfo.llvmPrivateVars,
+ privateVarsInfo.privateDecls)))
return llvm::make_error<PreviouslyReportedError>();
// Free heap allocated task context structure at the end of the task.
@@ -2171,17 +2187,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
}
- MutableArrayRef<BlockArgument> privateBlockArgs =
- cast<omp::BlockArgOpenMPOpInterface>(*wsloopOp).getPrivateBlockArgs();
- SmallVector<mlir::Value> mlirPrivateVars;
- SmallVector<llvm::Value *> llvmPrivateVars;
- SmallVector<omp::PrivateClauseOp> privateDecls;
- mlirPrivateVars.reserve(privateBlockArgs.size());
- llvmPrivateVars.reserve(privateBlockArgs.size());
- collectPrivatizationDecls(wsloopOp, privateDecls);
-
- for (mlir::Value privateVar : wsloopOp.getPrivateVars())
- mlirPrivateVars.push_back(privateVar);
+ PrivateVarsInfo privateVarsInfo(wsloopOp);
SmallVector<omp::DeclareReductionOp> reductionDecls;
collectReductionDecls(wsloopOp, reductionDecls);
@@ -2192,8 +2198,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
wsloopOp.getNumReductionVars());
llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
- builder, moduleTranslation, privateBlockArgs, privateDecls,
- mlirPrivateVars, llvmPrivateVars, allocaIP);
+ builder, moduleTranslation, privateVarsInfo, allocaIP);
if (handleError(afterAllocas, opInst).failed())
return failure();
@@ -2210,15 +2215,14 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
deferredStores, isByRef)))
return failure();
- if (handleError(initPrivateVars(builder, moduleTranslation, privateBlockArgs,
- privateDecls, mlirPrivateVars,
- llvmPrivateVars),
+ if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
opInst)
.failed())
return failure();
- if (failed(copyFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
- llvmPrivateVars, privateDecls)))
+ if (failed(copyFirstPrivateVars(
+ builder, moduleTranslation, privateVarsInfo.mlirPrivateVars,
+ privateVarsInfo.llvmPrivateVars, privateVarsInfo.privateDecls)))
return failure();
assert(afterAllocas.get()->getSinglePredecessor());
@@ -2271,7 +2275,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
return failure();
return cleanupPrivateVars(builder, moduleTranslation, wsloopOp.getLoc(),
- llvmPrivateVars, privateDecls);
+ privateVarsInfo.llvmPrivateVars,
+ privateVarsInfo.privateDecls);
}
/// Converts the OpenMP parallel operation to LLVM IR.
@@ -2286,17 +2291,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
if (failed(checkImplementationStatus(*opInst)))
return failure();
- // Collect delayed privatization declarations
- MutableArrayRef<BlockArgument> privateBlockArgs =
- cast<omp::BlockArgOpenMPOpInterface>(*opInst).getPrivateBlockArgs();
- SmallVector<mlir::Value> mlirPrivateVars;
- SmallVector<llvm::Value *> llvmPrivateVars;
- SmallVector<omp::PrivateClauseOp> privateDecls;
- mlirPrivateVars.reserve(privateBlockArgs.size());
- llvmPrivateVars.reserve(privateBlockArgs.size());
- collectPrivatizationDecls(opInst, privateDecls);
- for (mlir::Value privateVar : opInst.getPrivateVars())
- mlirPrivateVars.push_back(privateVar);
+ PrivateVarsInfo privateVarsInfo(opInst);
// Collect reduction declarations
SmallVector<omp::DeclareReductionOp> reductionDecls;
@@ -2308,8 +2303,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
auto bodyGenCB = [&](InsertPointTy allocaIP,
InsertPointTy codeGenIP) -> llvm::Error {
llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
- builder, moduleTranslation, privateBlockArgs, privateDecls,
- mlirPrivateVars, llvmPrivateVars, allocaIP);
+ builder, moduleTranslation, privateVarsInfo, allocaIP);
if (handleError(afterAllocas, *opInst).failed())
return llvm::make_error<PreviouslyReportedError>();
@@ -2332,15 +2326,15 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
assert(afterAllocas.get()->getSinglePredecessor());
builder.restoreIP(codeGenIP);
- if (handleError(initPrivateVars(builder, moduleTranslation,
- privateBlockArgs, privateDecls,
- mlirPrivateVars, llvmPrivateVars),
- *opInst)
+ if (handleError(
+ initPrivateVars(builder, moduleTranslation, privateVarsInfo),
+ *opInst)
.failed())
return llvm::make_error<PreviouslyReportedError>();
- if (failed(copyFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
- llvmPrivateVars, privateDecls)))
+ if (failed(copyFirstPrivateVars(
+ builder, moduleTranslation, privateVarsInfo.mlirPrivateVars,
+ privateVarsInfo.llvmPrivateVars, privateVarsInfo.privateDecls)))
return llvm::make_error<PreviouslyReportedError>();
if (failed(
@@ -2422,7 +2416,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
"failed to inline `cleanup` region of `omp.declare_reduction`");
if (failed(cleanupPrivateVars(builder, moduleTranslation, opInst.getLoc(),
- llvmPrivateVars, privateDecls)))
+ privateVarsInfo.llvmPrivateVars,
+ privateVarsInfo.privateDecls)))
return llvm::make_error<PreviouslyReportedError>();
builder.restoreIP(oldIP);
@@ -2490,30 +2485,17 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
if (failed(checkImplementationStatus(opInst)))
return failure();
- MutableArrayRef<BlockArgument> privateBlockArgs =
- cast<omp::BlockArgOpenMPOpInterface>(*simdOp).getPrivateBlockArgs();
- SmallVector<mlir::Value> mlirPrivateVars;
- SmallVector<llvm::Value *> llvmPrivateVars;
- SmallVector<omp::PrivateClauseOp> privateDecls;
- mlirPrivateVars.reserve(privateBlockArgs.size());
- llvmPrivateVars.reserve(privateBlockArgs.size());
- collectPrivatizationDecls(simdOp, privateDecls);
-
- for (mlir::Value privateVar : simdOp.getPrivateVars())
- mlirPrivateVars.push_back(privateVar);
+ PrivateVarsInfo privateVarsInfo(simdOp);
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
- builder, moduleTranslation, privateBlockArgs, privateDecls,
- mlirPrivateVars, llvmPrivateVars, allocaIP);
+ builder, moduleTranslation, privateVarsInfo, allocaIP);
if (handleError(afterAllocas, opInst).failed())
return failure();
- if (handleError(initPrivateVars(builder, moduleTranslation, privateBlockArgs,
- privateDecls, mlirPrivateVars,
- llvmPrivateVars),
+ if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
opInst)
.failed())
return failure();
@@ -2562,7 +2544,8 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
order, simdlen, safelen);
return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
- llvmPrivateVars, privateDecls);
+ privateVarsInfo.llvmPrivateVars,
+ privateVarsInfo.privateDecls);
}
/// Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
@@ -4186,37 +4169,21 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
// DistributeOp has only one region associated with it.
builder.restoreIP(codeGenIP);
+ PrivateVarsInfo privVarsInfo(distributeOp);
- // TODO This is a recurring pattern in almost all ops that need
- // privatization. Try to abstract it in a shared util/interface.
- MutableArrayRef<BlockArgument> privateBlockArgs =
- cast<omp::BlockArgOpenMPOpInterface>(*distributeOp)
- .getPrivateBlockArgs();
- SmallVector<mlir::Value> mlirPrivateVars;
- SmallVector<llvm::Value *> llvmPrivateVars;
- SmallVector<omp::PrivateClauseOp> privateDecls;
- mlirPrivateVars.reserve(privateBlockArgs.size());
- llvmPrivateVars.reserve(privateBlockArgs.size());
- collectPrivatizationDecls(distributeOp, privateDecls);
-
- for (mlir::Value privateVar : distributeOp.getPrivateVars())
- mlirPrivateVars.push_back(privateVar);
-
- llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
- builder, moduleTranslation, privateBlockArgs, privateDecls,
- mlirPrivateVars, llvmPrivateVars, allocaIP);
+ llvm::Expected<llvm::BasicBlock *> afterAllocas =
+ allocatePrivateVars(builder, moduleTranslation, privVarsInfo, allocaIP);
if (handleError(afterAllocas, opInst).failed())
return llvm::make_error<PreviouslyReportedError>();
- if (handleError(initPrivateVars(builder, moduleTranslation,
- privateBlockArgs, privateDecls,
- mlirPrivateVars, llvmPrivateVars),
+ if (handleError(initPrivateVars(builder, moduleTranslation, privVarsInfo),
opInst)
.failed())
return llvm::make_error<PreviouslyReportedError>();
- if (failed(copyFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
- llvmPrivateVars, privateDecls)))
+ if (failed(copyFirstPrivateVars(
+ builder, moduleTranslation, privVarsInfo.mlirPrivateVars,
+ privVarsInfo.llvm...
[truncated]
|
@llvm/pr-subscribers-mlir-llvm Author: Kareem Ergawy (ergawy) ChangesMoves code needed to collect info about delayed privatizers into a shared util instread of repeating the same patter across all relevant constructs. Patch is 22.93 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/131582.diff 1 Files Affected:
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 17d0a7007729f..315c6b8ccc553 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -696,20 +696,42 @@ convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
return success();
}
-/// Populates `privatizations` with privatization declarations used for the
-/// given op.
-template <class OP>
-static void collectPrivatizationDecls(
- OP op, SmallVectorImpl<omp::PrivateClauseOp> &privatizations) {
- std::optional<ArrayAttr> attr = op.getPrivateSyms();
- if (!attr)
- return;
+/// A util to collect info needed to convert delayed privatizers from MLIR to
+/// LLVM.
+struct PrivateVarsInfo {
+ template <typename OP>
+ PrivateVarsInfo(OP op)
+ : privateBlockArgs(
+ cast<omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
+ mlirPrivateVars.reserve(privateBlockArgs.size());
+ llvmPrivateVars.reserve(privateBlockArgs.size());
+ collectPrivatizationDecls<OP>(op, privateDecls);
- privatizations.reserve(privatizations.size() + attr->size());
- for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
- privatizations.push_back(findPrivatizer(op, symbolRef));
+ for (mlir::Value privateVar : op.getPrivateVars())
+ mlirPrivateVars.push_back(privateVar);
}
-}
+
+ MutableArrayRef<BlockArgument> privateBlockArgs;
+ SmallVector<mlir::Value> mlirPrivateVars;
+ SmallVector<llvm::Value *> llvmPrivateVars;
+ SmallVector<omp::PrivateClauseOp> privateDecls;
+
+private:
+ /// Populates `privatizations` with privatization declarations used for the
+ /// given op.
+ template <class OP>
+ static void collectPrivatizationDecls(
+ OP op, SmallVectorImpl<omp::PrivateClauseOp> &privatizations) {
+ std::optional<ArrayAttr> attr = op.getPrivateSyms();
+ if (!attr)
+ return;
+
+ privatizations.reserve(privatizations.size() + attr->size());
+ for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
+ privatizations.push_back(findPrivatizer(op, symbolRef));
+ }
+ }
+};
/// Populates `reductions` with reduction declarations used in the given op.
template <typename T>
@@ -1384,19 +1406,18 @@ static llvm::Expected<llvm::Value *> initPrivateVar(
static llvm::Error
initPrivateVars(llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
- MutableArrayRef<BlockArgument> privateBlockArgs,
- MutableArrayRef<omp::PrivateClauseOp> privateDecls,
- MutableArrayRef<mlir::Value> mlirPrivateVars,
- llvm::SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
+ PrivateVarsInfo &privateVarsInfo,
llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
- if (privateBlockArgs.empty())
+ if (privateVarsInfo.privateBlockArgs.empty())
return llvm::Error::success();
llvm::BasicBlock *privInitBlock = splitBB(builder, true, "omp.private.init");
setInsertPointForPossiblyEmptyBlock(builder, privInitBlock);
for (auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
- privateDecls, mlirPrivateVars, privateBlockArgs, llvmPrivateVars))) {
+ privateVarsInfo.privateDecls, privateVarsInfo.mlirPrivateVars,
+ privateVarsInfo.privateBlockArgs,
+ privateVarsInfo.llvmPrivateVars))) {
auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
llvm::Expected<llvm::Value *> privVarOrErr = initPrivateVar(
builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
@@ -1420,10 +1441,7 @@ initPrivateVars(llvm::IRBuilderBase &builder,
static llvm::Expected<llvm::BasicBlock *>
allocatePrivateVars(llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
- MutableArrayRef<BlockArgument> privateBlockArgs,
- MutableArrayRef<omp::PrivateClauseOp> privateDecls,
- MutableArrayRef<mlir::Value> mlirPrivateVars,
- llvm::SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
+ PrivateVarsInfo &privateVarsInfo,
const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
// Allocate private vars
@@ -1449,8 +1467,9 @@ allocatePrivateVars(llvm::IRBuilderBase &builder,
->getDataLayout()
.getProgramAddressSpace();
- for (auto [privDecl, mlirPrivVar, blockArg] :
- llvm::zip_equal(privateDecls, mlirPrivateVars, privateBlockArgs)) {
+ for (auto [privDecl, mlirPrivVar, blockArg] : llvm::zip_equal(
+ privateVarsInfo.privateDecls, privateVarsInfo.mlirPrivateVars,
+ privateVarsInfo.privateBlockArgs)) {
llvm::Type *llvmAllocType =
moduleTranslation.convertType(privDecl.getType());
builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
@@ -1460,7 +1479,7 @@ allocatePrivateVars(llvm::IRBuilderBase &builder,
llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
builder.getPtrTy(defaultAS));
- llvmPrivateVars.push_back(llvmPrivateVar);
+ privateVarsInfo.llvmPrivateVars.push_back(llvmPrivateVar);
}
return afterAllocas;
@@ -1888,19 +1907,9 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
if (failed(checkImplementationStatus(*taskOp)))
return failure();
- // Collect delayed privatisation declarations
- MutableArrayRef<BlockArgument> privateBlockArgs =
- cast<omp::BlockArgOpenMPOpInterface>(*taskOp).getPrivateBlockArgs();
- SmallVector<mlir::Value> mlirPrivateVars;
- SmallVector<llvm::Value *> llvmPrivateVars;
- SmallVector<omp::PrivateClauseOp> privateDecls;
- mlirPrivateVars.reserve(privateBlockArgs.size());
- llvmPrivateVars.reserve(privateBlockArgs.size());
- collectPrivatizationDecls(taskOp, privateDecls);
+ PrivateVarsInfo privateVarsInfo(taskOp);
TaskContextStructManager taskStructMgr{builder, moduleTranslation,
- privateDecls};
- for (mlir::Value privateVar : taskOp.getPrivateVars())
- mlirPrivateVars.push_back(privateVar);
+ privateVarsInfo.privateDecls};
// Allocate and copy private variables before creating the task. This avoids
// accessing invalid memory if (after this scope ends) the private variables
@@ -1959,7 +1968,9 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
taskStructMgr.createGEPsToPrivateVars();
for (auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
- llvm::zip_equal(privateDecls, mlirPrivateVars, privateBlockArgs,
+ llvm::zip_equal(privateVarsInfo.privateDecls,
+ privateVarsInfo.mlirPrivateVars,
+ privateVarsInfo.privateBlockArgs,
taskStructMgr.getLLVMPrivateVarGEPs())) {
// To be handled inside the task.
if (!privDecl.readsFromMold())
@@ -1998,9 +2009,9 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
// firstprivate copy region
setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
- if (failed(copyFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
- taskStructMgr.getLLVMPrivateVarGEPs(),
- privateDecls)))
+ if (failed(copyFirstPrivateVars(
+ builder, moduleTranslation, privateVarsInfo.mlirPrivateVars,
+ taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.privateDecls)))
return llvm::failure();
// Set up for call to createTask()
@@ -2017,9 +2028,11 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
builder.restoreIP(codegenIP);
llvm::BasicBlock *privInitBlock = nullptr;
- llvmPrivateVars.resize(privateBlockArgs.size());
+ privateVarsInfo.llvmPrivateVars.resize(
+ privateVarsInfo.privateBlockArgs.size());
for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
- privateBlockArgs, privateDecls, mlirPrivateVars))) {
+ privateVarsInfo.privateBlockArgs, privateVarsInfo.privateDecls,
+ privateVarsInfo.mlirPrivateVars))) {
auto [blockArg, privDecl, mlirPrivVar] = zip;
// This is handled before the task executes
if (privDecl.readsFromMold())
@@ -2038,23 +2051,25 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
if (!privateVarOrError)
return privateVarOrError.takeError();
moduleTranslation.mapValue(blockArg, privateVarOrError.get());
- llvmPrivateVars[i] = privateVarOrError.get();
+ privateVarsInfo.llvmPrivateVars[i] = privateVarOrError.get();
}
taskStructMgr.createGEPsToPrivateVars();
for (auto [i, llvmPrivVar] :
llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
if (!llvmPrivVar) {
- assert(llvmPrivateVars[i] && "This is added in the loop above");
+ assert(privateVarsInfo.llvmPrivateVars[i] &&
+ "This is added in the loop above");
continue;
}
- llvmPrivateVars[i] = llvmPrivVar;
+ privateVarsInfo.llvmPrivateVars[i] = llvmPrivVar;
}
// Find and map the addresses of each variable within the task context
// structure
- for (auto [blockArg, llvmPrivateVar, privateDecl] :
- llvm::zip_equal(privateBlockArgs, llvmPrivateVars, privateDecls)) {
+ for (auto [blockArg, llvmPrivateVar, privateDecl] : llvm::zip_equal(
+ privateVarsInfo.privateBlockArgs, privateVarsInfo.llvmPrivateVars,
+ privateVarsInfo.privateDecls)) {
// This was handled above.
if (!privateDecl.readsFromMold())
continue;
@@ -2076,7 +2091,8 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
if (failed(cleanupPrivateVars(builder, moduleTranslation, taskOp.getLoc(),
- llvmPrivateVars, privateDecls)))
+ privateVarsInfo.llvmPrivateVars,
+ privateVarsInfo.privateDecls)))
return llvm::make_error<PreviouslyReportedError>();
// Free heap allocated task context structure at the end of the task.
@@ -2171,17 +2187,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
}
- MutableArrayRef<BlockArgument> privateBlockArgs =
- cast<omp::BlockArgOpenMPOpInterface>(*wsloopOp).getPrivateBlockArgs();
- SmallVector<mlir::Value> mlirPrivateVars;
- SmallVector<llvm::Value *> llvmPrivateVars;
- SmallVector<omp::PrivateClauseOp> privateDecls;
- mlirPrivateVars.reserve(privateBlockArgs.size());
- llvmPrivateVars.reserve(privateBlockArgs.size());
- collectPrivatizationDecls(wsloopOp, privateDecls);
-
- for (mlir::Value privateVar : wsloopOp.getPrivateVars())
- mlirPrivateVars.push_back(privateVar);
+ PrivateVarsInfo privateVarsInfo(wsloopOp);
SmallVector<omp::DeclareReductionOp> reductionDecls;
collectReductionDecls(wsloopOp, reductionDecls);
@@ -2192,8 +2198,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
wsloopOp.getNumReductionVars());
llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
- builder, moduleTranslation, privateBlockArgs, privateDecls,
- mlirPrivateVars, llvmPrivateVars, allocaIP);
+ builder, moduleTranslation, privateVarsInfo, allocaIP);
if (handleError(afterAllocas, opInst).failed())
return failure();
@@ -2210,15 +2215,14 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
deferredStores, isByRef)))
return failure();
- if (handleError(initPrivateVars(builder, moduleTranslation, privateBlockArgs,
- privateDecls, mlirPrivateVars,
- llvmPrivateVars),
+ if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
opInst)
.failed())
return failure();
- if (failed(copyFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
- llvmPrivateVars, privateDecls)))
+ if (failed(copyFirstPrivateVars(
+ builder, moduleTranslation, privateVarsInfo.mlirPrivateVars,
+ privateVarsInfo.llvmPrivateVars, privateVarsInfo.privateDecls)))
return failure();
assert(afterAllocas.get()->getSinglePredecessor());
@@ -2271,7 +2275,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
return failure();
return cleanupPrivateVars(builder, moduleTranslation, wsloopOp.getLoc(),
- llvmPrivateVars, privateDecls);
+ privateVarsInfo.llvmPrivateVars,
+ privateVarsInfo.privateDecls);
}
/// Converts the OpenMP parallel operation to LLVM IR.
@@ -2286,17 +2291,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
if (failed(checkImplementationStatus(*opInst)))
return failure();
- // Collect delayed privatization declarations
- MutableArrayRef<BlockArgument> privateBlockArgs =
- cast<omp::BlockArgOpenMPOpInterface>(*opInst).getPrivateBlockArgs();
- SmallVector<mlir::Value> mlirPrivateVars;
- SmallVector<llvm::Value *> llvmPrivateVars;
- SmallVector<omp::PrivateClauseOp> privateDecls;
- mlirPrivateVars.reserve(privateBlockArgs.size());
- llvmPrivateVars.reserve(privateBlockArgs.size());
- collectPrivatizationDecls(opInst, privateDecls);
- for (mlir::Value privateVar : opInst.getPrivateVars())
- mlirPrivateVars.push_back(privateVar);
+ PrivateVarsInfo privateVarsInfo(opInst);
// Collect reduction declarations
SmallVector<omp::DeclareReductionOp> reductionDecls;
@@ -2308,8 +2303,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
auto bodyGenCB = [&](InsertPointTy allocaIP,
InsertPointTy codeGenIP) -> llvm::Error {
llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
- builder, moduleTranslation, privateBlockArgs, privateDecls,
- mlirPrivateVars, llvmPrivateVars, allocaIP);
+ builder, moduleTranslation, privateVarsInfo, allocaIP);
if (handleError(afterAllocas, *opInst).failed())
return llvm::make_error<PreviouslyReportedError>();
@@ -2332,15 +2326,15 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
assert(afterAllocas.get()->getSinglePredecessor());
builder.restoreIP(codeGenIP);
- if (handleError(initPrivateVars(builder, moduleTranslation,
- privateBlockArgs, privateDecls,
- mlirPrivateVars, llvmPrivateVars),
- *opInst)
+ if (handleError(
+ initPrivateVars(builder, moduleTranslation, privateVarsInfo),
+ *opInst)
.failed())
return llvm::make_error<PreviouslyReportedError>();
- if (failed(copyFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
- llvmPrivateVars, privateDecls)))
+ if (failed(copyFirstPrivateVars(
+ builder, moduleTranslation, privateVarsInfo.mlirPrivateVars,
+ privateVarsInfo.llvmPrivateVars, privateVarsInfo.privateDecls)))
return llvm::make_error<PreviouslyReportedError>();
if (failed(
@@ -2422,7 +2416,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
"failed to inline `cleanup` region of `omp.declare_reduction`");
if (failed(cleanupPrivateVars(builder, moduleTranslation, opInst.getLoc(),
- llvmPrivateVars, privateDecls)))
+ privateVarsInfo.llvmPrivateVars,
+ privateVarsInfo.privateDecls)))
return llvm::make_error<PreviouslyReportedError>();
builder.restoreIP(oldIP);
@@ -2490,30 +2485,17 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
if (failed(checkImplementationStatus(opInst)))
return failure();
- MutableArrayRef<BlockArgument> privateBlockArgs =
- cast<omp::BlockArgOpenMPOpInterface>(*simdOp).getPrivateBlockArgs();
- SmallVector<mlir::Value> mlirPrivateVars;
- SmallVector<llvm::Value *> llvmPrivateVars;
- SmallVector<omp::PrivateClauseOp> privateDecls;
- mlirPrivateVars.reserve(privateBlockArgs.size());
- llvmPrivateVars.reserve(privateBlockArgs.size());
- collectPrivatizationDecls(simdOp, privateDecls);
-
- for (mlir::Value privateVar : simdOp.getPrivateVars())
- mlirPrivateVars.push_back(privateVar);
+ PrivateVarsInfo privateVarsInfo(simdOp);
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
- builder, moduleTranslation, privateBlockArgs, privateDecls,
- mlirPrivateVars, llvmPrivateVars, allocaIP);
+ builder, moduleTranslation, privateVarsInfo, allocaIP);
if (handleError(afterAllocas, opInst).failed())
return failure();
- if (handleError(initPrivateVars(builder, moduleTranslation, privateBlockArgs,
- privateDecls, mlirPrivateVars,
- llvmPrivateVars),
+ if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
opInst)
.failed())
return failure();
@@ -2562,7 +2544,8 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
order, simdlen, safelen);
return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
- llvmPrivateVars, privateDecls);
+ privateVarsInfo.llvmPrivateVars,
+ privateVarsInfo.privateDecls);
}
/// Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
@@ -4186,37 +4169,21 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
// DistributeOp has only one region associated with it.
builder.restoreIP(codeGenIP);
+ PrivateVarsInfo privVarsInfo(distributeOp);
- // TODO This is a recurring pattern in almost all ops that need
- // privatization. Try to abstract it in a shared util/interface.
- MutableArrayRef<BlockArgument> privateBlockArgs =
- cast<omp::BlockArgOpenMPOpInterface>(*distributeOp)
- .getPrivateBlockArgs();
- SmallVector<mlir::Value> mlirPrivateVars;
- SmallVector<llvm::Value *> llvmPrivateVars;
- SmallVector<omp::PrivateClauseOp> privateDecls;
- mlirPrivateVars.reserve(privateBlockArgs.size());
- llvmPrivateVars.reserve(privateBlockArgs.size());
- collectPrivatizationDecls(distributeOp, privateDecls);
-
- for (mlir::Value privateVar : distributeOp.getPrivateVars())
- mlirPrivateVars.push_back(privateVar);
-
- llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
- builder, moduleTranslation, privateBlockArgs, privateDecls,
- mlirPrivateVars, llvmPrivateVars, allocaIP);
+ llvm::Expected<llvm::BasicBlock *> afterAllocas =
+ allocatePrivateVars(builder, moduleTranslation, privVarsInfo, allocaIP);
if (handleError(afterAllocas, opInst).failed())
return llvm::make_error<PreviouslyReportedError>();
- if (handleError(initPrivateVars(builder, moduleTranslation,
- privateBlockArgs, privateDecls,
- mlirPrivateVars, llvmPrivateVars),
+ if (handleError(initPrivateVars(builder, moduleTranslation, privVarsInfo),
opInst)
.failed())
return llvm::make_error<PreviouslyReportedError>();
- if (failed(copyFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
- llvmPrivateVars, privateDecls)))
+ if (failed(copyFirstPrivateVars(
+ builder, moduleTranslation, privVarsInfo.mlirPrivateVars,
+ privVarsInfo.llvm...
[truncated]
|
f75f164
to
17f14be
Compare
1b934ba
to
19e8b68
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice cleanup. I think this is a NFC? Could you add that to the commit title
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for this change.
19e8b68
to
03e97e1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have some minor nits, but this LGTM, thank you!
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Outdated
Show resolved
Hide resolved
…to a shared util Moves code needed to collect info about delayed privatizers into a shared util instread of repeating the same patter across all relevant constructs.
03e97e1
to
b93f7db
Compare
…tizer info into a shared util (llvm#131582)" This reverts commit b7eb01b.
…fo into a shared util (llvm#131582) Moves code needed to collect info about delayed privatizers into a shared util instread of repeating the same patter across all relevant constructs.
…fo into a shared util (llvm#131582) (llvm#1290)
Moves code needed to collect info about delayed privatizers into a shared util instread of repeating the same patter across all relevant constructs.