Skip to content

[flang] Do not hoist all scalar sub-expressions from WHERE constructs #91395

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions flang/include/flang/Lower/StatementContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@
#include <functional>
#include <optional>

namespace mlir {
class Location;
class Region;
} // namespace mlir

namespace fir {
class FirOpBuilder;
}

namespace Fortran::lower {

/// When lowering a statement, temporaries for intermediate results may be
Expand Down Expand Up @@ -105,6 +114,11 @@ class StatementContext {
llvm::SmallVector<std::optional<CleanupFunction>> cufs;
};

/// If \p context contains any cleanups, ensure \p region has a block, and
/// generate the cleanup inside that block.
void genCleanUpInRegionIfAny(mlir::Location loc, fir::FirOpBuilder &builder,
mlir::Region &region, StatementContext &context);

} // namespace Fortran::lower

#endif // FORTRAN_LOWER_STATEMENTCONTEXT_H
24 changes: 23 additions & 1 deletion flang/include/flang/Optimizer/HLFIR/HLFIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1329,7 +1329,8 @@ def hlfir_RegionAssignOp : hlfir_Op<"region_assign", [hlfir_OrderedAssignmentTre
}

def hlfir_YieldOp : hlfir_Op<"yield", [Terminator, ParentOneOf<["RegionAssignOp",
"ElementalAddrOp", "ForallOp", "ForallMaskOp", "WhereOp", "ElseWhereOp"]>,
"ElementalAddrOp", "ForallOp", "ForallMaskOp", "WhereOp", "ElseWhereOp",
"ExactlyOnceOp"]>,
SingleBlockImplicitTerminator<"fir::FirEndOp">, RecursivelySpeculatable,
RecursiveMemoryEffects]> {

Expand Down Expand Up @@ -1594,6 +1595,27 @@ def hlfir_ForallMaskOp : hlfir_AssignmentMaskOp<"forall_mask"> {
let hasVerifier = 1;
}

def hlfir_ExactlyOnceOp : hlfir_Op<"exactly_once", [RecursiveMemoryEffects]> {
let summary = "Execute exactly once its region in a WhereOp";
let description = [{
Inside a Where assignment, Fortran requires a non elemental call and its
arguments to be executed exactly once, regardless of the mask values.
This operation allows holding these evaluations that cannot be hoisted
until potential parent Forall loops have been created.
It also allows inlining the calls without losing the information that
these calls must be hoisted.
}];

let regions = (region SizedRegion<1>:$body);

let results = (outs AnyFortranEntity:$result);

let assemblyFormat = [{
attr-dict `:` type($result)
$body
}];
}

def hlfir_WhereOp : hlfir_AssignmentMaskOp<"where"> {
let summary = "Represent a Fortran where construct or statement";
let description = [{
Expand Down
45 changes: 23 additions & 22 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3677,22 +3677,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
return hlfir::Entity{valueAndPair.first};
}

static void
genCleanUpInRegionIfAny(mlir::Location loc, fir::FirOpBuilder &builder,
mlir::Region &region,
Fortran::lower::StatementContext &context) {
if (!context.hasCode())
return;
mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
if (region.empty())
builder.createBlock(&region);
else
builder.setInsertionPointToEnd(&region.front());
context.finalizeAndPop();
hlfir::YieldOp::ensureTerminator(region, builder, loc);
builder.restoreInsertionPoint(insertPt);
}

bool firstDummyIsPointerOrAllocatable(
const Fortran::evaluate::ProcedureRef &userDefinedAssignment) {
using DummyAttr = Fortran::evaluate::characteristics::DummyDataObject::Attr;
Expand Down Expand Up @@ -3918,23 +3902,24 @@ class FirConverter : public Fortran::lower::AbstractConverter {
Fortran::lower::StatementContext rhsContext;
hlfir::Entity rhs = evaluateRhs(rhsContext);
auto rhsYieldOp = builder.create<hlfir::YieldOp>(loc, rhs);
genCleanUpInRegionIfAny(loc, builder, rhsYieldOp.getCleanup(), rhsContext);
Fortran::lower::genCleanUpInRegionIfAny(
loc, builder, rhsYieldOp.getCleanup(), rhsContext);
// Lower LHS in its own region.
builder.createBlock(&regionAssignOp.getLhsRegion());
Fortran::lower::StatementContext lhsContext;
mlir::Value lhsYield = nullptr;
if (!lhsHasVectorSubscripts) {
hlfir::Entity lhs = evaluateLhs(lhsContext);
auto lhsYieldOp = builder.create<hlfir::YieldOp>(loc, lhs);
genCleanUpInRegionIfAny(loc, builder, lhsYieldOp.getCleanup(),
lhsContext);
Fortran::lower::genCleanUpInRegionIfAny(
loc, builder, lhsYieldOp.getCleanup(), lhsContext);
lhsYield = lhs;
} else {
hlfir::ElementalAddrOp elementalAddr =
Fortran::lower::convertVectorSubscriptedExprToElementalAddr(
loc, *this, assign.lhs, localSymbols, lhsContext);
genCleanUpInRegionIfAny(loc, builder, elementalAddr.getCleanup(),
lhsContext);
Fortran::lower::genCleanUpInRegionIfAny(
loc, builder, elementalAddr.getCleanup(), lhsContext);
lhsYield = elementalAddr.getYieldOp().getEntity();
}
assert(lhsYield && "must have been set");
Expand Down Expand Up @@ -4289,7 +4274,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
loc, *this, *maskExpr, localSymbols, maskContext);
mask = hlfir::loadTrivialScalar(loc, *builder, mask);
auto yieldOp = builder->create<hlfir::YieldOp>(loc, mask);
genCleanUpInRegionIfAny(loc, *builder, yieldOp.getCleanup(), maskContext);
Fortran::lower::genCleanUpInRegionIfAny(loc, *builder, yieldOp.getCleanup(),
maskContext);
}
void genFIR(const Fortran::parser::WhereConstructStmt &stmt) {
const Fortran::semantics::SomeExpr *maskExpr = Fortran::semantics::GetExpr(
Expand Down Expand Up @@ -5545,3 +5531,18 @@ Fortran::lower::LoweringBridge::LoweringBridge(
fir::support::setMLIRDataLayout(*module.get(),
targetMachine.createDataLayout());
}

void Fortran::lower::genCleanUpInRegionIfAny(
mlir::Location loc, fir::FirOpBuilder &builder, mlir::Region &region,
Fortran::lower::StatementContext &context) {
if (!context.hasCode())
return;
mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
if (region.empty())
builder.createBlock(&region);
else
builder.setInsertionPointToEnd(&region.front());
context.finalizeAndPop();
hlfir::YieldOp::ensureTerminator(region, builder, loc);
builder.restoreInsertionPoint(insertPt);
}
38 changes: 38 additions & 0 deletions flang/lib/Lower/ConvertCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2682,10 +2682,48 @@ bool Fortran::lower::isIntrinsicModuleProcRef(
return module && module->attrs().test(Fortran::semantics::Attr::INTRINSIC);
}

static bool isInWhereMaskedExpression(fir::FirOpBuilder &builder) {
// The MASK of the outer WHERE is not masked itself.
mlir::Operation *op = builder.getRegion().getParentOp();
return op && op->getParentOfType<hlfir::WhereOp>();
}

std::optional<hlfir::EntityWithAttributes> Fortran::lower::convertCallToHLFIR(
mlir::Location loc, Fortran::lower::AbstractConverter &converter,
const evaluate::ProcedureRef &procRef, std::optional<mlir::Type> resultType,
Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx) {
auto &builder = converter.getFirOpBuilder();
if (resultType && !procRef.IsElemental() &&
isInWhereMaskedExpression(builder) &&
!builder.getRegion().getParentOfType<hlfir::ExactlyOnceOp>()) {
// Non elemental calls inside a where-assignment-stmt must be executed
// exactly once without mask control. Lower them in a special region so that
// this can be enforced whenscheduling forall/where expression evaluations.
Fortran::lower::StatementContext localStmtCtx;
mlir::Type bogusType = builder.getIndexType();
auto exactlyOnce = builder.create<hlfir::ExactlyOnceOp>(loc, bogusType);
mlir::Block *block = builder.createBlock(&exactlyOnce.getBody());
builder.setInsertionPointToStart(block);
CallContext callContext(procRef, resultType, loc, converter, symMap,
localStmtCtx);
std::optional<hlfir::EntityWithAttributes> res =
genProcedureRef(callContext);
assert(res.has_value() && "must be a function");
auto yield = builder.create<hlfir::YieldOp>(loc, *res);
Fortran::lower::genCleanUpInRegionIfAny(loc, builder, yield.getCleanup(),
localStmtCtx);
builder.setInsertionPointAfter(exactlyOnce);
exactlyOnce->getResult(0).setType(res->getType());
if (hlfir::isFortranValue(exactlyOnce.getResult()))
return hlfir::EntityWithAttributes{exactlyOnce.getResult()};
// Create hlfir.declare for the result to satisfy
// hlfir::EntityWithAttributes requirements.
auto [exv, cleanup] = hlfir::translateToExtendedValue(
loc, builder, hlfir::Entity{exactlyOnce});
assert(!cleanup && "resut is a variable");
return hlfir::genDeclare(loc, builder, exv, ".func.pointer.result",
fir::FortranVariableFlagsAttr{});
}
CallContext callContext(procRef, resultType, loc, converter, symMap, stmtCtx);
return genProcedureRef(callContext);
}
Expand Down
Loading
Loading