Skip to content

Commit 7d4e732

Browse files
committed
[flang][OpenMP] Implicitly map allocatable record fields
This is a starting PR to implicitly map allocatable record fields. This PR contains the following changes: 1. Re-purposes some of the utils used in `Lower/OpenMP.cpp` so that these utils work on the `mlir::Value` level rather than the `semantics::Symbol` level. This takes one step towards to enabling MLIR passes to more easily do some lowering themselves (e.g. creating `omp.map.bounds` ops for implicitely caputured data like this PR does). 2. Adds support for implicitely capturing and mapping allocatable fields in record types. There is quite some distant to still cover to have full support for this. I added a number of todos to guide further development.
1 parent 56ddbef commit 7d4e732

File tree

11 files changed

+329
-36
lines changed

11 files changed

+329
-36
lines changed

flang/lib/Lower/DirectivesCommon.h renamed to flang/include/flang/Lower/DirectivesCommon.h

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -609,32 +609,22 @@ void createEmptyRegionBlocks(
609609
}
610610
}
611611

612-
inline AddrAndBoundsInfo
613-
getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
614-
fir::FirOpBuilder &builder,
615-
Fortran::lower::SymbolRef sym, mlir::Location loc) {
616-
mlir::Value symAddr = converter.getSymbolAddress(sym);
612+
inline AddrAndBoundsInfo getDataOperandBaseAddr(fir::FirOpBuilder &builder,
613+
mlir::Value symAddr,
614+
bool isOptional,
615+
mlir::Location loc) {
617616
mlir::Value rawInput = symAddr;
618617
if (auto declareOp =
619618
mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp())) {
620619
symAddr = declareOp.getResults()[0];
621620
rawInput = declareOp.getResults()[1];
622621
}
623622

624-
// TODO: Might need revisiting to handle for non-shared clauses
625-
if (!symAddr) {
626-
if (const auto *details =
627-
sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
628-
symAddr = converter.getSymbolAddress(details->symbol());
629-
rawInput = symAddr;
630-
}
631-
}
632-
633623
if (!symAddr)
634624
llvm::report_fatal_error("could not retrieve symbol address");
635625

636626
mlir::Value isPresent;
637-
if (Fortran::semantics::IsOptional(sym))
627+
if (isOptional)
638628
isPresent =
639629
builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), rawInput);
640630

@@ -648,8 +638,7 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
648638
// all address/dimension retrievals. For Fortran optional though, leave
649639
// the load generation for later so it can be done in the appropriate
650640
// if branches.
651-
if (mlir::isa<fir::ReferenceType>(symAddr.getType()) &&
652-
!Fortran::semantics::IsOptional(sym)) {
641+
if (mlir::isa<fir::ReferenceType>(symAddr.getType()) && !isOptional) {
653642
mlir::Value addr = builder.create<fir::LoadOp>(loc, symAddr);
654643
return AddrAndBoundsInfo(addr, rawInput, isPresent, boxTy);
655644
}
@@ -659,6 +648,14 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
659648
return AddrAndBoundsInfo(symAddr, rawInput, isPresent);
660649
}
661650

651+
inline AddrAndBoundsInfo
652+
getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
653+
fir::FirOpBuilder &builder,
654+
Fortran::lower::SymbolRef sym, mlir::Location loc) {
655+
return getDataOperandBaseAddr(builder, converter.getSymbolAddress(sym),
656+
Fortran::semantics::IsOptional(sym), loc);
657+
}
658+
662659
template <typename BoundsOp, typename BoundsType>
663660
llvm::SmallVector<mlir::Value>
664661
gatherBoundsOrBoundValues(fir::FirOpBuilder &builder, mlir::Location loc,
@@ -1224,6 +1221,25 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
12241221

12251222
return info;
12261223
}
1224+
1225+
template <typename BoundsOp, typename BoundsType>
1226+
llvm::SmallVector<mlir::Value>
1227+
genImplicitBoundsOps(fir::FirOpBuilder &builder, lower::AddrAndBoundsInfo &info,
1228+
fir::ExtendedValue dataExv, bool dataExvIsAssumedSize,
1229+
mlir::Location loc) {
1230+
llvm::SmallVector<mlir::Value> bounds;
1231+
1232+
mlir::Value baseOp = info.rawInput;
1233+
if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
1234+
bounds = lower::genBoundsOpsFromBox<BoundsOp, BoundsType>(builder, loc,
1235+
dataExv, info);
1236+
if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
1237+
bounds = lower::genBaseBoundsOps<BoundsOp, BoundsType>(
1238+
builder, loc, dataExv, dataExvIsAssumedSize);
1239+
}
1240+
1241+
return bounds;
1242+
}
12271243
} // namespace lower
12281244
} // namespace Fortran
12291245

flang/lib/Lower/Bridge.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "flang/Lower/Bridge.h"
14-
#include "DirectivesCommon.h"
14+
1515
#include "flang/Common/Version.h"
1616
#include "flang/Lower/Allocatable.h"
1717
#include "flang/Lower/CallInterface.h"
@@ -22,6 +22,7 @@
2222
#include "flang/Lower/ConvertType.h"
2323
#include "flang/Lower/ConvertVariable.h"
2424
#include "flang/Lower/Cuda.h"
25+
#include "flang/Lower/DirectivesCommon.h"
2526
#include "flang/Lower/HostAssociations.h"
2627
#include "flang/Lower/IO.h"
2728
#include "flang/Lower/IterationSpace.h"

flang/lib/Lower/OpenACC.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "flang/Lower/OpenACC.h"
14-
#include "DirectivesCommon.h"
14+
1515
#include "flang/Common/idioms.h"
1616
#include "flang/Lower/Bridge.h"
1717
#include "flang/Lower/ConvertType.h"
18+
#include "flang/Lower/DirectivesCommon.h"
1819
#include "flang/Lower/Mangler.h"
1920
#include "flang/Lower/PFTBuilder.h"
2021
#include "flang/Lower/StatementContext.h"

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
#define FORTRAN_LOWER_CLAUSEPROCESSOR_H
1414

1515
#include "Clauses.h"
16-
#include "DirectivesCommon.h"
1716
#include "ReductionProcessor.h"
1817
#include "Utils.h"
1918
#include "flang/Lower/AbstractConverter.h"
2019
#include "flang/Lower/Bridge.h"
20+
#include "flang/Lower/DirectivesCommon.h"
2121
#include "flang/Optimizer/Builder/Todo.h"
2222
#include "flang/Parser/dump-parse-tree.h"
2323
#include "flang/Parser/parse-tree.h"

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616
#include "Clauses.h"
1717
#include "DataSharingProcessor.h"
1818
#include "Decomposer.h"
19-
#include "DirectivesCommon.h"
2019
#include "ReductionProcessor.h"
2120
#include "Utils.h"
2221
#include "flang/Common/OpenMP-utils.h"
2322
#include "flang/Common/idioms.h"
2423
#include "flang/Lower/Bridge.h"
2524
#include "flang/Lower/ConvertExpr.h"
2625
#include "flang/Lower/ConvertVariable.h"
26+
#include "flang/Lower/DirectivesCommon.h"
2727
#include "flang/Lower/StatementContext.h"
2828
#include "flang/Lower/SymbolMap.h"
2929
#include "flang/Optimizer/Builder/BoxValue.h"
@@ -1731,32 +1731,25 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
17311731
if (const auto *details =
17321732
sym.template detailsIf<semantics::HostAssocDetails>())
17331733
converter.copySymbolBinding(details->symbol(), sym);
1734-
llvm::SmallVector<mlir::Value> bounds;
17351734
std::stringstream name;
17361735
fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym);
17371736
name << sym.name().ToString();
17381737

17391738
lower::AddrAndBoundsInfo info = getDataOperandBaseAddr(
17401739
converter, firOpBuilder, sym, converter.getCurrentLocation());
1741-
mlir::Value baseOp = info.rawInput;
1742-
if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
1743-
bounds = lower::genBoundsOpsFromBox<mlir::omp::MapBoundsOp,
1744-
mlir::omp::MapBoundsType>(
1745-
firOpBuilder, converter.getCurrentLocation(), dataExv, info);
1746-
if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
1747-
bool dataExvIsAssumedSize =
1748-
semantics::IsAssumedSizeArray(sym.GetUltimate());
1749-
bounds = lower::genBaseBoundsOps<mlir::omp::MapBoundsOp,
1750-
mlir::omp::MapBoundsType>(
1751-
firOpBuilder, converter.getCurrentLocation(), dataExv,
1752-
dataExvIsAssumedSize);
1753-
}
1740+
llvm::SmallVector<mlir::Value> bounds =
1741+
lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
1742+
mlir::omp::MapBoundsType>(
1743+
firOpBuilder, info, dataExv,
1744+
semantics::IsAssumedSizeArray(sym.GetUltimate()),
1745+
converter.getCurrentLocation());
17541746

17551747
llvm::omp::OpenMPOffloadMappingFlags mapFlag =
17561748
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
17571749
mlir::omp::VariableCaptureKind captureKind =
17581750
mlir::omp::VariableCaptureKind::ByRef;
17591751

1752+
mlir::Value baseOp = info.rawInput;
17601753
mlir::Type eleType = baseOp.getType();
17611754
if (auto refType = mlir::dyn_cast<fir::ReferenceType>(baseOp.getType()))
17621755
eleType = refType.getElementType();

flang/lib/Lower/OpenMP/Utils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
#include "Utils.h"
1414

1515
#include "Clauses.h"
16-
#include <DirectivesCommon.h>
1716

1817
#include <flang/Lower/AbstractConverter.h>
1918
#include <flang/Lower/ConvertType.h>
19+
#include <flang/Lower/DirectivesCommon.h>
2020
#include <flang/Lower/PFTBuilder.h>
2121
#include <flang/Optimizer/Builder/FIRBuilder.h>
2222
#include <flang/Optimizer/Builder/Todo.h>

flang/lib/Optimizer/OpenMP/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ add_flang_library(FlangOpenMPTransforms
1212
FIRDialect
1313
HLFIROpsIncGen
1414
FlangOpenMPPassesIncGen
15+
${dialect_libs}
1516

1617
LINK_LIBS
1718
FIRAnalysis
@@ -27,4 +28,5 @@ add_flang_library(FlangOpenMPTransforms
2728
MLIRIR
2829
MLIRPass
2930
MLIRTransformUtils
31+
${dialect_libs}
3032
)

flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,14 @@
2424
/// indirectly via a parent object.
2525
//===----------------------------------------------------------------------===//
2626

27+
#include "flang/Lower/DirectivesCommon.h"
2728
#include "flang/Optimizer/Builder/FIRBuilder.h"
29+
#include "flang/Optimizer/Builder/HLFIRTools.h"
2830
#include "flang/Optimizer/Dialect/FIRType.h"
2931
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
32+
#include "flang/Optimizer/HLFIR/HLFIROps.h"
3033
#include "flang/Optimizer/OpenMP/Passes.h"
34+
#include "mlir/Analysis/SliceAnalysis.h"
3135
#include "mlir/Dialect/Func/IR/FuncOps.h"
3236
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
3337
#include "mlir/IR/BuiltinDialect.h"
@@ -486,6 +490,160 @@ class MapInfoFinalizationPass
486490
// iterations from previous function scopes.
487491
localBoxAllocas.clear();
488492

493+
// First, walk `omp.map.info` ops to see if any record members should be
494+
// implicitly mapped.
495+
func->walk([&](mlir::omp::MapInfoOp op) {
496+
mlir::Type underlyingType =
497+
fir::unwrapRefType(op.getVarPtr().getType());
498+
499+
// TODO Test with and support more complicated cases; like arrays for
500+
// records, for example.
501+
if (!fir::isRecordWithAllocatableMember(underlyingType))
502+
return mlir::WalkResult::advance();
503+
504+
// TODO For now, only consider `omp.target` ops. Other ops that support
505+
// `map` clauses will follow later.
506+
mlir::omp::TargetOp target =
507+
mlir::dyn_cast_if_present<mlir::omp::TargetOp>(
508+
getFirstTargetUser(op));
509+
510+
if (!target)
511+
return mlir::WalkResult::advance();
512+
513+
auto mapClauseOwner =
514+
llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(*target);
515+
516+
int64_t mapVarIdx = mapClauseOwner.getOperandIndexForMap(op);
517+
assert(mapVarIdx >= 0 &&
518+
mapVarIdx <
519+
static_cast<int64_t>(mapClauseOwner.getMapVars().size()));
520+
521+
auto argIface =
522+
llvm::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(*target);
523+
// TODO How should `map` block argument that correspond to: `private`,
524+
// `use_device_addr`, `use_device_ptr`, be handled?
525+
mlir::BlockArgument opBlockArg = argIface.getMapBlockArgs()[mapVarIdx];
526+
llvm::SetVector<mlir::Operation *> mapVarForwardSlice;
527+
mlir::getForwardSlice(opBlockArg, &mapVarForwardSlice);
528+
529+
mapVarForwardSlice.remove_if([&](mlir::Operation *sliceOp) {
530+
// TODO Support coordinate_of ops.
531+
//
532+
// TODO Support call ops by recursively examining the forward slice of
533+
// the corresponding parameter to the field in the called function.
534+
return !mlir::isa<hlfir::DesignateOp>(sliceOp);
535+
});
536+
537+
auto recordType = mlir::cast<fir::RecordType>(underlyingType);
538+
llvm::SmallVector<mlir::Value> newMapOpsForFields;
539+
llvm::SmallVector<int64_t> fieldIndicies;
540+
541+
for (auto fieldMemTyPair : recordType.getTypeList()) {
542+
auto &field = fieldMemTyPair.first;
543+
auto memTy = fieldMemTyPair.second;
544+
545+
bool shouldMapField =
546+
llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
547+
if (!fir::isAllocatableType(memTy))
548+
return false;
549+
550+
auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
551+
if (!designateOp)
552+
return false;
553+
554+
return designateOp.getComponent() &&
555+
designateOp.getComponent()->strref() == field;
556+
}) != mapVarForwardSlice.end();
557+
558+
// TODO Handle recursive record types. Adapting
559+
// `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
560+
// entities might be helpful here.
561+
562+
if (!shouldMapField)
563+
continue;
564+
565+
int64_t fieldIdx = recordType.getFieldIndex(field);
566+
bool alreadyMapped = [&]() {
567+
if (op.getMembersIndexAttr())
568+
for (auto indexList : op.getMembersIndexAttr()) {
569+
auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
570+
if (indexListAttr.size() == 1 &&
571+
mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
572+
fieldIdx)
573+
return true;
574+
}
575+
576+
return false;
577+
}();
578+
579+
if (alreadyMapped)
580+
continue;
581+
582+
builder.setInsertionPoint(op);
583+
mlir::Value fieldIdxVal = builder.createIntegerConstant(
584+
op.getLoc(), mlir::IndexType::get(builder.getContext()),
585+
fieldIdx);
586+
auto fieldCoord = builder.create<fir::CoordinateOp>(
587+
op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
588+
fieldIdxVal);
589+
Fortran::lower::AddrAndBoundsInfo info =
590+
Fortran::lower::getDataOperandBaseAddr(
591+
builder, fieldCoord, /*isOptional=*/false, op.getLoc());
592+
llvm::SmallVector<mlir::Value> bounds =
593+
Fortran::lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
594+
mlir::omp::MapBoundsType>(
595+
builder, info,
596+
hlfir::translateToExtendedValue(op.getLoc(), builder,
597+
hlfir::Entity{fieldCoord})
598+
.first,
599+
/*dataExvIsAssumedSize=*/false, op.getLoc());
600+
601+
mlir::omp::MapInfoOp fieldMapOp =
602+
builder.create<mlir::omp::MapInfoOp>(
603+
op.getLoc(), fieldCoord.getResult().getType(),
604+
fieldCoord.getResult(),
605+
mlir::TypeAttr::get(
606+
fir::unwrapRefType(fieldCoord.getResult().getType())),
607+
/*varPtrPtr=*/mlir::Value{},
608+
/*members=*/mlir::ValueRange{},
609+
/*members_index=*/mlir::ArrayAttr{},
610+
/*bounds=*/bounds, op.getMapTypeAttr(),
611+
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
612+
mlir::omp::VariableCaptureKind::ByRef),
613+
builder.getStringAttr(op.getNameAttr().strref() + "." +
614+
field + ".implicit_map"),
615+
/*partial_map=*/builder.getBoolAttr(false));
616+
newMapOpsForFields.emplace_back(fieldMapOp);
617+
fieldIndicies.emplace_back(fieldIdx);
618+
}
619+
620+
if (newMapOpsForFields.empty())
621+
return mlir::WalkResult::advance();
622+
623+
op.getMembersMutable().append(newMapOpsForFields);
624+
llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
625+
mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();
626+
627+
if (oldMembersIdxAttr)
628+
for (mlir::Attribute indexList : oldMembersIdxAttr) {
629+
llvm::SmallVector<int64_t> listVec;
630+
631+
for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
632+
listVec.push_back(mlir::cast<mlir::IntegerAttr>(index).getInt());
633+
634+
newMemberIndices.emplace_back(std::move(listVec));
635+
}
636+
637+
for (int64_t newFieldIdx : fieldIndicies)
638+
newMemberIndices.emplace_back(
639+
llvm::SmallVector<int64_t>(1, newFieldIdx));
640+
641+
op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices));
642+
op.setPartialMap(true);
643+
644+
return mlir::WalkResult::advance();
645+
});
646+
489647
func->walk([&](mlir::omp::MapInfoOp op) {
490648
// TODO: Currently only supports a single user for the MapInfoOp. This
491649
// is fine for the moment, as the Fortran frontend will generate a

0 commit comments

Comments
 (0)