Skip to content

[Flang][OpenMP][MLIR] Initial derived type member map support #82853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flang/docs/OpenMP-descriptor-management.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Currently, Flang will lower these descriptor types in the OpenMP lowering (lower
to all other map types, generating an omp.MapInfoOp containing relevant information required for lowering
the OpenMP dialect to LLVM-IR during the final stages of the MLIR lowering. However, after
the lowering to FIR/HLFIR has been performed an OpenMP dialect specific pass for Fortran,
`OMPDescriptorMapInfoGenPass` (Optimizer/OMPDescriptorMapInfoGen.cpp) will expand the
`OMPMapInfoFinalizationPass` (Optimizer/OMPMapInfoFinalization.cpp) will expand the
`omp.MapInfoOp`'s containing descriptors (which currently will be a `BoxType` or `BoxAddrOp`) into multiple
mappings, with one extra per pointer member in the descriptor that is supported on top of the original
descriptor map operation. These pointers members are linked to the parent descriptor by adding them to
Expand All @@ -53,7 +53,7 @@ owning operation's (`omp.TargetOp`, `omp.TargetDataOp` etc.) map operand list an
operation is `IsolatedFromAbove`, it also inserts them as `BlockArgs` to canonicalize the mappings and
simplify lowering.

An example transformation by the `OMPDescriptorMapInfoGenPass`:
An example transformation by the `OMPMapInfoFinalizationPass`:

```

Expand Down
2 changes: 1 addition & 1 deletion flang/include/flang/Optimizer/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ std::unique_ptr<mlir::Pass> createAlgebraicSimplificationPass();
std::unique_ptr<mlir::Pass>
createAlgebraicSimplificationPass(const mlir::GreedyRewriteConfig &config);

std::unique_ptr<mlir::Pass> createOMPDescriptorMapInfoGenPass();
std::unique_ptr<mlir::Pass> createOMPMapInfoFinalizationPass();
std::unique_ptr<mlir::Pass> createOMPFunctionFilteringPass();
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createOMPMarkDeclareTargetPass();
Expand Down
6 changes: 3 additions & 3 deletions flang/include/flang/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -321,15 +321,15 @@ def LoopVersioning : Pass<"loop-versioning", "mlir::func::FuncOp"> {
let dependentDialects = [ "fir::FIROpsDialect" ];
}

def OMPDescriptorMapInfoGenPass
: Pass<"omp-descriptor-map-info-gen", "mlir::func::FuncOp"> {
def OMPMapInfoFinalizationPass
: Pass<"omp-map-info-finalization", "mlir::func::FuncOp"> {
let summary = "expands OpenMP MapInfo operations containing descriptors";
let description = [{
Expands MapInfo operations containing descriptor types into multiple
MapInfo's for each pointer element in the descriptor that requires
explicit individual mapping by the OpenMP runtime.
}];
let constructor = "::fir::createOMPDescriptorMapInfoGenPass()";
let constructor = "::fir::createOMPMapInfoFinalizationPass()";
let dependentDialects = ["mlir::omp::OpenMPDialect"];
}

Expand Down
2 changes: 1 addition & 1 deletion flang/include/flang/Tools/CLOptions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ inline void createHLFIRToFIRPassPipeline(
/// rather than the host device.
inline void createOpenMPFIRPassPipeline(
mlir::PassManager &pm, bool isTargetDevice) {
pm.addPass(fir::createOMPDescriptorMapInfoGenPass());
pm.addPass(fir::createOMPMapInfoFinalizationPass());
pm.addPass(fir::createOMPMarkDeclareTargetPass());
if (isTargetDevice)
pm.addPass(fir::createOMPFunctionFilteringPass());
Expand Down
67 changes: 31 additions & 36 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -814,38 +814,24 @@ bool ClauseProcessor::processLink(
});
}

mlir::omp::MapInfoOp
createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name,
llvm::ArrayRef<mlir::Value> bounds,
llvm::ArrayRef<mlir::Value> members, uint64_t mapType,
mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
bool isVal) {
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) {
baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
retTy = baseAddr.getType();
}

mlir::TypeAttr varType = mlir::TypeAttr::get(
llvm::cast<mlir::omp::PointerLikeType>(retTy).getElementType());

mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>(
loc, retTy, baseAddr, varType, varPtrPtr, members, bounds,
builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
builder.getStringAttr(name));

return op;
}

bool ClauseProcessor::processMap(
mlir::Location currentLocation, Fortran::lower::StatementContext &stmtCtx,
mlir::omp::MapClauseOps &result,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSyms,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
return findRepeatableClause<omp::clause::Map>(
// We always require tracking of symbols, even if the caller does not,
// so we create an optionally used local set of symbols when the mapSyms
// argument is not present.
llvm::SmallVector<const Fortran::semantics::Symbol *> localMapSyms;
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *ptrMapSyms =
mapSyms ? mapSyms : &localMapSyms;
std::map<const Fortran::semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>>
parentMemberIndices;

bool clauseFound = findRepeatableClause<omp::clause::Map>(
[&](const omp::clause::Map &clause,
const Fortran::parser::CharBlock &source) {
using Map = omp::clause::Map;
Expand Down Expand Up @@ -910,24 +896,33 @@ bool ClauseProcessor::processMap(
// Explicit map captures are captured ByRef by default,
// optimisation passes may alter this to ByCopy or other capture
// types to optimise
mlir::Value mapOp = createMapInfoOp(
firOpBuilder, clauseLocation, symAddr, mlir::Value{},
asFortran.str(), bounds, {},
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
firOpBuilder, clauseLocation, symAddr,
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
/*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
static_cast<
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
mapTypeBits),
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());

result.mapVars.push_back(mapOp);

if (mapSyms)
mapSyms->push_back(object.id());
if (mapSymLocs)
mapSymLocs->push_back(symAddr.getLoc());
if (mapSymTypes)
mapSymTypes->push_back(symAddr.getType());
if (object.id()->owner().IsDerivedType()) {
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp,
semaCtx);
} else {
result.mapVars.push_back(mapOp);
ptrMapSyms->push_back(object.id());
if (mapSymTypes)
mapSymTypes->push_back(symAddr.getType());
if (mapSymLocs)
mapSymLocs->push_back(symAddr.getLoc());
}
}
});

insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,
*ptrMapSyms, mapSymTypes, mapSymLocs);

return clauseFound;
}

bool ClauseProcessor::processReduction(
Expand Down
28 changes: 23 additions & 5 deletions flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,12 @@ template <typename T>
bool ClauseProcessor::processMotionClauses(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this patch, but perhaps we should think about refactoring processMap and processMotionClauses to avoid code duplication for those parts that are shared between them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree and have been thinking about it, as there isn't a whole lot of dissimilarities, most changes to processMapClauses need replicated to processMotionClauses and it would make testing a lot simpler! But it'd also be up to @ergawy I am unsure if he'd like to keep them distinct and has some insights into motion clauses that might make merging the two functions unideal!

But I can make a follow up PR after the stack has landed to merge them if we are all happy doing so.

Fortran::lower::StatementContext &stmtCtx,
mlir::omp::MapClauseOps &result) {
return findRepeatableClause<T>(
std::map<const Fortran::semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>>
parentMemberIndices;
llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;

bool clauseFound = findRepeatableClause<T>(
[&](const T &clause, const Fortran::parser::CharBlock &source) {
mlir::Location clauseLocation = converter.genLocation(source);
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Expand All @@ -203,6 +208,7 @@ bool ClauseProcessor::processMotionClauses(
for (const omp::Object &object : objects) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;

Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>(
Expand All @@ -218,17 +224,29 @@ bool ClauseProcessor::processMotionClauses(
// Explicit map captures are captured ByRef by default,
// optimisation passes may alter this to ByCopy or other capture
// types to optimise
mlir::Value mapOp = createMapInfoOp(
firOpBuilder, clauseLocation, symAddr, mlir::Value{},
asFortran.str(), bounds, {},
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
firOpBuilder, clauseLocation, symAddr,
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
/*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
static_cast<
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
mapTypeBits),
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());

result.mapVars.push_back(mapOp);
if (object.id()->owner().IsDerivedType()) {
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp,
semaCtx);
} else {
result.mapVars.push_back(mapOp);
mapSymbols.push_back(object.id());
}
}
});

insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,
mapSymbols,
/*mapSymTypes=*/nullptr, /*mapSymLocs=*/nullptr);
return clauseFound;
}

template <typename... Ts>
Expand Down
11 changes: 7 additions & 4 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -939,8 +939,10 @@ genBodyOfTargetOp(Fortran::lower::AbstractConverter &converter,
std::stringstream name;
firOpBuilder.setInsertionPoint(targetOp);
mlir::Value mapOp = createMapInfoOp(
firOpBuilder, copyVal.getLoc(), copyVal, mlir::Value{}, name.str(),
bounds, llvm::SmallVector<mlir::Value>{},
firOpBuilder, copyVal.getLoc(), copyVal,
/*varPtrPtr=*/mlir::Value{}, name.str(), bounds,
/*members=*/llvm::SmallVector<mlir::Value>{},
/*membersIndex=*/mlir::DenseIntElementsAttr{},
static_cast<
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT),
Expand Down Expand Up @@ -1637,8 +1639,9 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
}

mlir::Value mapOp = createMapInfoOp(
firOpBuilder, baseOp.getLoc(), baseOp, mlir::Value{}, name.str(),
bounds, {},
firOpBuilder, baseOp.getLoc(), baseOp, /*varPtrPtr=*/mlir::Value{},
name.str(), bounds, /*members=*/{},
/*membersIndex=*/mlir::DenseIntElementsAttr{},
static_cast<
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
mapFlag),
Expand Down
Loading
Loading