Skip to content

Commit a706f5b

Browse files
committed
[Flang][OpenMP] Fix allocating arrays with size intrinisic
Attempt to address the following example from causing an assert or ICE: subroutine test(a) implicit none integer :: i real(kind=real64), dimension(:) :: a real(kind=real64), dimension(size(a, 1)) :: b !$omp target map(tofrom: b) do i = 1, 10 b(i) = i end do !$omp end target end subroutine Where we utilise a Fortran intrinsic (size) to calculate the size of allocatable arrays and then map it to device. Borrowing some of Kareem Ergawy's current work to disentangle bounds generation from the semantic/PFT information. Co-author: Kareem Ergawy : [email protected]
1 parent 10f315d commit a706f5b

File tree

9 files changed

+177
-44
lines changed

9 files changed

+177
-44
lines changed

flang/lib/Lower/DirectivesCommon.h

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -609,32 +609,22 @@ void createEmptyRegionBlocks(
609609
}
610610
}
611611

612-
inline AddrAndBoundsInfo
613-
getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
614-
fir::FirOpBuilder &builder,
615-
Fortran::lower::SymbolRef sym, mlir::Location loc) {
616-
mlir::Value symAddr = converter.getSymbolAddress(sym);
612+
inline AddrAndBoundsInfo getDataOperandBaseAddr(fir::FirOpBuilder &builder,
613+
mlir::Value symAddr,
614+
bool isOptional,
615+
mlir::Location loc) {
617616
mlir::Value rawInput = symAddr;
618617
if (auto declareOp =
619618
mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp())) {
620619
symAddr = declareOp.getResults()[0];
621620
rawInput = declareOp.getResults()[1];
622621
}
623622

624-
// TODO: Might need revisiting to handle for non-shared clauses
625-
if (!symAddr) {
626-
if (const auto *details =
627-
sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
628-
symAddr = converter.getSymbolAddress(details->symbol());
629-
rawInput = symAddr;
630-
}
631-
}
632-
633623
if (!symAddr)
634624
llvm::report_fatal_error("could not retrieve symbol address");
635625

636626
mlir::Value isPresent;
637-
if (Fortran::semantics::IsOptional(sym))
627+
if (isOptional)
638628
isPresent =
639629
builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), rawInput);
640630

@@ -648,8 +638,7 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
648638
// all address/dimension retrievals. For Fortran optional though, leave
649639
// the load generation for later so it can be done in the appropriate
650640
// if branches.
651-
if (mlir::isa<fir::ReferenceType>(symAddr.getType()) &&
652-
!Fortran::semantics::IsOptional(sym)) {
641+
if (mlir::isa<fir::ReferenceType>(symAddr.getType()) && !isOptional) {
653642
mlir::Value addr = builder.create<fir::LoadOp>(loc, symAddr);
654643
return AddrAndBoundsInfo(addr, rawInput, isPresent, boxTy);
655644
}
@@ -659,6 +648,14 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
659648
return AddrAndBoundsInfo(symAddr, rawInput, isPresent);
660649
}
661650

651+
inline AddrAndBoundsInfo
652+
getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
653+
fir::FirOpBuilder &builder,
654+
Fortran::lower::SymbolRef sym, mlir::Location loc) {
655+
return getDataOperandBaseAddr(builder, converter.getSymbolAddress(sym),
656+
Fortran::semantics::IsOptional(sym), loc);
657+
}
658+
662659
template <typename BoundsOp, typename BoundsType>
663660
llvm::SmallVector<mlir::Value>
664661
gatherBoundsOrBoundValues(fir::FirOpBuilder &builder, mlir::Location loc,
@@ -1224,6 +1221,26 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
12241221

12251222
return info;
12261223
}
1224+
1225+
template <typename BoundsOp, typename BoundsType>
1226+
llvm::SmallVector<mlir::Value>
1227+
genImplicitBoundsOps(fir::FirOpBuilder &builder, lower::AddrAndBoundsInfo &info,
1228+
fir::ExtendedValue dataExv, bool dataExvIsAssumedSize,
1229+
mlir::Location loc) {
1230+
llvm::SmallVector<mlir::Value> bounds;
1231+
1232+
mlir::Value baseOp = info.rawInput;
1233+
if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
1234+
bounds = lower::genBoundsOpsFromBox<BoundsOp, BoundsType>(builder, loc,
1235+
dataExv, info);
1236+
if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
1237+
bounds = lower::genBaseBoundsOps<BoundsOp, BoundsType>(
1238+
builder, loc, dataExv, dataExvIsAssumedSize);
1239+
}
1240+
1241+
return bounds;
1242+
}
1243+
12271244
} // namespace lower
12281245
} // namespace Fortran
12291246

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -922,32 +922,70 @@ static void genBodyOfTargetOp(
922922
while (!valuesDefinedAbove.empty()) {
923923
for (mlir::Value val : valuesDefinedAbove) {
924924
mlir::Operation *valOp = val.getDefiningOp();
925-
if (mlir::isMemoryEffectFree(valOp)) {
925+
assert(valOp != nullptr);
926+
927+
// NOTE: We skip BoxDimsOp's as the lesser of two evils is to map the
928+
// indices separately, as the alternative is to eventually map the Box,
929+
// which comes with a fairly large overhead comparatively. We could be
930+
// more robust about this and check using a BackWardsSlice to see if we
931+
// run the risk of mapping a box.
932+
if (mlir::isMemoryEffectFree(valOp) &&
933+
!mlir::isa<fir::BoxDimsOp>(valOp)) {
926934
mlir::Operation *clonedOp = valOp->clone();
927935
entryBlock->push_front(clonedOp);
928-
val.replaceUsesWithIf(clonedOp->getResult(0),
929-
[entryBlock](mlir::OpOperand &use) {
930-
return use.getOwner()->getBlock() == entryBlock;
931-
});
936+
937+
auto replace = [entryBlock](mlir::OpOperand &use) {
938+
return use.getOwner()->getBlock() == entryBlock;
939+
};
940+
941+
valOp->getResults().replaceUsesWithIf(clonedOp->getResults(), replace);
942+
valOp->replaceUsesWithIf(clonedOp, replace);
932943
} else {
933944
auto savedIP = firOpBuilder.getInsertionPoint();
934945
firOpBuilder.setInsertionPointAfter(valOp);
935946
auto copyVal =
936947
firOpBuilder.createTemporary(val.getLoc(), val.getType());
937948
firOpBuilder.createStoreWithConvert(copyVal.getLoc(), val, copyVal);
938949

939-
llvm::SmallVector<mlir::Value> bounds;
950+
lower::AddrAndBoundsInfo info = lower::getDataOperandBaseAddr(
951+
firOpBuilder, val, /*isOptional=*/false, val.getLoc());
952+
llvm::SmallVector<mlir::Value> bounds =
953+
Fortran::lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
954+
mlir::omp::MapBoundsType>(
955+
firOpBuilder, info,
956+
hlfir::translateToExtendedValue(val.getLoc(), firOpBuilder,
957+
hlfir::Entity{val})
958+
.first,
959+
/*dataExvIsAssumedSize=*/false, val.getLoc());
960+
940961
std::stringstream name;
941962
firOpBuilder.setInsertionPoint(targetOp);
963+
964+
llvm::omp::OpenMPOffloadMappingFlags mapFlag =
965+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
966+
mlir::omp::VariableCaptureKind captureKind =
967+
mlir::omp::VariableCaptureKind::ByRef;
968+
969+
mlir::Type eleType = copyVal.getType();
970+
if (auto refType =
971+
mlir::dyn_cast<fir::ReferenceType>(copyVal.getType()))
972+
eleType = refType.getElementType();
973+
974+
if (fir::isa_trivial(eleType) || fir::isa_char(eleType)) {
975+
captureKind = mlir::omp::VariableCaptureKind::ByCopy;
976+
} else if (!fir::isa_builtin_cptr_type(eleType)) {
977+
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
978+
}
979+
942980
mlir::Value mapOp = createMapInfoOp(
943981
firOpBuilder, copyVal.getLoc(), copyVal,
944982
/*varPtrPtr=*/mlir::Value{}, name.str(), bounds,
945983
/*members=*/llvm::SmallVector<mlir::Value>{},
946984
/*membersIndex=*/mlir::ArrayAttr{},
947985
static_cast<
948986
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
949-
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT),
950-
mlir::omp::VariableCaptureKind::ByCopy, copyVal.getType());
987+
mapFlag),
988+
captureKind, copyVal.getType());
951989

952990
// Get the index of the first non-map argument before modifying mapVars,
953991
// then append an element to mapVars and an associated entry block

flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,13 +158,19 @@ class MapInfoFinalizationPass
158158
mlir::Value baseAddrAddr = builder.create<fir::BoxOffsetOp>(
159159
loc, descriptor, fir::BoxFieldAttr::base_addr);
160160

161+
mlir::Type underlyingVarType =
162+
llvm::cast<mlir::omp::PointerLikeType>(
163+
fir::unwrapRefType(baseAddrAddr.getType()))
164+
.getElementType();
165+
if (auto seqType = llvm::dyn_cast<fir::SequenceType>(underlyingVarType))
166+
if (seqType.hasDynamicExtents())
167+
underlyingVarType = seqType.getEleTy();
168+
161169
// Member of the descriptor pointing at the allocated data
162170
return builder.create<mlir::omp::MapInfoOp>(
163171
loc, baseAddrAddr.getType(), descriptor,
164-
mlir::TypeAttr::get(llvm::cast<mlir::omp::PointerLikeType>(
165-
fir::unwrapRefType(baseAddrAddr.getType()))
166-
.getElementType()),
167-
baseAddrAddr, /*members=*/mlir::SmallVector<mlir::Value>{},
172+
mlir::TypeAttr::get(underlyingVarType), baseAddrAddr,
173+
/*members=*/mlir::SmallVector<mlir::Value>{},
168174
/*membersIndex=*/mlir::ArrayAttr{}, bounds,
169175
builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
170176
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(

flang/test/Lower/OpenMP/allocatable-array-bounds.f90

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
!HOST: %[[BOX_3:.*]]:3 = fir.box_dims %[[LOAD_3]], %[[CONSTANT_3]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>, index) -> (index, index, index)
2424
!HOST: %[[BOUNDS_1:.*]] = omp.map.bounds lower_bound(%[[LB_1]] : index) upper_bound(%[[UB_1]] : index) extent(%[[BOX_3]]#1 : index) stride(%[[BOX_2]]#2 : index) start_idx(%[[BOX_1]]#0 : index) {stride_in_bytes = true}
2525
!HOST: %[[VAR_PTR_PTR:.*]] = fir.box_offset %[[DECLARE_1]]#1 base_addr : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>
26-
!HOST: %[[MAP_INFO_MEMBER:.*]] = omp.map.info var_ptr(%[[DECLARE_1]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.array<?xi32>) var_ptr_ptr(%[[VAR_PTR_PTR]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS_1]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
26+
!HOST: %[[MAP_INFO_MEMBER:.*]] = omp.map.info var_ptr(%[[DECLARE_1]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, i32) var_ptr_ptr(%[[VAR_PTR_PTR]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS_1]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
2727
!HOST: %[[MAP_INFO_1:.*]] = omp.map.info var_ptr(%[[DECLARE_1]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.box<!fir.heap<!fir.array<?xi32>>>) map_clauses(to) capture(ByRef) members(%[[MAP_INFO_MEMBER]] : [0] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {name = "sp_read(2:5)"}
2828

2929
!HOST: %[[LOAD_3:.*]] = fir.load %[[DECLARE_2]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
@@ -41,7 +41,7 @@
4141
!HOST: %[[BOX_5:.*]]:3 = fir.box_dims %[[LOAD_5]], %[[CONSTANT_5]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>, index) -> (index, index, index)
4242
!HOST: %[[BOUNDS_2:.*]] = omp.map.bounds lower_bound(%[[LB_2]] : index) upper_bound(%[[UB_2]] : index) extent(%[[BOX_5]]#1 : index) stride(%[[BOX_4]]#2 : index) start_idx(%[[BOX_3]]#0 : index) {stride_in_bytes = true}
4343
!HOST: %[[VAR_PTR_PTR:.*]] = fir.box_offset %[[DECLARE_2]]#1 base_addr : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>
44-
!HOST: %[[MAP_INFO_MEMBER:.*]] = omp.map.info var_ptr(%[[DECLARE_2]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.array<?xi32>) var_ptr_ptr(%[[VAR_PTR_PTR]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS_2]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
44+
!HOST: %[[MAP_INFO_MEMBER:.*]] = omp.map.info var_ptr(%[[DECLARE_2]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, i32) var_ptr_ptr(%[[VAR_PTR_PTR]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS_2]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
4545
!HOST: %[[MAP_INFO_2:.*]] = omp.map.info var_ptr(%[[DECLARE_2]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.box<!fir.heap<!fir.array<?xi32>>>) map_clauses(to) capture(ByRef) members(%[[MAP_INFO_MEMBER]] : [0] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {name = "sp_write(2:5)"}
4646

4747
subroutine read_write_section()
@@ -80,8 +80,9 @@ module assumed_allocatable_array_routines
8080
!HOST: %[[BOX_3:.*]]:3 = fir.box_dims %[[LOAD_3]], %[[CONSTANT_3]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>, index) -> (index, index, index)
8181
!HOST: %[[BOUNDS:.*]] = omp.map.bounds lower_bound(%[[LB]] : index) upper_bound(%[[UB]] : index) extent(%[[BOX_3]]#1 : index) stride(%[[BOX_2]]#2 : index) start_idx(%[[BOX_1]]#0 : index) {stride_in_bytes = true}
8282
!HOST: %[[VAR_PTR_PTR:.*]] = fir.box_offset %[[DECLARE]]#1 base_addr : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>
83-
!HOST: %[[MAP_INFO_MEMBER:.*]] = omp.map.info var_ptr(%[[DECLARE]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.array<?xi32>) var_ptr_ptr(%[[VAR_PTR_PTR]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
83+
!HOST: %[[MAP_INFO_MEMBER:.*]] = omp.map.info var_ptr(%[[DECLARE]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, i32) var_ptr_ptr(%[[VAR_PTR_PTR]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
8484
!HOST: %[[MAP_INFO:.*]] = omp.map.info var_ptr(%[[DECLARE]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.box<!fir.heap<!fir.array<?xi32>>>) map_clauses(to) capture(ByRef) members(%[[MAP_INFO_MEMBER]] : [0] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {name = "arr_read_write(2:5)"}
85+
8586
subroutine assumed_shape_array(arr_read_write)
8687
integer, allocatable, intent(inout) :: arr_read_write(:)
8788

flang/test/Lower/OpenMP/array-bounds.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ module assumed_array_routines
5151
!HOST: %[[DIMS1:.*]]:3 = fir.box_dims %[[ARG0_DECL]]#1, %[[C0_1]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
5252
!HOST: %[[BOUNDS:.*]] = omp.map.bounds lower_bound(%[[C3]] : index) upper_bound(%[[C4]] : index) extent(%[[DIMS1]]#1 : index) stride(%[[DIMS0]]#2 : index) start_idx(%[[C0]] : index) {stride_in_bytes = true}
5353
!HOST: %[[VAR_PTR_PTR:.*]] = fir.box_offset %0 base_addr : (!fir.ref<!fir.box<!fir.array<?xi32>>>) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>
54-
!HOST: %[[MAP_INFO_MEMBER:.*]] = omp.map.info var_ptr(%[[INTERMEDIATE_ALLOCA]] : !fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.array<?xi32>) var_ptr_ptr(%[[VAR_PTR_PTR]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
54+
!HOST: %[[MAP_INFO_MEMBER:.*]] = omp.map.info var_ptr(%[[INTERMEDIATE_ALLOCA]] : !fir.ref<!fir.box<!fir.array<?xi32>>>, i32) var_ptr_ptr(%[[VAR_PTR_PTR]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
5555
!HOST: %[[MAP:.*]] = omp.map.info var_ptr(%[[INTERMEDIATE_ALLOCA]] : !fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.box<!fir.array<?xi32>>) map_clauses(to) capture(ByRef) members(%[[MAP_INFO_MEMBER]] : [0] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) -> !fir.ref<!fir.array<?xi32>> {name = "arr_read_write(2:5)"}
5656
!HOST: omp.target map_entries(%[[MAP]] -> %{{.*}}, {{.*}} -> {{.*}}, %[[MAP_INFO_MEMBER]] -> %{{.*}} : !fir.ref<!fir.array<?xi32>>, !fir.ref<i32>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) {
5757
subroutine assumed_shape_array(arr_read_write)

0 commit comments

Comments
 (0)