Skip to content

[flang][OpenMP] Implicitly map allocatable record fields #117867

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -609,32 +609,22 @@ void createEmptyRegionBlocks(
}
}

inline AddrAndBoundsInfo
getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
fir::FirOpBuilder &builder,
Fortran::lower::SymbolRef sym, mlir::Location loc) {
mlir::Value symAddr = converter.getSymbolAddress(sym);
inline AddrAndBoundsInfo getDataOperandBaseAddr(fir::FirOpBuilder &builder,
mlir::Value symAddr,
bool isOptional,
mlir::Location loc) {
mlir::Value rawInput = symAddr;
if (auto declareOp =
mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp())) {
symAddr = declareOp.getResults()[0];
rawInput = declareOp.getResults()[1];
}

// TODO: Might need revisiting to handle for non-shared clauses
if (!symAddr) {
if (const auto *details =
sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
symAddr = converter.getSymbolAddress(details->symbol());
rawInput = symAddr;
}
}

if (!symAddr)
llvm::report_fatal_error("could not retrieve symbol address");

mlir::Value isPresent;
if (Fortran::semantics::IsOptional(sym))
if (isOptional)
isPresent =
builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), rawInput);

Expand All @@ -648,8 +638,7 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
// all address/dimension retrievals. For Fortran optional though, leave
// the load generation for later so it can be done in the appropriate
// if branches.
if (mlir::isa<fir::ReferenceType>(symAddr.getType()) &&
!Fortran::semantics::IsOptional(sym)) {
if (mlir::isa<fir::ReferenceType>(symAddr.getType()) && !isOptional) {
mlir::Value addr = builder.create<fir::LoadOp>(loc, symAddr);
return AddrAndBoundsInfo(addr, rawInput, isPresent, boxTy);
}
Expand All @@ -659,6 +648,14 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
return AddrAndBoundsInfo(symAddr, rawInput, isPresent);
}

inline AddrAndBoundsInfo
getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
fir::FirOpBuilder &builder,
Fortran::lower::SymbolRef sym, mlir::Location loc) {
return getDataOperandBaseAddr(builder, converter.getSymbolAddress(sym),
Fortran::semantics::IsOptional(sym), loc);
}

template <typename BoundsOp, typename BoundsType>
llvm::SmallVector<mlir::Value>
gatherBoundsOrBoundValues(fir::FirOpBuilder &builder, mlir::Location loc,
Expand Down Expand Up @@ -1224,6 +1221,25 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(

return info;
}

template <typename BoundsOp, typename BoundsType>
llvm::SmallVector<mlir::Value>
genImplicitBoundsOps(fir::FirOpBuilder &builder, lower::AddrAndBoundsInfo &info,
fir::ExtendedValue dataExv, bool dataExvIsAssumedSize,
mlir::Location loc) {
llvm::SmallVector<mlir::Value> bounds;

mlir::Value baseOp = info.rawInput;
if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
bounds = lower::genBoundsOpsFromBox<BoundsOp, BoundsType>(builder, loc,
dataExv, info);
if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
bounds = lower::genBaseBoundsOps<BoundsOp, BoundsType>(
builder, loc, dataExv, dataExvIsAssumedSize);
}

return bounds;
}
} // namespace lower
} // namespace Fortran

Expand Down
3 changes: 2 additions & 1 deletion flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
//===----------------------------------------------------------------------===//

#include "flang/Lower/Bridge.h"
#include "DirectivesCommon.h"

#include "flang/Common/Version.h"
#include "flang/Lower/Allocatable.h"
#include "flang/Lower/CallInterface.h"
Expand All @@ -22,6 +22,7 @@
#include "flang/Lower/ConvertType.h"
#include "flang/Lower/ConvertVariable.h"
#include "flang/Lower/Cuda.h"
#include "flang/Lower/DirectivesCommon.h"
#include "flang/Lower/HostAssociations.h"
#include "flang/Lower/IO.h"
#include "flang/Lower/IterationSpace.h"
Expand Down
3 changes: 2 additions & 1 deletion flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
//===----------------------------------------------------------------------===//

#include "flang/Lower/OpenACC.h"
#include "DirectivesCommon.h"

#include "flang/Common/idioms.h"
#include "flang/Lower/Bridge.h"
#include "flang/Lower/ConvertType.h"
#include "flang/Lower/DirectivesCommon.h"
#include "flang/Lower/Mangler.h"
#include "flang/Lower/PFTBuilder.h"
#include "flang/Lower/StatementContext.h"
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
#define FORTRAN_LOWER_CLAUSEPROCESSOR_H

#include "Clauses.h"
#include "DirectivesCommon.h"
#include "ReductionProcessor.h"
#include "Utils.h"
#include "flang/Lower/AbstractConverter.h"
#include "flang/Lower/Bridge.h"
#include "flang/Lower/DirectivesCommon.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Parser/dump-parse-tree.h"
#include "flang/Parser/parse-tree.h"
Expand Down
23 changes: 8 additions & 15 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
#include "Clauses.h"
#include "DataSharingProcessor.h"
#include "Decomposer.h"
#include "DirectivesCommon.h"
#include "ReductionProcessor.h"
#include "Utils.h"
#include "flang/Common/OpenMP-utils.h"
#include "flang/Common/idioms.h"
#include "flang/Lower/Bridge.h"
#include "flang/Lower/ConvertExpr.h"
#include "flang/Lower/ConvertVariable.h"
#include "flang/Lower/DirectivesCommon.h"
#include "flang/Lower/StatementContext.h"
#include "flang/Lower/SymbolMap.h"
#include "flang/Optimizer/Builder/BoxValue.h"
Expand Down Expand Up @@ -1735,32 +1735,25 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
if (const auto *details =
sym.template detailsIf<semantics::HostAssocDetails>())
converter.copySymbolBinding(details->symbol(), sym);
llvm::SmallVector<mlir::Value> bounds;
std::stringstream name;
fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym);
name << sym.name().ToString();

lower::AddrAndBoundsInfo info = getDataOperandBaseAddr(
converter, firOpBuilder, sym, converter.getCurrentLocation());
mlir::Value baseOp = info.rawInput;
if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
bounds = lower::genBoundsOpsFromBox<mlir::omp::MapBoundsOp,
mlir::omp::MapBoundsType>(
firOpBuilder, converter.getCurrentLocation(), dataExv, info);
if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
bool dataExvIsAssumedSize =
semantics::IsAssumedSizeArray(sym.GetUltimate());
bounds = lower::genBaseBoundsOps<mlir::omp::MapBoundsOp,
mlir::omp::MapBoundsType>(
firOpBuilder, converter.getCurrentLocation(), dataExv,
dataExvIsAssumedSize);
}
llvm::SmallVector<mlir::Value> bounds =
lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
mlir::omp::MapBoundsType>(
firOpBuilder, info, dataExv,
semantics::IsAssumedSizeArray(sym.GetUltimate()),
converter.getCurrentLocation());

llvm::omp::OpenMPOffloadMappingFlags mapFlag =
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
mlir::omp::VariableCaptureKind captureKind =
mlir::omp::VariableCaptureKind::ByRef;

mlir::Value baseOp = info.rawInput;
mlir::Type eleType = baseOp.getType();
if (auto refType = mlir::dyn_cast<fir::ReferenceType>(baseOp.getType()))
eleType = refType.getElementType();
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Lower/OpenMP/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
#include "Utils.h"

#include "Clauses.h"
#include <DirectivesCommon.h>

#include <flang/Lower/AbstractConverter.h>
#include <flang/Lower/ConvertType.h>
#include <flang/Lower/DirectivesCommon.h>
#include <flang/Lower/PFTBuilder.h>
#include <flang/Optimizer/Builder/FIRBuilder.h>
#include <flang/Optimizer/Builder/Todo.h>
Expand Down
2 changes: 2 additions & 0 deletions flang/lib/Optimizer/OpenMP/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ add_flang_library(FlangOpenMPTransforms
FIRDialect
HLFIROpsIncGen
FlangOpenMPPassesIncGen
${dialect_libs}

LINK_LIBS
FIRAnalysis
Expand All @@ -27,4 +28,5 @@ add_flang_library(FlangOpenMPTransforms
MLIRIR
MLIRPass
MLIRTransformUtils
${dialect_libs}
)
158 changes: 158 additions & 0 deletions flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,14 @@
/// indirectly via a parent object.
//===----------------------------------------------------------------------===//

#include "flang/Lower/DirectivesCommon.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/OpenMP/Passes.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/BuiltinDialect.h"
Expand Down Expand Up @@ -486,6 +490,160 @@ class MapInfoFinalizationPass
// iterations from previous function scopes.
localBoxAllocas.clear();

// First, walk `omp.map.info` ops to see if any record members should be
// implicitly mapped.
func->walk([&](mlir::omp::MapInfoOp op) {
mlir::Type underlyingType =
fir::unwrapRefType(op.getVarPtr().getType());

// TODO Test with and support more complicated cases; like arrays for
// records, for example.
if (!fir::isRecordWithAllocatableMember(underlyingType))
return mlir::WalkResult::advance();

// TODO For now, only consider `omp.target` ops. Other ops that support
// `map` clauses will follow later.
mlir::omp::TargetOp target =
mlir::dyn_cast_if_present<mlir::omp::TargetOp>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will have to work for anything that has an explicit map clause as well. From my understanding we have to map these components if someone was to write map(tofrom: some_dtype_with_allocas). Worth double checking this with @mjklemm. Hopefully isn't too hard to extend, but if it is, I'd be happy with a first pass that works for just TargetOp, which is arguably the hardest to do it for in any case :-)

Copy link
Member Author

@ergawy ergawy Nov 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a todo.

getFirstTargetUser(op));

if (!target)
return mlir::WalkResult::advance();

auto mapClauseOwner =
llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(*target);

int64_t mapVarIdx = mapClauseOwner.getOperandIndexForMap(op);
assert(mapVarIdx >= 0 &&
mapVarIdx <
static_cast<int64_t>(mapClauseOwner.getMapVars().size()));

auto argIface =
llvm::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(*target);
// TODO How should `map` block argument that correspond to: `private`,
// `use_device_addr`, `use_device_ptr`, be handled?
mlir::BlockArgument opBlockArg = argIface.getMapBlockArgs()[mapVarIdx];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might need to be a tad wary of how this works with use_device_addr/ptr, as they're both map info holders, and soon with Pranav's and your work I believe private may also be, and I don't think they have the same implicit connotations regular map does. We may also eventually need to be careful of this with declare mappers as well, as they're explicit user defined mappings that shouldn't have the implicit behavior but that's a problem for another day/PR when that work is further along :-)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a todo.

llvm::SetVector<mlir::Operation *> mapVarForwardSlice;
mlir::getForwardSlice(opBlockArg, &mapVarForwardSlice);

mapVarForwardSlice.remove_if([&](mlir::Operation *sliceOp) {
// TODO Support coordinate_of ops.
//
// TODO Support call ops by recursively examining the forward slice of
// the corresponding parameter to the field in the called function.
return !mlir::isa<hlfir::DesignateOp>(sliceOp);
});

auto recordType = mlir::cast<fir::RecordType>(underlyingType);
llvm::SmallVector<mlir::Value> newMapOpsForFields;
llvm::SmallVector<int64_t> fieldIndicies;

for (auto fieldMemTyPair : recordType.getTypeList()) {
auto &field = fieldMemTyPair.first;
auto memTy = fieldMemTyPair.second;

bool shouldMapField =
llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
if (!fir::isAllocatableType(memTy))
return false;

auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
if (!designateOp)
return false;

return designateOp.getComponent() &&
designateOp.getComponent()->strref() == field;
}) != mapVarForwardSlice.end();

// TODO Handle recursive record types. Adapting
// `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
// entities might be helpful here.

if (!shouldMapField)
continue;

int64_t fieldIdx = recordType.getFieldIndex(field);
bool alreadyMapped = [&]() {
if (op.getMembersIndexAttr())
for (auto indexList : op.getMembersIndexAttr()) {
auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
if (indexListAttr.size() == 1 &&
mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
fieldIdx)
return true;
}

return false;
}();

if (alreadyMapped)
continue;

builder.setInsertionPoint(op);
mlir::Value fieldIdxVal = builder.createIntegerConstant(
op.getLoc(), mlir::IndexType::get(builder.getContext()),
fieldIdx);
auto fieldCoord = builder.create<fir::CoordinateOp>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the whole recursive generation / gather of indices, you may be able to borrow / tweak (improve on ;-)) the explicit map code that does something similar in the function: "createParentSymAndGenIntermediateMaps"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointer. Also added a todo for a later PR.

op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
fieldIdxVal);
Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::getDataOperandBaseAddr(
builder, fieldCoord, /*isOptional=*/false, op.getLoc());
llvm::SmallVector<mlir::Value> bounds =
Fortran::lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
mlir::omp::MapBoundsType>(
builder, info,
hlfir::translateToExtendedValue(op.getLoc(), builder,
hlfir::Entity{fieldCoord})
.first,
/*dataExvIsAssumedSize=*/false, op.getLoc());

mlir::omp::MapInfoOp fieldMapOp =
builder.create<mlir::omp::MapInfoOp>(
op.getLoc(), fieldCoord.getResult().getType(),
fieldCoord.getResult(),
mlir::TypeAttr::get(
fir::unwrapRefType(fieldCoord.getResult().getType())),
/*varPtrPtr=*/mlir::Value{},
/*members=*/mlir::ValueRange{},
/*members_index=*/mlir::ArrayAttr{},
/*bounds=*/bounds, op.getMapTypeAttr(),
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should probably take the capture kind from the original operation as well, to keep it inline with the map type attribute, may save some possible weirdness :-)

But taking the map type from the parent does raise the interesting question of if the implicit mapping should take on the specified map type behavior of it's parent or if it should default to the regular implicit map behaviour of it's type. Not too sure on that one, the spec or one of our local specification gurus may have some insight into that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe @mjklemm can provide some guidance here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think that's the right behavior. However, we need to bear in mind that a user might have used declare mapper to define a mapper for a derived type. In that case, we might need to honor the user-defined mapper instead.

mlir::omp::VariableCaptureKind::ByRef),
builder.getStringAttr(op.getNameAttr().strref() + "." +
field + ".implicit_map"),
/*partial_map=*/builder.getBoolAttr(false));
newMapOpsForFields.emplace_back(fieldMapOp);
fieldIndicies.emplace_back(fieldIdx);
}

if (newMapOpsForFields.empty())
return mlir::WalkResult::advance();

op.getMembersMutable().append(newMapOpsForFields);
llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();

if (oldMembersIdxAttr)
for (mlir::Attribute indexList : oldMembersIdxAttr) {
llvm::SmallVector<int64_t> listVec;

for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
listVec.push_back(mlir::cast<mlir::IntegerAttr>(index).getInt());

newMemberIndices.emplace_back(std::move(listVec));
}

for (int64_t newFieldIdx : fieldIndicies)
newMemberIndices.emplace_back(
llvm::SmallVector<int64_t>(1, newFieldIdx));

op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices));
op.setPartialMap(true);

return mlir::WalkResult::advance();
});

func->walk([&](mlir::omp::MapInfoOp op) {
// TODO: Currently only supports a single user for the MapInfoOp. This
// is fine for the moment, as the Fortran frontend will generate a
Expand Down
Loading
Loading