Skip to content

Commit 7883900

Browse files
committed
[flang] Lower type-bound procedure call needing dynamic dispatch to fir.dispatch
Lower call with polymorphic entities to fir.dispatch operation. This patch only focus one lowering with simple scalar polymorphic entities. A follow-up patch will deal with allocatble, pointer and array of polymorphic entities as they require box manipulation for the passed-object. Reviewed By: jeanPerier Differential Revision: https://reviews.llvm.org/D135649
1 parent b1d7a95 commit 7883900

File tree

9 files changed

+279
-18
lines changed

9 files changed

+279
-18
lines changed

flang/include/flang/Lower/CallInterface.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,14 @@ class CallerInterface : public CallInterface<CallerInterface> {
284284
/// procedure.
285285
bool isIndirectCall() const;
286286

287+
/// Returns true if this is a call of a type-bound procedure with a
288+
/// polymorphic entity.
289+
bool requireDispatchCall() const;
290+
291+
/// Get the passed-object argument index. nullopt if there is no passed-object
292+
/// index.
293+
std::optional<unsigned> getPassArgIndex() const;
294+
287295
/// Return the procedure symbol if this is a call to a user defined
288296
/// procedure.
289297
const Fortran::semantics::Symbol *getProcedureSymbol() const;
@@ -372,6 +380,10 @@ class CalleeInterface : public CallInterface<CalleeInterface> {
372380
/// called through pointers or not.
373381
bool isIndirectCall() const { return false; }
374382

383+
/// On the callee side it does not matter whether the procedure is called
384+
/// through dynamic dispatch or not.
385+
bool requireDispatchCall() const { return false; };
386+
375387
/// Return the procedure symbol if this is a call to a user defined
376388
/// procedure.
377389
const Fortran::semantics::Symbol *getProcedureSymbol() const;

flang/include/flang/Optimizer/Dialect/FIRType.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ inline unsigned getRankOfShapeType(mlir::Type t) {
203203
}
204204

205205
/// Get the memory reference type of the data pointer from the box type,
206-
inline mlir::Type boxMemRefType(fir::BoxType t) {
206+
inline mlir::Type boxMemRefType(fir::BaseBoxType t) {
207207
auto eleTy = t.getEleTy();
208208
if (!eleTy.isa<fir::PointerType, fir::HeapType>())
209209
eleTy = fir::ReferenceType::get(t);

flang/lib/Lower/CallInterface.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,36 @@ bool Fortran::lower::CallerInterface::isIndirectCall() const {
8888
return false;
8989
}
9090

91+
bool Fortran::lower::CallerInterface::requireDispatchCall() const {
92+
// calls with NOPASS attribute still have their component so check if it is
93+
// polymorphic.
94+
if (const Fortran::evaluate::Component *component =
95+
procRef.proc().GetComponent()) {
96+
if (Fortran::semantics::IsPolymorphic(component->GetFirstSymbol()))
97+
return true;
98+
}
99+
// calls with PASS attribute have the passed-object already set in its
100+
// arguments. Just check if their is one.
101+
std::optional<unsigned> passArg = getPassArgIndex();
102+
if (passArg)
103+
return true;
104+
return false;
105+
}
106+
107+
std::optional<unsigned>
108+
Fortran::lower::CallerInterface::getPassArgIndex() const {
109+
unsigned passArgIdx = 0;
110+
std::optional<unsigned> passArg = std::nullopt;
111+
for (const auto &arg : getCallDescription().arguments()) {
112+
if (arg && arg->isPassedObject()) {
113+
passArg = passArgIdx;
114+
break;
115+
}
116+
++passArgIdx;
117+
}
118+
return passArg;
119+
}
120+
91121
const Fortran::semantics::Symbol *
92122
Fortran::lower::CallerInterface::getIfIndirectCallSymbol() const {
93123
if (const Fortran::semantics::Symbol *symbol = procRef.proc().GetSymbol())

flang/lib/Lower/ConvertExpr.cpp

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1993,8 +1993,10 @@ class ScalarExprLowering {
19931993
}
19941994

19951995
mlir::Value base = fir::getBase(array);
1996-
auto seqTy =
1997-
fir::dyn_cast_ptrOrBoxEleTy(base.getType()).cast<fir::SequenceType>();
1996+
mlir::Type eleTy = fir::dyn_cast_ptrOrBoxEleTy(base.getType());
1997+
if (auto classTy = eleTy.dyn_cast<fir::ClassType>())
1998+
eleTy = classTy.getEleTy();
1999+
auto seqTy = eleTy.cast<fir::SequenceType>();
19982000
assert(args.size() == seqTy.getDimension());
19992001
mlir::Type ty = builder.getRefType(seqTy.getEleTy());
20002002
auto addr = builder.create<fir::CoordinateOp>(loc, ty, base, args);
@@ -2727,11 +2729,47 @@ class ScalarExprLowering {
27272729
if (addHostAssociations)
27282730
operands.push_back(converter.hostAssocTupleValue());
27292731

2730-
auto call = builder.create<fir::CallOp>(loc, funcType.getResults(),
2731-
funcSymbolAttr, operands);
2732+
mlir::Value callResult;
2733+
unsigned callNumResults;
2734+
if (caller.requireDispatchCall()) {
2735+
// Procedure call requiring a dynamic dispatch. Call is created with
2736+
// fir.dispatch.
2737+
2738+
// Get the raw procedure name. The procedure name is not mangled in the
2739+
// binding table.
2740+
const auto &ultimateSymbol =
2741+
caller.getCallDescription().proc().GetSymbol()->GetUltimate();
2742+
auto procName = toStringRef(ultimateSymbol.name());
2743+
2744+
fir::DispatchOp dispatch;
2745+
if (std::optional<unsigned> passArg = caller.getPassArgIndex()) {
2746+
// PASS, PASS(arg-name)
2747+
dispatch = builder.create<fir::DispatchOp>(
2748+
loc, funcType.getResults(), procName, operands[*passArg], operands,
2749+
builder.getI32IntegerAttr(*passArg));
2750+
} else {
2751+
// NOPASS
2752+
const Fortran::evaluate::Component *component =
2753+
caller.getCallDescription().proc().GetComponent();
2754+
assert(component && "expect component for type-bound procedure call.");
2755+
fir::ExtendedValue pass =
2756+
symMap.lookupSymbol(component->GetFirstSymbol()).toExtendedValue();
2757+
dispatch = builder.create<fir::DispatchOp>(loc, funcType.getResults(),
2758+
procName, fir::getBase(pass),
2759+
operands, nullptr);
2760+
}
2761+
callResult = dispatch.getResult(0);
2762+
callNumResults = dispatch.getNumResults();
2763+
} else {
2764+
// Standard procedure call with fir.call.
2765+
auto call = builder.create<fir::CallOp>(loc, funcType.getResults(),
2766+
funcSymbolAttr, operands);
2767+
callResult = call.getResult(0);
2768+
callNumResults = call.getNumResults();
2769+
}
27322770

27332771
if (caller.mustSaveResult())
2734-
builder.create<fir::SaveResultOp>(loc, call.getResult(0),
2772+
builder.create<fir::SaveResultOp>(loc, callResult,
27352773
fir::getBase(allocatedResult.value()),
27362774
arrayResultShape, resultLengths);
27372775

@@ -2754,7 +2792,7 @@ class ScalarExprLowering {
27542792
return mlir::Value{}; // subroutine call
27552793
// For now, Fortran return values are implemented with a single MLIR
27562794
// function return value.
2757-
assert(call.getNumResults() == 1 &&
2795+
assert(callNumResults == 1 &&
27582796
"Expected exactly one result in FUNCTION call");
27592797

27602798
// Call a BIND(C) function that return a char.
@@ -2764,10 +2802,10 @@ class ScalarExprLowering {
27642802
funcType.getResults()[0].dyn_cast<fir::CharacterType>();
27652803
mlir::Value len = builder.createIntegerConstant(
27662804
loc, builder.getCharacterLengthType(), charTy.getLen());
2767-
return fir::CharBoxValue{call.getResult(0), len};
2805+
return fir::CharBoxValue{callResult, len};
27682806
}
27692807

2770-
return call.getResult(0);
2808+
return callResult;
27712809
}
27722810

27732811
/// Like genExtAddr, but ensure the address returned is a temporary even if \p
@@ -6012,7 +6050,7 @@ class ArrayExprLowering {
60126050
}
60136051

60146052
static mlir::Type unwrapBoxEleTy(mlir::Type ty) {
6015-
if (auto boxTy = ty.dyn_cast<fir::BoxType>())
6053+
if (auto boxTy = ty.dyn_cast<fir::BaseBoxType>())
60166054
return fir::unwrapRefType(boxTy.getEleTy());
60176055
return ty;
60186056
}
@@ -7150,7 +7188,7 @@ class ArrayExprLowering {
71507188
// Need an intermediate dereference if the boxed value
71517189
// appears in the middle of the component path or if it is
71527190
// on the right and this is not a pointer assignment.
7153-
if (auto boxTy = ty.dyn_cast<fir::BoxType>()) {
7191+
if (auto boxTy = ty.dyn_cast<fir::BaseBoxType>()) {
71547192
auto currentFunc = components.getExtendCoorRef();
71557193
auto loc = getLoc();
71567194
auto *bldr = &converter.getFirOpBuilder();
@@ -7161,7 +7199,7 @@ class ArrayExprLowering {
71617199
deref = true;
71627200
}
71637201
}
7164-
} else if (auto boxTy = ty.dyn_cast<fir::BoxType>()) {
7202+
} else if (auto boxTy = ty.dyn_cast<fir::BaseBoxType>()) {
71657203
ty = fir::unwrapRefType(boxTy.getEleTy());
71667204
auto recTy = ty.cast<fir::RecordType>();
71677205
ty = recTy.getType(name);
@@ -7247,7 +7285,7 @@ class ArrayExprLowering {
72477285
// assignment, then insert the dereference of the box before any
72487286
// conversion and store.
72497287
if (!isPointerAssignment()) {
7250-
if (auto boxTy = eleTy.dyn_cast<fir::BoxType>()) {
7288+
if (auto boxTy = eleTy.dyn_cast<fir::BaseBoxType>()) {
72517289
eleTy = fir::boxMemRefType(boxTy);
72527290
addr = builder.create<fir::BoxAddrOp>(loc, eleTy, addr);
72537291
eleTy = fir::unwrapRefType(eleTy);

flang/lib/Lower/Mangler.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,10 @@ Fortran::lower::mangle::mangleName(const Fortran::semantics::Symbol &symbol,
155155
llvm::report_fatal_error(
156156
"only derived type instances can be mangled");
157157
},
158+
[&](const Fortran::semantics::ProcBindingDetails &procBinding)
159+
-> std::string {
160+
return mangleName(procBinding.symbol(), keepExternalInScope);
161+
},
158162
[](const auto &) -> std::string { TODO_NOLOC("symbol mangling"); },
159163
},
160164
ultimateSymbol.details());

flang/lib/Optimizer/Builder/BoxValue.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ bool fir::MutableBoxValue::verify() const {
204204
/// Debug verifier for BoxValue ctor. There is no guarantee this will
205205
/// always be called.
206206
bool fir::BoxValue::verify() const {
207-
if (!addr.getType().isa<fir::BoxType>())
207+
if (!addr.getType().isa<fir::BaseBoxType>())
208208
return false;
209209
if (!lbounds.empty() && lbounds.size() != rank())
210210
return false;

flang/lib/Optimizer/Builder/FIRBuilder.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ mlir::Value fir::FirOpBuilder::createSlice(mlir::Location loc,
460460
mlir::Value fir::FirOpBuilder::createBox(mlir::Location loc,
461461
const fir::ExtendedValue &exv) {
462462
mlir::Value itemAddr = fir::getBase(exv);
463-
if (itemAddr.getType().isa<fir::BoxType>())
463+
if (itemAddr.getType().isa<fir::BaseBoxType>())
464464
return itemAddr;
465465
auto elementType = fir::dyn_cast_ptrEleTy(itemAddr.getType());
466466
if (!elementType) {
@@ -741,7 +741,7 @@ static llvm::SmallVector<mlir::Value> getFromBox(mlir::Location loc,
741741
fir::FirOpBuilder &builder,
742742
mlir::Type valTy,
743743
mlir::Value boxVal) {
744-
if (auto boxTy = valTy.dyn_cast<fir::BoxType>()) {
744+
if (auto boxTy = valTy.dyn_cast<fir::BaseBoxType>()) {
745745
auto eleTy = fir::unwrapAllRefAndSeqType(boxTy.getEleTy());
746746
if (auto recTy = eleTy.dyn_cast<fir::RecordType>()) {
747747
if (recTy.getNumLenParams() > 0) {
@@ -795,7 +795,7 @@ llvm::SmallVector<mlir::Value>
795795
fir::factory::getTypeParams(mlir::Location loc, fir::FirOpBuilder &builder,
796796
fir::ArrayLoadOp load) {
797797
mlir::Type memTy = load.getMemref().getType();
798-
if (auto boxTy = memTy.dyn_cast<fir::BoxType>())
798+
if (auto boxTy = memTy.dyn_cast<fir::BaseBoxType>())
799799
return getFromBox(loc, builder, boxTy, load.getMemref());
800800
return load.getTypeparams();
801801
}

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -917,7 +917,8 @@ mlir::LogicalResult fir::ConvertOp::verify() {
917917
(inType.isa<fir::BoxType>() && outType.isa<fir::BoxType>()) ||
918918
(inType.isa<fir::BoxProcType>() && outType.isa<fir::BoxProcType>()) ||
919919
(fir::isa_complex(inType) && fir::isa_complex(outType)) ||
920-
(fir::isBoxedRecordType(inType) && fir::isPolymorphicType(outType)))
920+
(fir::isBoxedRecordType(inType) && fir::isPolymorphicType(outType)) ||
921+
(fir::isPolymorphicType(inType) && fir::isPolymorphicType(outType)))
921922
return mlir::success();
922923
return emitOpError("invalid type conversion");
923924
}

0 commit comments

Comments
 (0)