Skip to content

[flang] lower remaining cases of pointer assignments inside forall #130772

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions flang/include/flang/Optimizer/Builder/FIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -774,9 +774,19 @@ mlir::Value createZeroValue(fir::FirOpBuilder &builder, mlir::Location loc,
std::optional<std::int64_t> getExtentFromTriplet(mlir::Value lb, mlir::Value ub,
mlir::Value stride);

/// Compute the extent value given the lower bound \lb and upper bound \ub.
/// All inputs must have the same SSA integer type.
mlir::Value computeExtent(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value lb, mlir::Value ub);
mlir::Value computeExtent(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value lb, mlir::Value ub, mlir::Value zero,
mlir::Value one);

/// Generate max(\p value, 0) where \p value is a scalar integer.
mlir::Value genMaxWithZero(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value value);
mlir::Value genMaxWithZero(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value value, mlir::Value zero);

/// The type(C_PTR/C_FUNPTR) is defined as the derived type with only one
/// component of integer 64, and the component is the C address. Get the C
Expand Down
98 changes: 65 additions & 33 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4353,30 +4353,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
stmtCtx);
}

void genForallPointerAssignment(
mlir::Location loc, const Fortran::evaluate::Assignment &assign,
const Fortran::evaluate::Assignment::BoundsSpec &lbExprs) {
std::optional<Fortran::evaluate::DynamicType> lhsType =
assign.lhs.GetType();
// Polymorphic pointer assignment is delegated to the runtime, and
// PointerAssociateLowerBounds needs the lower bounds as arguments, so they
// must be preserved.
if (lhsType && lhsType->IsPolymorphic())
TODO(loc, "polymorphic pointer assignment in FORALL");
// Nullification is special, there is no RHS that can be prepared,
// need to encode it in HLFIR.
if (Fortran::evaluate::UnwrapExpr<Fortran::evaluate::NullPointer>(
assign.rhs))
TODO(loc, "NULL pointer assignment in FORALL");
// Lower bounds could be "applied" when preparing RHS, but in order
// to deal with the polymorphic case and to reuse existing pointer
// assignment helpers in HLFIR codegen, it is better to keep them
// separate.
if (!lbExprs.empty())
TODO(loc, "Pointer assignment with new lower bounds inside FORALL");
// Otherwise, this is a "dumb" pointer assignment that can be represented
// with hlfir.region_assign with descriptor address/value and later
// implemented with a store.
void genForallPointerAssignment(mlir::Location loc,
const Fortran::evaluate::Assignment &assign) {
// Lower pointer assignment inside forall with hlfir.region_assign with
// descriptor address/value and later implemented with a store.
// The RHS is fully prepared in lowering, so that all that is left
// in hlfir.region_assign code generation is the store.
auto regionAssignOp = builder->create<hlfir::RegionAssignOp>(loc);

// Lower LHS in its own region.
Expand All @@ -4400,22 +4382,73 @@ class FirConverter : public Fortran::lower::AbstractConverter {
builder->setInsertionPointAfter(regionAssignOp);
}

mlir::Value lowerToIndexValue(mlir::Location loc,
const Fortran::evaluate::ExtentExpr &expr,
Fortran::lower::StatementContext &stmtCtx) {
mlir::Value val = fir::getBase(genExprValue(toEvExpr(expr), stmtCtx));
return builder->createConvert(loc, builder->getIndexType(), val);
}

mlir::Value
genForallPointerAssignmentRhs(mlir::Location loc, mlir::Value lhs,
const Fortran::evaluate::Assignment &assign,
Fortran::lower::StatementContext &rhsContext) {
if (Fortran::evaluate::IsProcedureDesignator(assign.rhs))
if (Fortran::evaluate::IsProcedureDesignator(assign.lhs)) {
if (Fortran::evaluate::UnwrapExpr<Fortran::evaluate::NullPointer>(
assign.rhs))
return fir::factory::createNullBoxProc(
*builder, loc, fir::unwrapRefType(lhs.getType()));
return fir::getBase(Fortran::lower::convertExprToAddress(
loc, *this, assign.rhs, localSymbols, rhsContext));
}
// Data target.
auto lhsBoxType =
llvm::cast<fir::BaseBoxType>(fir::unwrapRefType(lhs.getType()));
// For NULL, create disassociated descriptor whose dynamic type is
// the static type of the LHS.
if (Fortran::evaluate::UnwrapExpr<Fortran::evaluate::NullPointer>(
assign.rhs))
return fir::factory::createUnallocatedBox(*builder, loc, lhsBoxType,
std::nullopt);
hlfir::Entity rhs = Fortran::lower::convertExprToHLFIR(
loc, *this, assign.rhs, localSymbols, rhsContext);
// Create pointer descriptor value from the RHS.
if (rhs.isMutableBox())
rhs = hlfir::Entity{builder->create<fir::LoadOp>(loc, rhs)};
auto lhsBoxType =
llvm::cast<fir::BaseBoxType>(fir::unwrapRefType(lhs.getType()));
return hlfir::genVariableBox(loc, *builder, rhs, lhsBoxType);
mlir::Value rhsBox = hlfir::genVariableBox(
loc, *builder, rhs, lhsBoxType.getBoxTypeWithNewShape(rhs.getRank()));
// Apply lower bounds or reshaping if any.
if (const auto *lbExprs =
std::get_if<Fortran::evaluate::Assignment::BoundsSpec>(&assign.u);
lbExprs && !lbExprs->empty()) {
// Override target lower bounds with the LHS bounds spec.
llvm::SmallVector<mlir::Value> lbounds;
for (const Fortran::evaluate::ExtentExpr &lbExpr : *lbExprs)
lbounds.push_back(lowerToIndexValue(loc, lbExpr, rhsContext));
mlir::Value shift = builder->genShift(loc, lbounds);
rhsBox = builder->create<fir::ReboxOp>(loc, lhsBoxType, rhsBox, shift,
/*slice=*/mlir::Value{});
} else if (const auto *boundExprs =
std::get_if<Fortran::evaluate::Assignment::BoundsRemapping>(
&assign.u);
boundExprs && !boundExprs->empty()) {
// Reshape the target according to the LHS bounds remapping.
llvm::SmallVector<mlir::Value> lbounds;
llvm::SmallVector<mlir::Value> extents;
mlir::Type indexTy = builder->getIndexType();
mlir::Value zero = builder->createIntegerConstant(loc, indexTy, 0);
mlir::Value one = builder->createIntegerConstant(loc, indexTy, 1);
for (const auto &[lbExpr, ubExpr] : *boundExprs) {
lbounds.push_back(lowerToIndexValue(loc, lbExpr, rhsContext));
mlir::Value ub = lowerToIndexValue(loc, ubExpr, rhsContext);
extents.push_back(fir::factory::computeExtent(
*builder, loc, lbounds.back(), ub, zero, one));
}
mlir::Value shape = builder->genShape(loc, lbounds, extents);
rhsBox = builder->create<fir::ReboxOp>(loc, lhsBoxType, rhsBox, shape,
/*slice=*/mlir::Value{});
}
return rhsBox;
}

// Create the 2 x newRank array with the bounds to be passed to the runtime as
Expand Down Expand Up @@ -4856,17 +4889,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
},
[&](const Fortran::evaluate::Assignment::BoundsSpec &lbExprs) {
if (isInsideHlfirForallOrWhere())
genForallPointerAssignment(loc, assign, lbExprs);
genForallPointerAssignment(loc, assign);
else
genPointerAssignment(loc, assign, lbExprs);
},
[&](const Fortran::evaluate::Assignment::BoundsRemapping
&boundExprs) {
if (isInsideHlfirForallOrWhere())
TODO(
loc,
"pointer assignment with bounds remapping inside FORALL");
genPointerAssignment(loc, assign, boundExprs);
genForallPointerAssignment(loc, assign);
else
genPointerAssignment(loc, assign, boundExprs);
},
},
assign.u);
Expand Down
18 changes: 4 additions & 14 deletions flang/lib/Lower/ConvertVariable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1519,17 +1519,6 @@ static bool lowerToBoxValue(const Fortran::semantics::Symbol &sym,
return false;
}

/// Compute extent from lower and upper bound.
static mlir::Value computeExtent(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value lb, mlir::Value ub) {
mlir::IndexType idxTy = builder.getIndexType();
// Let the folder deal with the common `ub - <const> + 1` case.
auto diff = builder.create<mlir::arith::SubIOp>(loc, idxTy, ub, lb);
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
auto rawExtent = builder.create<mlir::arith::AddIOp>(loc, idxTy, diff, one);
return fir::factory::genMaxWithZero(builder, loc, rawExtent);
}

/// Lower explicit lower bounds into \p result. Does nothing if this is not an
/// array, or if the lower bounds are deferred, or all implicit or one.
static void lowerExplicitLowerBounds(
Expand Down Expand Up @@ -1593,8 +1582,8 @@ lowerExplicitExtents(Fortran::lower::AbstractConverter &converter,
if (lowerBounds.empty())
result.emplace_back(fir::factory::genMaxWithZero(builder, loc, ub));
else
result.emplace_back(
computeExtent(builder, loc, lowerBounds[spec.index()], ub));
result.emplace_back(fir::factory::computeExtent(
builder, loc, lowerBounds[spec.index()], ub));
} else if (spec.value()->ubound().isStar()) {
result.emplace_back(getAssumedSizeExtent(loc, builder));
}
Expand Down Expand Up @@ -2214,7 +2203,8 @@ void Fortran::lower::mapSymbolAttributes(
if (auto high = spec->ubound().GetExplicit()) {
auto expr = Fortran::lower::SomeExpr{*high};
ub = builder.createConvert(loc, idxTy, genValue(expr));
extents.emplace_back(computeExtent(builder, loc, lb, ub));
extents.emplace_back(
fir::factory::computeExtent(builder, loc, lb, ub));
} else {
// An assumed size array. The extent is not computed.
assert(spec->ubound().isStar() && "expected assumed size");
Expand Down
31 changes: 28 additions & 3 deletions flang/lib/Optimizer/Builder/FIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1609,9 +1609,8 @@ fir::factory::getExtentFromTriplet(mlir::Value lb, mlir::Value ub,
}

mlir::Value fir::factory::genMaxWithZero(fir::FirOpBuilder &builder,
mlir::Location loc,
mlir::Value value) {
mlir::Value zero = builder.createIntegerConstant(loc, value.getType(), 0);
mlir::Location loc, mlir::Value value,
mlir::Value zero) {
if (mlir::Operation *definingOp = value.getDefiningOp())
if (auto cst = mlir::dyn_cast<mlir::arith::ConstantOp>(definingOp))
if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(cst.getValue()))
Expand All @@ -1622,6 +1621,32 @@ mlir::Value fir::factory::genMaxWithZero(fir::FirOpBuilder &builder,
zero);
}

mlir::Value fir::factory::genMaxWithZero(fir::FirOpBuilder &builder,
mlir::Location loc,
mlir::Value value) {
mlir::Value zero = builder.createIntegerConstant(loc, value.getType(), 0);
return genMaxWithZero(builder, loc, value, zero);
}

mlir::Value fir::factory::computeExtent(fir::FirOpBuilder &builder,
mlir::Location loc, mlir::Value lb,
mlir::Value ub, mlir::Value zero,
mlir::Value one) {
mlir::Type type = lb.getType();
// Let the folder deal with the common `ub - <const> + 1` case.
auto diff = builder.create<mlir::arith::SubIOp>(loc, type, ub, lb);
auto rawExtent = builder.create<mlir::arith::AddIOp>(loc, type, diff, one);
return fir::factory::genMaxWithZero(builder, loc, rawExtent, zero);
}
mlir::Value fir::factory::computeExtent(fir::FirOpBuilder &builder,
mlir::Location loc, mlir::Value lb,
mlir::Value ub) {
mlir::Type type = lb.getType();
mlir::Value one = builder.createIntegerConstant(loc, type, 1);
mlir::Value zero = builder.createIntegerConstant(loc, type, 0);
return computeExtent(builder, loc, lb, ub, zero, one);
}

static std::pair<mlir::Value, mlir::Type>
genCPtrOrCFunptrFieldIndex(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type cptrTy) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
! Test analysis of pointer assignment inside FORALL with lower bounds or bounds
! remapping.
! The analysis must detect if the evaluation of the LHS or RHS may be impacted
! by the pointer assignments, or if the forall can be lowered into a single
! loop without any temporary copy.

! RUN: bbc -hlfir -o /dev/null -pass-pipeline="builtin.module(lower-hlfir-ordered-assignments)" \
! RUN: --debug-only=flang-ordered-assignment -flang-dbg-order-assignment-schedule-only %s 2>&1 | FileCheck %s
! REQUIRES: asserts
module forall_pointers_bounds
type ptr_wrapper
integer, pointer :: p(:, :)
end type
contains

! Simple case that can be lowered into a single loop.
subroutine test_lb_no_conflict(a, iarray)
type(ptr_wrapper) :: a(:)
integer, target :: iarray(:, :)
forall(i=lbound(a,1):ubound(a,1)) a(i)%p(2*(i-1)+1:,2*i:) => iarray
end subroutine

subroutine test_remapping_no_conflict(a, iarray)
type(ptr_wrapper) :: a(:)
integer, target :: iarray(6)
! Reshaping 6 to 2x3 with custom lower bounds.
forall(i=lbound(a,1):ubound(a,1)) a(i)%p(2*(i-1)+1:2*(i-1)+2,2*i:2*i+2) => iarray
end subroutine
! CHECK: ------------ scheduling forall in _QMforall_pointers_boundsPtest_remapping_no_conflict ------------
! CHECK-NEXT: run 1 evaluate: forall/region_assign1

! Bounds expression conflict. Note that even though they are syntactically on
! the LHS,they are saved with the RHS because they are applied when preparing the
! new descriptor value pointing to the RHS.
subroutine test_lb_conflict(a, iarray)
type(ptr_wrapper) :: a(:)
integer, target :: iarray(:, :)
integer :: n
n = ubound(a,1)
forall(i=lbound(a,1):ubound(a,1)) a(i)%p(a(n+1-i)%p(1,1):,a(n+1-i)%p(2,1):) => iarray
end subroutine
! CHECK: ------------ scheduling forall in _QMforall_pointers_boundsPtest_lb_conflict ------------
! CHECK-NEXT: conflict: R/W
! CHECK-NEXT: run 1 save : forall/region_assign1/rhs
! CHECK-NEXT: run 2 evaluate: forall/region_assign1

end module

! End to end test provided for debugging purpose (not run by lit).
program end_to_end
use forall_pointers_bounds
integer, parameter :: n = 5
integer, target, save :: data(2, 2, n) = reshape([(i, i=1,size(data))], shape=shape(data))
integer, target, save :: data2(6) = reshape([(i, i=1,size(data2))], shape=shape(data2))
type(ptr_wrapper) :: pointers(n)
! Print pointer/target mapping baseline.
call reset_pointers(pointers)
if (.not.check_equal(pointers, [17,18,19,20,13,14,15,16,9,10,11,12,5,6,7,8,1,2,3,4])) stop 1

call reset_pointers(pointers)
call test_lb_no_conflict(pointers, data(:, :, 1))
if (.not.check_equal(pointers, [([1,2,3,4],i=1,n)])) stop 2
if (.not.all([(lbound(pointers(i)%p), i=1,n)].eq.[(i, i=1,2*n)])) stop 3

call reset_pointers(pointers)
call test_remapping_no_conflict(pointers, data2)
if (.not.check_equal(pointers, [([1,2,3,4,5,6],i=1,n)])) stop 4
if (.not.all([(lbound(pointers(i)%p), i=1,n)].eq.[(i, i=1,2*n)])) stop 5
if (.not.all([(ubound(pointers(i)%p), i=1,n)].eq.[([2*(i-1)+2, 2*i+2], i=1,n)])) stop 6

call reset_pointers(pointers)
call test_lb_conflict(pointers, data(:, :, 1))
if (.not.check_equal(pointers, [([1,2,3,4],i=1,n)])) stop 7
if (.not.all([(lbound(pointers(i)%p), i=1,n)].eq.[([data(1,1,i), data(2,1,i)], i=1,n)])) stop 8

print *, "PASS"
contains
subroutine reset_pointers(a)
type(ptr_wrapper) :: a(:)
do i=1,n
a(i)%p => data(:, :, n+1-i)
end do
end subroutine
logical function check_equal(a, expected)
type(ptr_wrapper) :: a(:)
integer :: expected(:)
check_equal = all([(a(i)%p, i=1,n)].eq.expected)
if (.not.check_equal) then
print *, "expected:", expected
print *, "got:", [(a(i)%p, i=1,n)]
end if
end function
end
Loading