Skip to content

Commit 7302e1b

Browse files
authored
[flang] implement simple pointer assignments inside FORALL (#129522)
The semantic of pointer assignments inside FORALL requires evaluating the targets (RHS) and pointer variables (LHS) of all iterations before evaluating the assignments. In practice, if the compiler can prove that the RHS and LHS evaluations are not impacted by the assignments, the evaluation of the FORALL assignment statement can be done in a single loop. However, if the compiler cannot prove this, it needs to "save" the addresses of the targets and/or the pointer descriptors of each iterations before doing the assignments. This patch implements the most common cases where there is no lower bound spec, no bounds remapping, the LHS is not polymorphic, and the RHS is not NULL. The HLFIR operation used to represent assignments inside FORALL can be used for pointer assignments to (the only difference being that the LHS is a descriptor address). The analysis for intrinsic assignment can be reused, with the distinction that the RHS data is not read during the assignment. The logic that is used to save LHS in intrinsic assignments inside FORALL is extracted to be used for the RHS of pointer assignments when needed (saving a descriptor value). Pointer assignment LHS are just descriptor addresses and are saved as int_ptr values.
1 parent 5916903 commit 7302e1b

File tree

14 files changed

+605
-32
lines changed

14 files changed

+605
-32
lines changed

flang/include/flang/Optimizer/Builder/HLFIRTools.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,11 @@ mlir::Value genVariableBoxChar(mlir::Location loc, fir::FirOpBuilder &builder,
249249
hlfir::Entity var);
250250

251251
/// Get or create a fir.box or fir.class from a variable.
252+
/// A fir.box with different attributes that \p var can be created
253+
/// using \p forceBoxType.
252254
hlfir::Entity genVariableBox(mlir::Location loc, fir::FirOpBuilder &builder,
253-
hlfir::Entity var);
255+
hlfir::Entity var,
256+
fir::BaseBoxType forceBoxType = {});
254257

255258
/// If the entity is a variable, load its value (dereference pointers and
256259
/// allocatables if needed). Do nothing if the entity is already a value, and

flang/include/flang/Optimizer/Builder/TemporaryStorage.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ class AnyValueStack {
179179
/// type. Fetching variable N will return a variable with the same address,
180180
/// dynamic type, bounds, and type parameters as the Nth variable that was
181181
/// pushed. It is implemented using runtime.
182+
/// Note that this is not meant to save POINTER or ALLOCATABLE descriptor
183+
/// addresses, use AnyDescriptorAddressStack instead.
182184
class AnyVariableStack {
183185
public:
184186
AnyVariableStack(mlir::Location loc, fir::FirOpBuilder &builder,
@@ -203,6 +205,21 @@ class AnyVariableStack {
203205
mlir::Value retValueBox;
204206
};
205207

208+
/// Data structure to stack descriptor addresses. It stores the descriptor
209+
/// addresses as int_ptr values under the hood.
210+
class AnyDescriptorAddressStack : public AnyValueStack {
211+
public:
212+
AnyDescriptorAddressStack(mlir::Location loc, fir::FirOpBuilder &builder,
213+
mlir::Type descriptorAddressType);
214+
215+
void pushValue(mlir::Location loc, fir::FirOpBuilder &builder,
216+
mlir::Value value);
217+
mlir::Value fetch(mlir::Location loc, fir::FirOpBuilder &builder);
218+
219+
private:
220+
mlir::Type descriptorAddressType;
221+
};
222+
206223
class TemporaryStorage;
207224

208225
/// Data structure to stack vector subscripted entity shape and
@@ -264,7 +281,8 @@ class TemporaryStorage {
264281

265282
private:
266283
std::variant<HomogeneousScalarStack, SimpleCopy, SSARegister, AnyValueStack,
267-
AnyVariableStack, AnyVectorSubscriptStack>
284+
AnyVariableStack, AnyVectorSubscriptStack,
285+
AnyDescriptorAddressStack>
268286
impl;
269287
};
270288
} // namespace fir::factory

flang/include/flang/Optimizer/Dialect/FIRType.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ class BaseBoxType : public mlir::Type {
5353
/// Is this the box for an assumed rank?
5454
bool isAssumedRank() const;
5555

56+
/// Is this a box for a pointer?
57+
bool isPointer() const;
58+
5659
/// Return the same type, except for the shape, that is taken the shape
5760
/// of shapeMold.
5861
BaseBoxType getBoxTypeWithNewShape(mlir::Type shapeMold) const;

flang/include/flang/Optimizer/HLFIR/HLFIROps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1377,7 +1377,7 @@ def hlfir_RegionAssignOp : hlfir_Op<"region_assign", [hlfir_OrderedAssignmentTre
13771377
regions.push_back(&getUserDefinedAssignment());
13781378
}
13791379
mlir::Region* getSubTreeRegion() { return nullptr; }
1380-
1380+
bool isPointerAssignment();
13811381
}];
13821382

13831383
let hasCustomAssemblyFormat = 1;

flang/lib/Lower/Bridge.cpp

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4355,6 +4355,62 @@ class FirConverter : public Fortran::lower::AbstractConverter {
43554355
stmtCtx);
43564356
}
43574357

4358+
void genForallPointerAssignment(
4359+
mlir::Location loc, const Fortran::evaluate::Assignment &assign,
4360+
const Fortran::evaluate::Assignment::BoundsSpec &lbExprs) {
4361+
if (Fortran::evaluate::IsProcedureDesignator(assign.rhs))
4362+
TODO(loc, "procedure pointer assignment inside FORALL");
4363+
std::optional<Fortran::evaluate::DynamicType> lhsType =
4364+
assign.lhs.GetType();
4365+
// Polymorphic pointer assignment is delegated to the runtime, and
4366+
// PointerAssociateLowerBounds needs the lower bounds as arguments, so they
4367+
// must be preserved.
4368+
if (lhsType && lhsType->IsPolymorphic())
4369+
TODO(loc, "polymorphic pointer assignment in FORALL");
4370+
// Nullification is special, there is no RHS that can be prepared,
4371+
// need to encode it in HLFIR.
4372+
if (Fortran::evaluate::UnwrapExpr<Fortran::evaluate::NullPointer>(
4373+
assign.rhs))
4374+
TODO(loc, "NULL pointer assignment in FORALL");
4375+
// Lower bounds could be "applied" when preparing RHS, but in order
4376+
// to deal with the polymorphic case and to reuse existing pointer
4377+
// assignment helpers in HLFIR codegen, it is better to keep them
4378+
// separate.
4379+
if (!lbExprs.empty())
4380+
TODO(loc, "Pointer assignment with new lower bounds inside FORALL");
4381+
// Otherwise, this is a "dumb" pointer assignment that can be represented
4382+
// with hlfir.region_assign with descriptor address/value and later
4383+
// implemented with a store.
4384+
auto regionAssignOp = builder->create<hlfir::RegionAssignOp>(loc);
4385+
4386+
// Lower LHS in its own region.
4387+
builder->createBlock(&regionAssignOp.getLhsRegion());
4388+
Fortran::lower::StatementContext lhsContext;
4389+
hlfir::Entity lhs = Fortran::lower::convertExprToHLFIR(
4390+
loc, *this, assign.lhs, localSymbols, lhsContext);
4391+
4392+
auto lhsYieldOp = builder->create<hlfir::YieldOp>(loc, lhs);
4393+
Fortran::lower::genCleanUpInRegionIfAny(
4394+
loc, *builder, lhsYieldOp.getCleanup(), lhsContext);
4395+
4396+
// Lower RHS in its own region.
4397+
builder->createBlock(&regionAssignOp.getRhsRegion());
4398+
Fortran::lower::StatementContext rhsContext;
4399+
hlfir::Entity rhs = Fortran::lower::convertExprToHLFIR(
4400+
loc, *this, assign.rhs, localSymbols, rhsContext);
4401+
// Create pointer descriptor value from the RHS.
4402+
if (rhs.isMutableBox())
4403+
rhs = hlfir::Entity{builder->create<fir::LoadOp>(loc, rhs)};
4404+
auto lhsBoxType =
4405+
llvm::cast<fir::BaseBoxType>(fir::unwrapRefType(lhs.getType()));
4406+
mlir::Value newBox = hlfir::genVariableBox(loc, *builder, rhs, lhsBoxType);
4407+
auto rhsYieldOp = builder->create<hlfir::YieldOp>(loc, newBox);
4408+
Fortran::lower::genCleanUpInRegionIfAny(
4409+
loc, *builder, rhsYieldOp.getCleanup(), rhsContext);
4410+
4411+
builder->setInsertionPointAfter(regionAssignOp);
4412+
}
4413+
43584414
// Create the 2 x newRank array with the bounds to be passed to the runtime as
43594415
// a descriptor.
43604416
mlir::Value createBoundArray(llvm::ArrayRef<mlir::Value> lbounds,
@@ -4793,13 +4849,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
47934849
},
47944850
[&](const Fortran::evaluate::Assignment::BoundsSpec &lbExprs) {
47954851
if (isInsideHlfirForallOrWhere())
4796-
TODO(loc, "pointer assignment inside FORALL");
4797-
genPointerAssignment(loc, assign, lbExprs);
4852+
genForallPointerAssignment(loc, assign, lbExprs);
4853+
else
4854+
genPointerAssignment(loc, assign, lbExprs);
47984855
},
47994856
[&](const Fortran::evaluate::Assignment::BoundsRemapping
48004857
&boundExprs) {
48014858
if (isInsideHlfirForallOrWhere())
4802-
TODO(loc, "pointer assignment inside FORALL");
4859+
TODO(
4860+
loc,
4861+
"pointer assignment with bounds remapping inside FORALL");
48034862
genPointerAssignment(loc, assign, boundExprs);
48044863
},
48054864
},

flang/lib/Optimizer/Builder/HLFIRTools.cpp

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -349,26 +349,54 @@ mlir::Value hlfir::genVariableBoxChar(mlir::Location loc,
349349
lengths[0]);
350350
}
351351

352+
static hlfir::Entity changeBoxAttributes(mlir::Location loc,
353+
fir::FirOpBuilder &builder,
354+
hlfir::Entity var,
355+
fir::BaseBoxType forceBoxType) {
356+
assert(llvm::isa<fir::BaseBoxType>(var.getType()) && "expect box type");
357+
// Propagate lower bounds.
358+
mlir::Value shift;
359+
llvm::SmallVector<mlir::Value> lbounds =
360+
getNonDefaultLowerBounds(loc, builder, var);
361+
if (!lbounds.empty())
362+
shift = builder.genShift(loc, lbounds);
363+
auto rebox = builder.create<fir::ReboxOp>(loc, forceBoxType, var, shift,
364+
/*slice=*/nullptr);
365+
return hlfir::Entity{rebox};
366+
}
367+
352368
hlfir::Entity hlfir::genVariableBox(mlir::Location loc,
353369
fir::FirOpBuilder &builder,
354-
hlfir::Entity var) {
370+
hlfir::Entity var,
371+
fir::BaseBoxType forceBoxType) {
355372
assert(var.isVariable() && "must be a variable");
356373
var = hlfir::derefPointersAndAllocatables(loc, builder, var);
357-
if (mlir::isa<fir::BaseBoxType>(var.getType()))
358-
return var;
374+
if (mlir::isa<fir::BaseBoxType>(var.getType())) {
375+
if (!forceBoxType || forceBoxType == var.getType())
376+
return var;
377+
return changeBoxAttributes(loc, builder, var, forceBoxType);
378+
}
359379
// Note: if the var is not a fir.box/fir.class at that point, it has default
360380
// lower bounds and is not polymorphic.
361381
mlir::Value shape =
362382
var.isArray() ? hlfir::genShape(loc, builder, var) : mlir::Value{};
363383
llvm::SmallVector<mlir::Value> typeParams;
364-
auto maybeCharType =
365-
mlir::dyn_cast<fir::CharacterType>(var.getFortranElementType());
384+
mlir::Type elementType =
385+
forceBoxType ? fir::getFortranElementType(forceBoxType.getEleTy())
386+
: var.getFortranElementType();
387+
auto maybeCharType = mlir::dyn_cast<fir::CharacterType>(elementType);
366388
if (!maybeCharType || maybeCharType.hasDynamicLen())
367389
hlfir::genLengthParameters(loc, builder, var, typeParams);
368390
mlir::Value addr = var.getBase();
369391
if (mlir::isa<fir::BoxCharType>(var.getType()))
370392
addr = genVariableRawAddress(loc, builder, var);
371393
mlir::Type boxType = fir::BoxType::get(var.getElementOrSequenceType());
394+
if (forceBoxType) {
395+
boxType = forceBoxType;
396+
mlir::Type baseType =
397+
fir::ReferenceType::get(fir::unwrapRefType(forceBoxType.getEleTy()));
398+
addr = builder.createConvert(loc, baseType, addr);
399+
}
372400
auto embox =
373401
builder.create<fir::EmboxOp>(loc, boxType, addr, shape,
374402
/*slice=*/mlir::Value{}, typeParams);

flang/lib/Optimizer/Builder/TemporaryStorage.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,3 +355,27 @@ void fir::factory::AnyVectorSubscriptStack::destroy(
355355
static_cast<AnyVariableStack *>(this)->destroy(loc, builder);
356356
shapeTemp->destroy(loc, builder);
357357
}
358+
359+
//===----------------------------------------------------------------------===//
360+
// fir::factory::AnyDescriptorAddressStack implementation.
361+
//===----------------------------------------------------------------------===//
362+
363+
fir::factory::AnyDescriptorAddressStack::AnyDescriptorAddressStack(
364+
mlir::Location loc, fir::FirOpBuilder &builder,
365+
mlir::Type descriptorAddressType)
366+
: AnyValueStack(loc, builder, builder.getIntPtrType()),
367+
descriptorAddressType{descriptorAddressType} {}
368+
369+
void fir::factory::AnyDescriptorAddressStack::pushValue(
370+
mlir::Location loc, fir::FirOpBuilder &builder, mlir::Value variable) {
371+
mlir::Value cast =
372+
builder.createConvert(loc, builder.getIntPtrType(), variable);
373+
static_cast<AnyValueStack *>(this)->pushValue(loc, builder, cast);
374+
}
375+
376+
mlir::Value
377+
fir::factory::AnyDescriptorAddressStack::fetch(mlir::Location loc,
378+
fir::FirOpBuilder &builder) {
379+
mlir::Value addr = static_cast<AnyValueStack *>(this)->fetch(loc, builder);
380+
return builder.createConvert(loc, descriptorAddressType, addr);
381+
}

flang/lib/Optimizer/Dialect/FIRType.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,6 +1359,10 @@ bool fir::BaseBoxType::isAssumedRank() const {
13591359
return false;
13601360
}
13611361

1362+
bool fir::BaseBoxType::isPointer() const {
1363+
return llvm::isa<fir::PointerType>(getEleTy());
1364+
}
1365+
13621366
//===----------------------------------------------------------------------===//
13631367
// FIROpsDialect
13641368
//===----------------------------------------------------------------------===//

flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1891,6 +1891,20 @@ llvm::LogicalResult hlfir::RegionAssignOp::verify() {
18911891
return mlir::success();
18921892
}
18931893

1894+
bool hlfir::RegionAssignOp::isPointerAssignment() {
1895+
if (!getUserDefinedAssignment().empty())
1896+
return false;
1897+
hlfir::YieldOp yieldOp =
1898+
mlir::dyn_cast_or_null<hlfir::YieldOp>(getTerminator(getLhsRegion()));
1899+
if (!yieldOp)
1900+
return false;
1901+
mlir::Type lhsType = yieldOp.getEntity().getType();
1902+
if (!hlfir::isBoxAddressType(lhsType))
1903+
return false;
1904+
auto baseBoxType = llvm::cast<fir::BaseBoxType>(fir::unwrapRefType(lhsType));
1905+
return baseBoxType.isPointer();
1906+
}
1907+
18941908
//===----------------------------------------------------------------------===//
18951909
// YieldOp
18961910
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)