Skip to content

Commit 3ff3b29

Browse files
authored
[flang] lower remaining cases of pointer assignments inside forall (#130772)
Implement handling of `NULL()` RHS, polymorphic pointers, as well as lower bounds or bounds remapping in pointer assignment inside FORALL. These cases eventually do not require updating hlfir.region_assign, lowering can simply prepare the new descriptor for the LHS inside the RHS region. Looking more closely at the polymorphic cases, there is not need to call the runtime, fir.rebox and fir.embox do handle the dynamic type setting correctly. After this patch, the last remaining TODO is the allocatable assignment inside FORALL, which like some cases here, is more likely an accidental feature given FORALL was deprecated in F2003 at the same time than allocatable components where added.
1 parent 7bae613 commit 3ff3b29

10 files changed

+392
-65
lines changed

flang/include/flang/Optimizer/Builder/FIRBuilder.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,9 +774,19 @@ mlir::Value createZeroValue(fir::FirOpBuilder &builder, mlir::Location loc,
774774
std::optional<std::int64_t> getExtentFromTriplet(mlir::Value lb, mlir::Value ub,
775775
mlir::Value stride);
776776

777+
/// Compute the extent value given the lower bound \lb and upper bound \ub.
778+
/// All inputs must have the same SSA integer type.
779+
mlir::Value computeExtent(fir::FirOpBuilder &builder, mlir::Location loc,
780+
mlir::Value lb, mlir::Value ub);
781+
mlir::Value computeExtent(fir::FirOpBuilder &builder, mlir::Location loc,
782+
mlir::Value lb, mlir::Value ub, mlir::Value zero,
783+
mlir::Value one);
784+
777785
/// Generate max(\p value, 0) where \p value is a scalar integer.
778786
mlir::Value genMaxWithZero(fir::FirOpBuilder &builder, mlir::Location loc,
779787
mlir::Value value);
788+
mlir::Value genMaxWithZero(fir::FirOpBuilder &builder, mlir::Location loc,
789+
mlir::Value value, mlir::Value zero);
780790

781791
/// The type(C_PTR/C_FUNPTR) is defined as the derived type with only one
782792
/// component of integer 64, and the component is the C address. Get the C

flang/lib/Lower/Bridge.cpp

Lines changed: 65 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4353,30 +4353,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
43534353
stmtCtx);
43544354
}
43554355

4356-
void genForallPointerAssignment(
4357-
mlir::Location loc, const Fortran::evaluate::Assignment &assign,
4358-
const Fortran::evaluate::Assignment::BoundsSpec &lbExprs) {
4359-
std::optional<Fortran::evaluate::DynamicType> lhsType =
4360-
assign.lhs.GetType();
4361-
// Polymorphic pointer assignment is delegated to the runtime, and
4362-
// PointerAssociateLowerBounds needs the lower bounds as arguments, so they
4363-
// must be preserved.
4364-
if (lhsType && lhsType->IsPolymorphic())
4365-
TODO(loc, "polymorphic pointer assignment in FORALL");
4366-
// Nullification is special, there is no RHS that can be prepared,
4367-
// need to encode it in HLFIR.
4368-
if (Fortran::evaluate::UnwrapExpr<Fortran::evaluate::NullPointer>(
4369-
assign.rhs))
4370-
TODO(loc, "NULL pointer assignment in FORALL");
4371-
// Lower bounds could be "applied" when preparing RHS, but in order
4372-
// to deal with the polymorphic case and to reuse existing pointer
4373-
// assignment helpers in HLFIR codegen, it is better to keep them
4374-
// separate.
4375-
if (!lbExprs.empty())
4376-
TODO(loc, "Pointer assignment with new lower bounds inside FORALL");
4377-
// Otherwise, this is a "dumb" pointer assignment that can be represented
4378-
// with hlfir.region_assign with descriptor address/value and later
4379-
// implemented with a store.
4356+
void genForallPointerAssignment(mlir::Location loc,
4357+
const Fortran::evaluate::Assignment &assign) {
4358+
// Lower pointer assignment inside forall with hlfir.region_assign with
4359+
// descriptor address/value and later implemented with a store.
4360+
// The RHS is fully prepared in lowering, so that all that is left
4361+
// in hlfir.region_assign code generation is the store.
43804362
auto regionAssignOp = builder->create<hlfir::RegionAssignOp>(loc);
43814363

43824364
// Lower LHS in its own region.
@@ -4400,22 +4382,73 @@ class FirConverter : public Fortran::lower::AbstractConverter {
44004382
builder->setInsertionPointAfter(regionAssignOp);
44014383
}
44024384

4385+
mlir::Value lowerToIndexValue(mlir::Location loc,
4386+
const Fortran::evaluate::ExtentExpr &expr,
4387+
Fortran::lower::StatementContext &stmtCtx) {
4388+
mlir::Value val = fir::getBase(genExprValue(toEvExpr(expr), stmtCtx));
4389+
return builder->createConvert(loc, builder->getIndexType(), val);
4390+
}
4391+
44034392
mlir::Value
44044393
genForallPointerAssignmentRhs(mlir::Location loc, mlir::Value lhs,
44054394
const Fortran::evaluate::Assignment &assign,
44064395
Fortran::lower::StatementContext &rhsContext) {
4407-
if (Fortran::evaluate::IsProcedureDesignator(assign.rhs))
4396+
if (Fortran::evaluate::IsProcedureDesignator(assign.lhs)) {
4397+
if (Fortran::evaluate::UnwrapExpr<Fortran::evaluate::NullPointer>(
4398+
assign.rhs))
4399+
return fir::factory::createNullBoxProc(
4400+
*builder, loc, fir::unwrapRefType(lhs.getType()));
44084401
return fir::getBase(Fortran::lower::convertExprToAddress(
44094402
loc, *this, assign.rhs, localSymbols, rhsContext));
4403+
}
44104404
// Data target.
4405+
auto lhsBoxType =
4406+
llvm::cast<fir::BaseBoxType>(fir::unwrapRefType(lhs.getType()));
4407+
// For NULL, create disassociated descriptor whose dynamic type is
4408+
// the static type of the LHS.
4409+
if (Fortran::evaluate::UnwrapExpr<Fortran::evaluate::NullPointer>(
4410+
assign.rhs))
4411+
return fir::factory::createUnallocatedBox(*builder, loc, lhsBoxType,
4412+
std::nullopt);
44114413
hlfir::Entity rhs = Fortran::lower::convertExprToHLFIR(
44124414
loc, *this, assign.rhs, localSymbols, rhsContext);
44134415
// Create pointer descriptor value from the RHS.
44144416
if (rhs.isMutableBox())
44154417
rhs = hlfir::Entity{builder->create<fir::LoadOp>(loc, rhs)};
4416-
auto lhsBoxType =
4417-
llvm::cast<fir::BaseBoxType>(fir::unwrapRefType(lhs.getType()));
4418-
return hlfir::genVariableBox(loc, *builder, rhs, lhsBoxType);
4418+
mlir::Value rhsBox = hlfir::genVariableBox(
4419+
loc, *builder, rhs, lhsBoxType.getBoxTypeWithNewShape(rhs.getRank()));
4420+
// Apply lower bounds or reshaping if any.
4421+
if (const auto *lbExprs =
4422+
std::get_if<Fortran::evaluate::Assignment::BoundsSpec>(&assign.u);
4423+
lbExprs && !lbExprs->empty()) {
4424+
// Override target lower bounds with the LHS bounds spec.
4425+
llvm::SmallVector<mlir::Value> lbounds;
4426+
for (const Fortran::evaluate::ExtentExpr &lbExpr : *lbExprs)
4427+
lbounds.push_back(lowerToIndexValue(loc, lbExpr, rhsContext));
4428+
mlir::Value shift = builder->genShift(loc, lbounds);
4429+
rhsBox = builder->create<fir::ReboxOp>(loc, lhsBoxType, rhsBox, shift,
4430+
/*slice=*/mlir::Value{});
4431+
} else if (const auto *boundExprs =
4432+
std::get_if<Fortran::evaluate::Assignment::BoundsRemapping>(
4433+
&assign.u);
4434+
boundExprs && !boundExprs->empty()) {
4435+
// Reshape the target according to the LHS bounds remapping.
4436+
llvm::SmallVector<mlir::Value> lbounds;
4437+
llvm::SmallVector<mlir::Value> extents;
4438+
mlir::Type indexTy = builder->getIndexType();
4439+
mlir::Value zero = builder->createIntegerConstant(loc, indexTy, 0);
4440+
mlir::Value one = builder->createIntegerConstant(loc, indexTy, 1);
4441+
for (const auto &[lbExpr, ubExpr] : *boundExprs) {
4442+
lbounds.push_back(lowerToIndexValue(loc, lbExpr, rhsContext));
4443+
mlir::Value ub = lowerToIndexValue(loc, ubExpr, rhsContext);
4444+
extents.push_back(fir::factory::computeExtent(
4445+
*builder, loc, lbounds.back(), ub, zero, one));
4446+
}
4447+
mlir::Value shape = builder->genShape(loc, lbounds, extents);
4448+
rhsBox = builder->create<fir::ReboxOp>(loc, lhsBoxType, rhsBox, shape,
4449+
/*slice=*/mlir::Value{});
4450+
}
4451+
return rhsBox;
44194452
}
44204453

44214454
// Create the 2 x newRank array with the bounds to be passed to the runtime as
@@ -4856,17 +4889,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
48564889
},
48574890
[&](const Fortran::evaluate::Assignment::BoundsSpec &lbExprs) {
48584891
if (isInsideHlfirForallOrWhere())
4859-
genForallPointerAssignment(loc, assign, lbExprs);
4892+
genForallPointerAssignment(loc, assign);
48604893
else
48614894
genPointerAssignment(loc, assign, lbExprs);
48624895
},
48634896
[&](const Fortran::evaluate::Assignment::BoundsRemapping
48644897
&boundExprs) {
48654898
if (isInsideHlfirForallOrWhere())
4866-
TODO(
4867-
loc,
4868-
"pointer assignment with bounds remapping inside FORALL");
4869-
genPointerAssignment(loc, assign, boundExprs);
4899+
genForallPointerAssignment(loc, assign);
4900+
else
4901+
genPointerAssignment(loc, assign, boundExprs);
48704902
},
48714903
},
48724904
assign.u);

flang/lib/Lower/ConvertVariable.cpp

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1577,17 +1577,6 @@ static bool lowerToBoxValue(const Fortran::semantics::Symbol &sym,
15771577
return false;
15781578
}
15791579

1580-
/// Compute extent from lower and upper bound.
1581-
static mlir::Value computeExtent(fir::FirOpBuilder &builder, mlir::Location loc,
1582-
mlir::Value lb, mlir::Value ub) {
1583-
mlir::IndexType idxTy = builder.getIndexType();
1584-
// Let the folder deal with the common `ub - <const> + 1` case.
1585-
auto diff = builder.create<mlir::arith::SubIOp>(loc, idxTy, ub, lb);
1586-
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
1587-
auto rawExtent = builder.create<mlir::arith::AddIOp>(loc, idxTy, diff, one);
1588-
return fir::factory::genMaxWithZero(builder, loc, rawExtent);
1589-
}
1590-
15911580
/// Lower explicit lower bounds into \p result. Does nothing if this is not an
15921581
/// array, or if the lower bounds are deferred, or all implicit or one.
15931582
static void lowerExplicitLowerBounds(
@@ -1651,8 +1640,8 @@ lowerExplicitExtents(Fortran::lower::AbstractConverter &converter,
16511640
if (lowerBounds.empty())
16521641
result.emplace_back(fir::factory::genMaxWithZero(builder, loc, ub));
16531642
else
1654-
result.emplace_back(
1655-
computeExtent(builder, loc, lowerBounds[spec.index()], ub));
1643+
result.emplace_back(fir::factory::computeExtent(
1644+
builder, loc, lowerBounds[spec.index()], ub));
16561645
} else if (spec.value()->ubound().isStar()) {
16571646
result.emplace_back(getAssumedSizeExtent(loc, builder));
16581647
}
@@ -2272,7 +2261,8 @@ void Fortran::lower::mapSymbolAttributes(
22722261
if (auto high = spec->ubound().GetExplicit()) {
22732262
auto expr = Fortran::lower::SomeExpr{*high};
22742263
ub = builder.createConvert(loc, idxTy, genValue(expr));
2275-
extents.emplace_back(computeExtent(builder, loc, lb, ub));
2264+
extents.emplace_back(
2265+
fir::factory::computeExtent(builder, loc, lb, ub));
22762266
} else {
22772267
// An assumed size array. The extent is not computed.
22782268
assert(spec->ubound().isStar() && "expected assumed size");

flang/lib/Optimizer/Builder/FIRBuilder.cpp

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1609,9 +1609,8 @@ fir::factory::getExtentFromTriplet(mlir::Value lb, mlir::Value ub,
16091609
}
16101610

16111611
mlir::Value fir::factory::genMaxWithZero(fir::FirOpBuilder &builder,
1612-
mlir::Location loc,
1613-
mlir::Value value) {
1614-
mlir::Value zero = builder.createIntegerConstant(loc, value.getType(), 0);
1612+
mlir::Location loc, mlir::Value value,
1613+
mlir::Value zero) {
16151614
if (mlir::Operation *definingOp = value.getDefiningOp())
16161615
if (auto cst = mlir::dyn_cast<mlir::arith::ConstantOp>(definingOp))
16171616
if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(cst.getValue()))
@@ -1622,6 +1621,32 @@ mlir::Value fir::factory::genMaxWithZero(fir::FirOpBuilder &builder,
16221621
zero);
16231622
}
16241623

1624+
mlir::Value fir::factory::genMaxWithZero(fir::FirOpBuilder &builder,
1625+
mlir::Location loc,
1626+
mlir::Value value) {
1627+
mlir::Value zero = builder.createIntegerConstant(loc, value.getType(), 0);
1628+
return genMaxWithZero(builder, loc, value, zero);
1629+
}
1630+
1631+
mlir::Value fir::factory::computeExtent(fir::FirOpBuilder &builder,
1632+
mlir::Location loc, mlir::Value lb,
1633+
mlir::Value ub, mlir::Value zero,
1634+
mlir::Value one) {
1635+
mlir::Type type = lb.getType();
1636+
// Let the folder deal with the common `ub - <const> + 1` case.
1637+
auto diff = builder.create<mlir::arith::SubIOp>(loc, type, ub, lb);
1638+
auto rawExtent = builder.create<mlir::arith::AddIOp>(loc, type, diff, one);
1639+
return fir::factory::genMaxWithZero(builder, loc, rawExtent, zero);
1640+
}
1641+
mlir::Value fir::factory::computeExtent(fir::FirOpBuilder &builder,
1642+
mlir::Location loc, mlir::Value lb,
1643+
mlir::Value ub) {
1644+
mlir::Type type = lb.getType();
1645+
mlir::Value one = builder.createIntegerConstant(loc, type, 1);
1646+
mlir::Value zero = builder.createIntegerConstant(loc, type, 0);
1647+
return computeExtent(builder, loc, lb, ub, zero, one);
1648+
}
1649+
16251650
static std::pair<mlir::Value, mlir::Type>
16261651
genCPtrOrCFunptrFieldIndex(fir::FirOpBuilder &builder, mlir::Location loc,
16271652
mlir::Type cptrTy) {
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
! Test analysis of pointer assignment inside FORALL with lower bounds or bounds
2+
! remapping.
3+
! The analysis must detect if the evaluation of the LHS or RHS may be impacted
4+
! by the pointer assignments, or if the forall can be lowered into a single
5+
! loop without any temporary copy.
6+
7+
! RUN: bbc -hlfir -o /dev/null -pass-pipeline="builtin.module(lower-hlfir-ordered-assignments)" \
8+
! RUN: --debug-only=flang-ordered-assignment -flang-dbg-order-assignment-schedule-only %s 2>&1 | FileCheck %s
9+
! REQUIRES: asserts
10+
module forall_pointers_bounds
11+
type ptr_wrapper
12+
integer, pointer :: p(:, :)
13+
end type
14+
contains
15+
16+
! Simple case that can be lowered into a single loop.
17+
subroutine test_lb_no_conflict(a, iarray)
18+
type(ptr_wrapper) :: a(:)
19+
integer, target :: iarray(:, :)
20+
forall(i=lbound(a,1):ubound(a,1)) a(i)%p(2*(i-1)+1:,2*i:) => iarray
21+
end subroutine
22+
23+
subroutine test_remapping_no_conflict(a, iarray)
24+
type(ptr_wrapper) :: a(:)
25+
integer, target :: iarray(6)
26+
! Reshaping 6 to 2x3 with custom lower bounds.
27+
forall(i=lbound(a,1):ubound(a,1)) a(i)%p(2*(i-1)+1:2*(i-1)+2,2*i:2*i+2) => iarray
28+
end subroutine
29+
! CHECK: ------------ scheduling forall in _QMforall_pointers_boundsPtest_remapping_no_conflict ------------
30+
! CHECK-NEXT: run 1 evaluate: forall/region_assign1
31+
32+
! Bounds expression conflict. Note that even though they are syntactically on
33+
! the LHS,they are saved with the RHS because they are applied when preparing the
34+
! new descriptor value pointing to the RHS.
35+
subroutine test_lb_conflict(a, iarray)
36+
type(ptr_wrapper) :: a(:)
37+
integer, target :: iarray(:, :)
38+
integer :: n
39+
n = ubound(a,1)
40+
forall(i=lbound(a,1):ubound(a,1)) a(i)%p(a(n+1-i)%p(1,1):,a(n+1-i)%p(2,1):) => iarray
41+
end subroutine
42+
! CHECK: ------------ scheduling forall in _QMforall_pointers_boundsPtest_lb_conflict ------------
43+
! CHECK-NEXT: conflict: R/W
44+
! CHECK-NEXT: run 1 save : forall/region_assign1/rhs
45+
! CHECK-NEXT: run 2 evaluate: forall/region_assign1
46+
47+
end module
48+
49+
! End to end test provided for debugging purpose (not run by lit).
50+
program end_to_end
51+
use forall_pointers_bounds
52+
integer, parameter :: n = 5
53+
integer, target, save :: data(2, 2, n) = reshape([(i, i=1,size(data))], shape=shape(data))
54+
integer, target, save :: data2(6) = reshape([(i, i=1,size(data2))], shape=shape(data2))
55+
type(ptr_wrapper) :: pointers(n)
56+
! Print pointer/target mapping baseline.
57+
call reset_pointers(pointers)
58+
if (.not.check_equal(pointers, [17,18,19,20,13,14,15,16,9,10,11,12,5,6,7,8,1,2,3,4])) stop 1
59+
60+
call reset_pointers(pointers)
61+
call test_lb_no_conflict(pointers, data(:, :, 1))
62+
if (.not.check_equal(pointers, [([1,2,3,4],i=1,n)])) stop 2
63+
if (.not.all([(lbound(pointers(i)%p), i=1,n)].eq.[(i, i=1,2*n)])) stop 3
64+
65+
call reset_pointers(pointers)
66+
call test_remapping_no_conflict(pointers, data2)
67+
if (.not.check_equal(pointers, [([1,2,3,4,5,6],i=1,n)])) stop 4
68+
if (.not.all([(lbound(pointers(i)%p), i=1,n)].eq.[(i, i=1,2*n)])) stop 5
69+
if (.not.all([(ubound(pointers(i)%p), i=1,n)].eq.[([2*(i-1)+2, 2*i+2], i=1,n)])) stop 6
70+
71+
call reset_pointers(pointers)
72+
call test_lb_conflict(pointers, data(:, :, 1))
73+
if (.not.check_equal(pointers, [([1,2,3,4],i=1,n)])) stop 7
74+
if (.not.all([(lbound(pointers(i)%p), i=1,n)].eq.[([data(1,1,i), data(2,1,i)], i=1,n)])) stop 8
75+
76+
print *, "PASS"
77+
contains
78+
subroutine reset_pointers(a)
79+
type(ptr_wrapper) :: a(:)
80+
do i=1,n
81+
a(i)%p => data(:, :, n+1-i)
82+
end do
83+
end subroutine
84+
logical function check_equal(a, expected)
85+
type(ptr_wrapper) :: a(:)
86+
integer :: expected(:)
87+
check_equal = all([(a(i)%p, i=1,n)].eq.expected)
88+
if (.not.check_equal) then
89+
print *, "expected:", expected
90+
print *, "got:", [(a(i)%p, i=1,n)]
91+
end if
92+
end function
93+
end

0 commit comments

Comments
 (0)