Skip to content

Commit 6e85b88

Browse files
committed
[flang] Add fir.dispatch code generation
fir.dispatch code generation uses the binding table stored in the type descriptor. There is no runtime call involved. The binding table is always build from the parent type so the index of a specific binding is the same in the parent derived-type or in the extended type. Follow-up patches will deal cases not present here such as allocatable polymorphic entities or pointers. Reviewed By: jeanPerier, PeteSteinfeld Differential Revision: https://reviews.llvm.org/D136189
1 parent 2e73129 commit 6e85b88

File tree

7 files changed

+363
-35
lines changed

7 files changed

+363
-35
lines changed

flang/include/flang/Optimizer/CodeGen/CGOps.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def fircg_XEmboxOp : fircg_Op<"ext_embox", [AttrSizedOperandSegments]> {
5858
Variadic<AnyIntegerType>:$substr,
5959
Variadic<AnyIntegerType>:$lenParams
6060
);
61-
let results = (outs fir_BoxType);
61+
let results = (outs BoxOrClassType);
6262

6363
let assemblyFormat = [{
6464
$memref (`(`$shape^`)`)? (`origin` $shift^)? (`[`$slice^`]`)?
@@ -107,14 +107,14 @@ def fircg_XReboxOp : fircg_Op<"ext_rebox", [AttrSizedOperandSegments]> {
107107
}];
108108

109109
let arguments = (ins
110-
fir_BoxType:$box,
110+
BoxOrClassType:$box,
111111
Variadic<AnyIntegerType>:$shape,
112112
Variadic<AnyIntegerType>:$shift,
113113
Variadic<AnyIntegerType>:$slice,
114114
Variadic<AnyCoordinateType>:$subcomponent,
115115
Variadic<AnyIntegerType>:$substr
116116
);
117-
let results = (outs fir_BoxType);
117+
let results = (outs BoxOrClassType);
118118

119119
let assemblyFormat = [{
120120
$box (`(`$shape^`)`)? (`origin` $shift^)? (`[`$slice^`]`)?

flang/lib/Lower/ConvertExpr.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2745,18 +2745,18 @@ class ScalarExprLowering {
27452745
if (std::optional<unsigned> passArg = caller.getPassArgIndex()) {
27462746
// PASS, PASS(arg-name)
27472747
dispatch = builder.create<fir::DispatchOp>(
2748-
loc, funcType.getResults(), procName, operands[*passArg], operands,
2749-
builder.getI32IntegerAttr(*passArg));
2748+
loc, funcType.getResults(), builder.getStringAttr(procName),
2749+
operands[*passArg], operands, builder.getI32IntegerAttr(*passArg));
27502750
} else {
27512751
// NOPASS
27522752
const Fortran::evaluate::Component *component =
27532753
caller.getCallDescription().proc().GetComponent();
27542754
assert(component && "expect component for type-bound procedure call.");
27552755
fir::ExtendedValue pass =
27562756
symMap.lookupSymbol(component->GetFirstSymbol()).toExtendedValue();
2757-
dispatch = builder.create<fir::DispatchOp>(loc, funcType.getResults(),
2758-
procName, fir::getBase(pass),
2759-
operands, nullptr);
2757+
dispatch = builder.create<fir::DispatchOp>(
2758+
loc, funcType.getResults(), builder.getStringAttr(procName),
2759+
fir::getBase(pass), operands, nullptr);
27602760
}
27612761
callResult = dispatch.getResult(0);
27622762
callNumResults = dispatch.getNumResults();

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 124 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -893,8 +893,123 @@ struct DispatchOpConversion : public FIROpConversion<fir::DispatchOp> {
893893
mlir::LogicalResult
894894
matchAndRewrite(fir::DispatchOp dispatch, OpAdaptor adaptor,
895895
mlir::ConversionPatternRewriter &rewriter) const override {
896-
TODO(dispatch.getLoc(), "fir.dispatch codegen");
897-
return mlir::failure();
896+
mlir::Location loc = dispatch.getLoc();
897+
898+
if (bindingTables.empty())
899+
return emitError(loc) << "no binding tables found";
900+
901+
if (dispatch.getObject()
902+
.getType()
903+
.getEleTy()
904+
.isa<fir::HeapType, fir::PointerType>())
905+
TODO(loc,
906+
"fir.dispatch with allocatable or pointer polymorphic entities");
907+
908+
// Get derived type information.
909+
auto declaredType = dispatch.getObject().getType().getEleTy();
910+
assert(declaredType.isa<fir::RecordType>() && "expecting fir.type");
911+
auto recordType = declaredType.dyn_cast<fir::RecordType>();
912+
std::string typeDescName =
913+
fir::NameUniquer::getTypeDescriptorName(recordType.getName());
914+
std::string typeDescBindingTableName =
915+
fir::NameUniquer::getTypeDescriptorBindingTableName(
916+
recordType.getName());
917+
918+
// Lookup for the binding table.
919+
auto bindingsIter = bindingTables.find(typeDescBindingTableName);
920+
if (bindingsIter == bindingTables.end())
921+
return emitError(loc)
922+
<< "cannot find binding table for " << typeDescBindingTableName;
923+
924+
// Lookup for the binding.
925+
const BindingTable &bindingTable = bindingsIter->second;
926+
auto bindingIter = bindingTable.find(dispatch.getMethod());
927+
if (bindingIter == bindingTable.end())
928+
return emitError(loc)
929+
<< "cannot find binding for " << dispatch.getMethod();
930+
unsigned bindingIdx = bindingIter->second;
931+
932+
mlir::Value passedObject = dispatch.getObject();
933+
934+
auto module = dispatch.getOperation()->getParentOfType<mlir::ModuleOp>();
935+
mlir::Type typeDescTy;
936+
if (auto global = module.lookupSymbol<fir::GlobalOp>(typeDescName)) {
937+
typeDescTy = convertType(global.getType());
938+
} else if (auto global =
939+
module.lookupSymbol<mlir::LLVM::GlobalOp>(typeDescName)) {
940+
// The global may have already been translated to LLVM.
941+
typeDescTy = global.getType();
942+
}
943+
944+
auto isArray = fir::dyn_cast_ptrOrBoxEleTy(passedObject.getType())
945+
.template isa<fir::SequenceType>();
946+
unsigned typeDescFieldId = isArray ? kOptTypePtrPosInBox : kDimsPosInBox;
947+
948+
auto descPtr = adaptor.getOperands()[0]
949+
.getType()
950+
.dyn_cast<mlir::LLVM::LLVMPointerType>();
951+
952+
// Load the descriptor.
953+
auto desc = rewriter.create<mlir::LLVM::LoadOp>(
954+
loc, descPtr.getElementType(), adaptor.getOperands()[0]);
955+
956+
// Load the type descriptor.
957+
auto typeDescPtr =
958+
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, desc, typeDescFieldId);
959+
auto typeDesc =
960+
rewriter.create<mlir::LLVM::LoadOp>(loc, typeDescTy, typeDescPtr);
961+
962+
// Load the bindings descriptor.
963+
auto typeDescStructTy = typeDescTy.dyn_cast<mlir::LLVM::LLVMStructType>();
964+
auto bindingDescType =
965+
typeDescStructTy.getBody()[0].dyn_cast<mlir::LLVM::LLVMStructType>();
966+
auto bindingDesc =
967+
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, typeDesc, 0);
968+
969+
// Load the correct binding.
970+
auto bindingType =
971+
bindingDescType.getBody()[0].dyn_cast<mlir::LLVM::LLVMPointerType>();
972+
auto baseBindingPtr = rewriter.create<mlir::LLVM::ExtractValueOp>(
973+
loc, bindingDesc, kAddrPosInBox);
974+
auto bindingPtr = rewriter.create<mlir::LLVM::GEPOp>(
975+
loc, bindingType, baseBindingPtr,
976+
llvm::ArrayRef<mlir::LLVM::GEPArg>{static_cast<int32_t>(bindingIdx)});
977+
auto binding = rewriter.create<mlir::LLVM::LoadOp>(
978+
loc, bindingType.getElementType(), bindingPtr);
979+
980+
// Get the function type.
981+
llvm::SmallVector<mlir::Type> argTypes;
982+
for (mlir::Value operand : adaptor.getOperands().drop_front())
983+
argTypes.push_back(operand.getType());
984+
mlir::Type resultType;
985+
if (dispatch.getResults().empty())
986+
resultType = mlir::LLVM::LLVMVoidType::get(dispatch.getContext());
987+
else
988+
resultType = convertType(dispatch.getResults()[0].getType());
989+
auto fctType = mlir::LLVM::LLVMFunctionType::get(resultType, argTypes,
990+
/*isVarArg=*/false);
991+
992+
// Get the function pointer.
993+
auto builtinFuncPtr =
994+
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, binding, 0);
995+
auto funcAddr =
996+
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, builtinFuncPtr, 0);
997+
auto funcPtr = rewriter.create<mlir::LLVM::IntToPtrOp>(
998+
loc, mlir::LLVM::LLVMPointerType::get(fctType), funcAddr);
999+
1000+
// Indirect calls are done with the function pointer as the first operand.
1001+
llvm::SmallVector<mlir::Value> args;
1002+
args.push_back(funcPtr);
1003+
for (mlir::Value operand : adaptor.getOperands().drop_front())
1004+
args.push_back(operand);
1005+
auto callOp = rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
1006+
dispatch,
1007+
dispatch.getResults().empty() ? mlir::TypeRange{}
1008+
: fctType.getReturnType(),
1009+
"", args);
1010+
callOp.removeCalleeAttr(); // Indirect calls do not have callee attr.
1011+
1012+
return mlir::success();
8981013
}
8991014
};
9001015

@@ -1127,7 +1242,7 @@ template <typename OP>
11271242
struct EmboxCommonConversion : public FIROpConversion<OP> {
11281243
using FIROpConversion<OP>::FIROpConversion;
11291244

1130-
static int getCFIAttr(fir::BoxType boxTy) {
1245+
static int getCFIAttr(fir::BaseBoxType boxTy) {
11311246
auto eleTy = boxTy.getEleTy();
11321247
if (eleTy.isa<fir::PointerType>())
11331248
return CFI_attribute_pointer;
@@ -1136,15 +1251,15 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
11361251
return CFI_attribute_other;
11371252
}
11381253

1139-
static fir::RecordType unwrapIfDerived(fir::BoxType boxTy) {
1254+
static fir::RecordType unwrapIfDerived(fir::BaseBoxType boxTy) {
11401255
return fir::unwrapSequenceType(fir::dyn_cast_ptrOrBoxEleTy(boxTy))
11411256
.template dyn_cast<fir::RecordType>();
11421257
}
1143-
static bool isDerivedTypeWithLenParams(fir::BoxType boxTy) {
1258+
static bool isDerivedTypeWithLenParams(fir::BaseBoxType boxTy) {
11441259
auto recTy = unwrapIfDerived(boxTy);
11451260
return recTy && recTy.getNumLenParams() > 0;
11461261
}
1147-
static bool isDerivedType(fir::BoxType boxTy) {
1262+
static bool isDerivedType(fir::BaseBoxType boxTy) {
11481263
return static_cast<bool>(unwrapIfDerived(boxTy));
11491264
}
11501265

@@ -1342,11 +1457,11 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
13421457
}
13431458

13441459
template <typename BOX>
1345-
std::tuple<fir::BoxType, mlir::Value, mlir::Value>
1460+
std::tuple<fir::BaseBoxType, mlir::Value, mlir::Value>
13461461
consDescriptorPrefix(BOX box, mlir::ConversionPatternRewriter &rewriter,
13471462
unsigned rank, mlir::ValueRange lenParams) const {
13481463
auto loc = box.getLoc();
1349-
auto boxTy = box.getType().template dyn_cast<fir::BoxType>();
1464+
auto boxTy = box.getType().template dyn_cast<fir::BaseBoxType>();
13501465
auto convTy = this->lowerTy().convertBoxType(boxTy, rank);
13511466
auto llvmBoxPtrTy = convTy.template cast<mlir::LLVM::LLVMPointerType>();
13521467
auto llvmBoxTy = llvmBoxPtrTy.getElementType();
@@ -3367,7 +3482,7 @@ class FIRToLLVMLowering
33673482
// and binding index for later use by the fir.dispatch conversion pattern.
33683483
BindingTables bindingTables;
33693484
for (auto globalOp : mod.getOps<fir::GlobalOp>()) {
3370-
if (globalOp.getSymName().contains(".v.")) {
3485+
if (globalOp.getSymName().contains(bindingTableSeparator)) {
33713486
unsigned bindingIdx = 0;
33723487
BindingTable bindings;
33733488
for (auto addrOp : globalOp.getRegion().getOps<fir::AddrOfOp>()) {

flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -277,10 +277,8 @@ class CodeGenRewrite : public fir::impl::CodeGenRewriteBase<CodeGenRewrite> {
277277
target.addIllegalOp<fir::ArrayCoorOp>();
278278
target.addIllegalOp<fir::ReboxOp>();
279279
target.addDynamicallyLegalOp<fir::EmboxOp>([](fir::EmboxOp embox) {
280-
if (embox.getType().isa<fir::ClassType>())
281-
TODO(embox.getLoc(), "fir.class type CodeGenRewrite");
282280
return !(embox.getShape() || embox.getType()
283-
.cast<fir::BoxType>()
281+
.cast<fir::BaseBoxType>()
284282
.getEleTy()
285283
.isa<fir::SequenceType>());
286284
});

flang/lib/Optimizer/CodeGen/TypeConverter.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,8 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
6464
// procedure pointer feature is implemented.
6565
return llvm::None;
6666
});
67-
addConversion([&](fir::ClassType classTy) {
68-
TODO_NOLOC("fir.class type conversion");
69-
return llvm::None;
70-
});
67+
addConversion(
68+
[&](fir::ClassType classTy) { return convertBoxType(classTy); });
7169
addConversion(
7270
[&](fir::CharacterType charTy) { return convertCharType(charTy); });
7371
addConversion(
@@ -203,7 +201,7 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
203201

204202
// This corresponds to the descriptor as defined in ISO_Fortran_binding.h and
205203
// the addendum defined in descriptor.h.
206-
mlir::Type convertBoxType(BoxType box, int rank = unknownRank()) {
204+
mlir::Type convertBoxType(BaseBoxType box, int rank = unknownRank()) {
207205
// (base_addr*, elem_len, version, rank, type, attribute, f18Addendum, [dim]
208206
llvm::SmallVector<mlir::Type> dataDescFields;
209207
mlir::Type ele = box.getEleTy();

flang/test/Fir/Todo/dispatch.fir

Lines changed: 0 additions & 10 deletions
This file was deleted.

0 commit comments

Comments
 (0)