Skip to content

Commit cd7e653

Browse files
authored
[flang] optimize array function calls using hlfir.eval_in_mem (#118070)
This patch encapsulate array function call lowering into hlfir.eval_in_mem and allows directly evaluating the call into the LHS when possible. The conditions are: LHS is contiguous, not accessed inside the function, it is not a whole allocatable, and the function results needs not to be finalized. All these conditions are tested in the previous hlfir.eval_in_mem optimization (#118069) that is leveraging the extension of getModRef to handle function calls(#117164). This yields a 25% speed-up on polyhedron channel2 benchmark (from 1min to 45s measured on an X86-64 Zen 2).
1 parent a871124 commit cd7e653

File tree

12 files changed

+258
-74
lines changed

12 files changed

+258
-74
lines changed

flang/include/flang/Lower/ConvertCall.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@
2424

2525
namespace Fortran::lower {
2626

27+
/// Data structure packaging the SSA value(s) produced for the result of lowered
28+
/// function calls.
29+
using LoweredResult =
30+
std::variant<fir::ExtendedValue, hlfir::EntityWithAttributes>;
31+
2732
/// Given a call site for which the arguments were already lowered, generate
2833
/// the call and return the result. This function deals with explicit result
2934
/// allocation and lowering if needed. It also deals with passing the host
@@ -32,7 +37,7 @@ namespace Fortran::lower {
3237
/// It is only used for HLFIR.
3338
/// The returned boolean indicates if finalization has been emitted in
3439
/// \p stmtCtx for the result.
35-
std::pair<fir::ExtendedValue, bool> genCallOpAndResult(
40+
std::pair<LoweredResult, bool> genCallOpAndResult(
3641
mlir::Location loc, Fortran::lower::AbstractConverter &converter,
3742
Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx,
3843
Fortran::lower::CallerInterface &caller, mlir::FunctionType callSiteType,

flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ inline mlir::Type getFortranElementOrSequenceType(mlir::Type type) {
6161
return type;
6262
}
6363

64+
/// Build the hlfir.expr type for the value held in a variable of type \p
65+
/// variableType.
66+
mlir::Type getExprType(mlir::Type variableType);
67+
6468
/// Is this a fir.box or fir.class address type?
6569
inline bool isBoxAddressType(mlir::Type type) {
6670
type = fir::dyn_cast_ptrEleTy(type);

flang/lib/Lower/ConvertCall.cpp

Lines changed: 68 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,8 @@ static void remapActualToDummyDescriptors(
284284
}
285285
}
286286

287-
std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
287+
std::pair<Fortran::lower::LoweredResult, bool>
288+
Fortran::lower::genCallOpAndResult(
288289
mlir::Location loc, Fortran::lower::AbstractConverter &converter,
289290
Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx,
290291
Fortran::lower::CallerInterface &caller, mlir::FunctionType callSiteType,
@@ -326,13 +327,20 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
326327
}
327328
}
328329

330+
const bool isExprCall =
331+
converter.getLoweringOptions().getLowerToHighLevelFIR() &&
332+
callSiteType.getNumResults() == 1 &&
333+
llvm::isa<fir::SequenceType>(callSiteType.getResult(0));
334+
329335
mlir::IndexType idxTy = builder.getIndexType();
330336
auto lowerSpecExpr = [&](const auto &expr) -> mlir::Value {
331337
mlir::Value convertExpr = builder.createConvert(
332338
loc, idxTy, fir::getBase(converter.genExprValue(expr, stmtCtx)));
333339
return fir::factory::genMaxWithZero(builder, loc, convertExpr);
334340
};
335341
llvm::SmallVector<mlir::Value> resultLengths;
342+
mlir::Value arrayResultShape;
343+
hlfir::EvaluateInMemoryOp evaluateInMemory;
336344
auto allocatedResult = [&]() -> std::optional<fir::ExtendedValue> {
337345
llvm::SmallVector<mlir::Value> extents;
338346
llvm::SmallVector<mlir::Value> lengths;
@@ -366,6 +374,18 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
366374
resultLengths = lengths;
367375
}
368376

377+
if (!extents.empty())
378+
arrayResultShape = builder.genShape(loc, extents);
379+
380+
if (isExprCall) {
381+
mlir::Type exprType = hlfir::getExprType(type);
382+
evaluateInMemory = builder.create<hlfir::EvaluateInMemoryOp>(
383+
loc, exprType, arrayResultShape, resultLengths);
384+
builder.setInsertionPointToStart(&evaluateInMemory.getBody().front());
385+
return toExtendedValue(loc, evaluateInMemory.getMemory(), extents,
386+
lengths);
387+
}
388+
369389
if ((!extents.empty() || !lengths.empty()) && !isElemental) {
370390
// Note: in the elemental context, the alloca ownership inside the
371391
// elemental region is implicit, and later pass in lowering (stack
@@ -384,8 +404,7 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
384404
if (mustPopSymMap)
385405
symMap.popScope();
386406

387-
// Place allocated result or prepare the fir.save_result arguments.
388-
mlir::Value arrayResultShape;
407+
// Place allocated result
389408
if (allocatedResult) {
390409
if (std::optional<Fortran::lower::CallInterface<
391410
Fortran::lower::CallerInterface>::PassedEntity>
@@ -399,16 +418,6 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
399418
else
400419
fir::emitFatalError(
401420
loc, "only expect character scalar result to be passed by ref");
402-
} else {
403-
assert(caller.mustSaveResult());
404-
arrayResultShape = allocatedResult->match(
405-
[&](const fir::CharArrayBoxValue &) {
406-
return builder.createShape(loc, *allocatedResult);
407-
},
408-
[&](const fir::ArrayBoxValue &) {
409-
return builder.createShape(loc, *allocatedResult);
410-
},
411-
[&](const auto &) { return mlir::Value{}; });
412421
}
413422
}
414423

@@ -642,13 +651,39 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
642651
callResult = call.getResult(0);
643652
}
644653

654+
std::optional<Fortran::evaluate::DynamicType> retTy =
655+
caller.getCallDescription().proc().GetType();
656+
// With HLFIR lowering, isElemental must be set to true
657+
// if we are producing an elemental call. In this case,
658+
// the elemental results must not be destroyed, instead,
659+
// the resulting array result will be finalized/destroyed
660+
// as needed by hlfir.destroy.
661+
const bool mustFinalizeResult =
662+
!isElemental && callSiteType.getNumResults() > 0 &&
663+
!fir::isPointerType(callSiteType.getResult(0)) && retTy.has_value() &&
664+
(retTy->category() == Fortran::common::TypeCategory::Derived ||
665+
retTy->IsPolymorphic() || retTy->IsUnlimitedPolymorphic());
666+
645667
if (caller.mustSaveResult()) {
646668
assert(allocatedResult.has_value());
647669
builder.create<fir::SaveResultOp>(loc, callResult,
648670
fir::getBase(*allocatedResult),
649671
arrayResultShape, resultLengths);
650672
}
651673

674+
if (evaluateInMemory) {
675+
builder.setInsertionPointAfter(evaluateInMemory);
676+
mlir::Value expr = evaluateInMemory.getResult();
677+
fir::FirOpBuilder *bldr = &converter.getFirOpBuilder();
678+
if (!isElemental)
679+
stmtCtx.attachCleanup([bldr, loc, expr, mustFinalizeResult]() {
680+
bldr->create<hlfir::DestroyOp>(loc, expr,
681+
/*finalize=*/mustFinalizeResult);
682+
});
683+
return {LoweredResult{hlfir::EntityWithAttributes{expr}},
684+
mustFinalizeResult};
685+
}
686+
652687
if (allocatedResult) {
653688
// The result must be optionally destroyed (if it is of a derived type
654689
// that may need finalization or deallocation of the components).
@@ -679,17 +714,7 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
679714
// derived-type.
680715
// For polymorphic and unlimited polymorphic enities call the runtime
681716
// in any cases.
682-
std::optional<Fortran::evaluate::DynamicType> retTy =
683-
caller.getCallDescription().proc().GetType();
684-
// With HLFIR lowering, isElemental must be set to true
685-
// if we are producing an elemental call. In this case,
686-
// the elemental results must not be destroyed, instead,
687-
// the resulting array result will be finalized/destroyed
688-
// as needed by hlfir.destroy.
689-
if (!isElemental && !fir::isPointerType(funcType.getResults()[0]) &&
690-
retTy &&
691-
(retTy->category() == Fortran::common::TypeCategory::Derived ||
692-
retTy->IsPolymorphic() || retTy->IsUnlimitedPolymorphic())) {
717+
if (mustFinalizeResult) {
693718
if (retTy->IsPolymorphic() || retTy->IsUnlimitedPolymorphic()) {
694719
auto *bldr = &converter.getFirOpBuilder();
695720
stmtCtx.attachCleanup([bldr, loc, allocatedResult]() {
@@ -715,12 +740,13 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
715740
}
716741
}
717742
}
718-
return {*allocatedResult, resultIsFinalized};
743+
return {LoweredResult{*allocatedResult}, resultIsFinalized};
719744
}
720745

721746
// subroutine call
722747
if (!resultType)
723-
return {fir::ExtendedValue{mlir::Value{}}, /*resultIsFinalized=*/false};
748+
return {LoweredResult{fir::ExtendedValue{mlir::Value{}}},
749+
/*resultIsFinalized=*/false};
724750

725751
// For now, Fortran return values are implemented with a single MLIR
726752
// function return value.
@@ -734,10 +760,13 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
734760
mlir::dyn_cast<fir::CharacterType>(funcType.getResults()[0]);
735761
mlir::Value len = builder.createIntegerConstant(
736762
loc, builder.getCharacterLengthType(), charTy.getLen());
737-
return {fir::CharBoxValue{callResult, len}, /*resultIsFinalized=*/false};
763+
return {
764+
LoweredResult{fir::ExtendedValue{fir::CharBoxValue{callResult, len}}},
765+
/*resultIsFinalized=*/false};
738766
}
739767

740-
return {callResult, /*resultIsFinalized=*/false};
768+
return {LoweredResult{fir::ExtendedValue{callResult}},
769+
/*resultIsFinalized=*/false};
741770
}
742771

743772
static hlfir::EntityWithAttributes genStmtFunctionRef(
@@ -1661,19 +1690,25 @@ genUserCall(Fortran::lower::PreparedActualArguments &loweredActuals,
16611690
// Prepare lowered arguments according to the interface
16621691
// and map the lowered values to the dummy
16631692
// arguments.
1664-
auto [result, resultIsFinalized] = Fortran::lower::genCallOpAndResult(
1693+
auto [loweredResult, resultIsFinalized] = Fortran::lower::genCallOpAndResult(
16651694
loc, callContext.converter, callContext.symMap, callContext.stmtCtx,
16661695
caller, callSiteType, callContext.resultType,
16671696
callContext.isElementalProcWithArrayArgs());
1668-
// For procedure pointer function result, just return the call.
1669-
if (callContext.resultType &&
1670-
mlir::isa<fir::BoxProcType>(*callContext.resultType))
1671-
return hlfir::EntityWithAttributes(fir::getBase(result));
16721697

16731698
/// Clean-up associations and copy-in.
16741699
for (auto cleanUp : callCleanUps)
16751700
cleanUp.genCleanUp(loc, builder);
16761701

1702+
if (auto *entity = std::get_if<hlfir::EntityWithAttributes>(&loweredResult))
1703+
return *entity;
1704+
1705+
auto &result = std::get<fir::ExtendedValue>(loweredResult);
1706+
1707+
// For procedure pointer function result, just return the call.
1708+
if (callContext.resultType &&
1709+
mlir::isa<fir::BoxProcType>(*callContext.resultType))
1710+
return hlfir::EntityWithAttributes(fir::getBase(result));
1711+
16771712
if (!fir::getBase(result))
16781713
return std::nullopt; // subroutine call.
16791714

flang/lib/Lower/ConvertExpr.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2852,10 +2852,11 @@ class ScalarExprLowering {
28522852
}
28532853
}
28542854

2855-
ExtValue result =
2855+
auto loweredResult =
28562856
Fortran::lower::genCallOpAndResult(loc, converter, symMap, stmtCtx,
28572857
caller, callSiteType, resultType)
28582858
.first;
2859+
auto &result = std::get<ExtValue>(loweredResult);
28592860

28602861
// Sync pointers and allocatables that may have been modified during the
28612862
// call.
@@ -4881,10 +4882,12 @@ class ArrayExprLowering {
48814882
[&](const auto &) { return fir::getBase(exv); });
48824883
caller.placeInput(argIface, arg);
48834884
}
4884-
return Fortran::lower::genCallOpAndResult(loc, converter, symMap,
4885-
getElementCtx(), caller,
4886-
callSiteType, retTy)
4887-
.first;
4885+
Fortran::lower::LoweredResult res =
4886+
Fortran::lower::genCallOpAndResult(loc, converter, symMap,
4887+
getElementCtx(), caller,
4888+
callSiteType, retTy)
4889+
.first;
4890+
return std::get<ExtValue>(res);
48884891
};
48894892
}
48904893

flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,3 +215,16 @@ bool hlfir::mayHaveAllocatableComponent(mlir::Type ty) {
215215
return fir::isPolymorphicType(ty) || fir::isUnlimitedPolymorphicType(ty) ||
216216
fir::isRecordWithAllocatableMember(hlfir::getFortranElementType(ty));
217217
}
218+
219+
mlir::Type hlfir::getExprType(mlir::Type variableType) {
220+
hlfir::ExprType::Shape typeShape;
221+
bool isPolymorphic = fir::isPolymorphicType(variableType);
222+
mlir::Type type = getFortranElementOrSequenceType(variableType);
223+
if (auto seqType = mlir::dyn_cast<fir::SequenceType>(type)) {
224+
assert(!seqType.hasUnknownShape() && "assumed-rank cannot be expressions");
225+
typeShape.append(seqType.getShape().begin(), seqType.getShape().end());
226+
type = seqType.getEleTy();
227+
}
228+
return hlfir::ExprType::get(variableType.getContext(), typeShape, type,
229+
isPolymorphic);
230+
}

flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1427,16 +1427,7 @@ llvm::LogicalResult hlfir::EndAssociateOp::verify() {
14271427
void hlfir::AsExprOp::build(mlir::OpBuilder &builder,
14281428
mlir::OperationState &result, mlir::Value var,
14291429
mlir::Value mustFree) {
1430-
hlfir::ExprType::Shape typeShape;
1431-
bool isPolymorphic = fir::isPolymorphicType(var.getType());
1432-
mlir::Type type = getFortranElementOrSequenceType(var.getType());
1433-
if (auto seqType = mlir::dyn_cast<fir::SequenceType>(type)) {
1434-
typeShape.append(seqType.getShape().begin(), seqType.getShape().end());
1435-
type = seqType.getEleTy();
1436-
}
1437-
1438-
auto resultType = hlfir::ExprType::get(builder.getContext(), typeShape, type,
1439-
isPolymorphic);
1430+
mlir::Type resultType = hlfir::getExprType(var.getType());
14401431
return build(builder, result, resultType, var, mustFree);
14411432
}
14421433

flang/test/HLFIR/order_assignments/where-scheduling.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ end function f
134134
!CHECK-NEXT: run 1 save : where/mask
135135
!CHECK-NEXT: run 2 evaluate: where/region_assign1
136136
!CHECK-LABEL: ------------ scheduling where in _QPonly_once ------------
137-
!CHECK-NEXT: unknown effect: %{{[0-9]+}} = llvm.intr.stacksave : !llvm.ptr
137+
!CHECK-NEXT: unknown effect: %11 = fir.call @_QPcall_me_only_once() fastmath<contract> : () -> !fir.array<10x!fir.logical<4>>
138138
!CHECK-NEXT: saving eval because write effect prevents re-evaluation
139139
!CHECK-NEXT: run 1 save (w): where/mask
140140
!CHECK-NEXT: run 2 evaluate: where/region_assign1

0 commit comments

Comments
 (0)