@@ -893,8 +893,123 @@ struct DispatchOpConversion : public FIROpConversion<fir::DispatchOp> {
893
893
mlir::LogicalResult
894
894
matchAndRewrite (fir::DispatchOp dispatch, OpAdaptor adaptor,
895
895
mlir::ConversionPatternRewriter &rewriter) const override {
896
- TODO (dispatch.getLoc (), " fir.dispatch codegen" );
897
- return mlir::failure ();
896
+ mlir::Location loc = dispatch.getLoc ();
897
+
898
+ if (bindingTables.empty ())
899
+ return emitError (loc) << " no binding tables found" ;
900
+
901
+ if (dispatch.getObject ()
902
+ .getType ()
903
+ .getEleTy ()
904
+ .isa <fir::HeapType, fir::PointerType>())
905
+ TODO (loc,
906
+ " fir.dispatch with allocatable or pointer polymorphic entities" );
907
+
908
+ // Get derived type information.
909
+ auto declaredType = dispatch.getObject ().getType ().getEleTy ();
910
+ assert (declaredType.isa <fir::RecordType>() && " expecting fir.type" );
911
+ auto recordType = declaredType.dyn_cast <fir::RecordType>();
912
+ std::string typeDescName =
913
+ fir::NameUniquer::getTypeDescriptorName (recordType.getName ());
914
+ std::string typeDescBindingTableName =
915
+ fir::NameUniquer::getTypeDescriptorBindingTableName (
916
+ recordType.getName ());
917
+
918
+ // Lookup for the binding table.
919
+ auto bindingsIter = bindingTables.find (typeDescBindingTableName);
920
+ if (bindingsIter == bindingTables.end ())
921
+ return emitError (loc)
922
+ << " cannot find binding table for " << typeDescBindingTableName;
923
+
924
+ // Lookup for the binding.
925
+ const BindingTable &bindingTable = bindingsIter->second ;
926
+ auto bindingIter = bindingTable.find (dispatch.getMethod ());
927
+ if (bindingIter == bindingTable.end ())
928
+ return emitError (loc)
929
+ << " cannot find binding for " << dispatch.getMethod ();
930
+ unsigned bindingIdx = bindingIter->second ;
931
+
932
+ mlir::Value passedObject = dispatch.getObject ();
933
+
934
+ auto module = dispatch.getOperation ()->getParentOfType <mlir::ModuleOp>();
935
+ mlir::Type typeDescTy;
936
+ if (auto global = module .lookupSymbol <fir::GlobalOp>(typeDescName)) {
937
+ typeDescTy = convertType (global.getType ());
938
+ } else if (auto global =
939
+ module .lookupSymbol <mlir::LLVM::GlobalOp>(typeDescName)) {
940
+ // The global may have already been translated to LLVM.
941
+ typeDescTy = global.getType ();
942
+ }
943
+
944
+ auto isArray = fir::dyn_cast_ptrOrBoxEleTy (passedObject.getType ())
945
+ .template isa <fir::SequenceType>();
946
+ unsigned typeDescFieldId = isArray ? kOptTypePtrPosInBox : kDimsPosInBox ;
947
+
948
+ auto descPtr = adaptor.getOperands ()[0 ]
949
+ .getType ()
950
+ .dyn_cast <mlir::LLVM::LLVMPointerType>();
951
+
952
+ // Load the descriptor.
953
+ auto desc = rewriter.create <mlir::LLVM::LoadOp>(
954
+ loc, descPtr.getElementType (), adaptor.getOperands ()[0 ]);
955
+
956
+ // Load the type descriptor.
957
+ auto typeDescPtr =
958
+ rewriter.create <mlir::LLVM::ExtractValueOp>(loc, desc, typeDescFieldId);
959
+ auto typeDesc =
960
+ rewriter.create <mlir::LLVM::LoadOp>(loc, typeDescTy, typeDescPtr);
961
+
962
+ // Load the bindings descriptor.
963
+ auto typeDescStructTy = typeDescTy.dyn_cast <mlir::LLVM::LLVMStructType>();
964
+ auto bindingDescType =
965
+ typeDescStructTy.getBody ()[0 ].dyn_cast <mlir::LLVM::LLVMStructType>();
966
+ auto bindingDesc =
967
+ rewriter.create <mlir::LLVM::ExtractValueOp>(loc, typeDesc, 0 );
968
+
969
+ // Load the correct binding.
970
+ auto bindingType =
971
+ bindingDescType.getBody ()[0 ].dyn_cast <mlir::LLVM::LLVMPointerType>();
972
+ auto baseBindingPtr = rewriter.create <mlir::LLVM::ExtractValueOp>(
973
+ loc, bindingDesc, kAddrPosInBox );
974
+ auto bindingPtr = rewriter.create <mlir::LLVM::GEPOp>(
975
+ loc, bindingType, baseBindingPtr,
976
+ llvm::ArrayRef<mlir::LLVM::GEPArg>{static_cast <int32_t >(bindingIdx)});
977
+ auto binding = rewriter.create <mlir::LLVM::LoadOp>(
978
+ loc, bindingType.getElementType (), bindingPtr);
979
+
980
+ // Get the function type.
981
+ llvm::SmallVector<mlir::Type> argTypes;
982
+ for (mlir::Value operand : adaptor.getOperands ().drop_front ())
983
+ argTypes.push_back (operand.getType ());
984
+ mlir::Type resultType;
985
+ if (dispatch.getResults ().empty ())
986
+ resultType = mlir::LLVM::LLVMVoidType::get (dispatch.getContext ());
987
+ else
988
+ resultType = convertType (dispatch.getResults ()[0 ].getType ());
989
+ auto fctType = mlir::LLVM::LLVMFunctionType::get (resultType, argTypes,
990
+ /* isVarArg=*/ false );
991
+
992
+ // Get the function pointer.
993
+ auto builtinFuncPtr =
994
+ rewriter.create <mlir::LLVM::ExtractValueOp>(loc, binding, 0 );
995
+ auto funcAddr =
996
+ rewriter.create <mlir::LLVM::ExtractValueOp>(loc, builtinFuncPtr, 0 );
997
+ auto funcPtr = rewriter.create <mlir::LLVM::IntToPtrOp>(
998
+ loc, mlir::LLVM::LLVMPointerType::get (fctType), funcAddr);
999
+
1000
+ // Indirect calls are done with the function pointer as the first operand.
1001
+ llvm::SmallVector<mlir::Value> args;
1002
+ args.push_back (funcPtr);
1003
+ for (mlir::Value operand : adaptor.getOperands ().drop_front ())
1004
+ args.push_back (operand);
1005
+ auto callOp = rewriter.replaceOpWithNewOp <mlir::LLVM::CallOp>(
1006
+ dispatch,
1007
+ dispatch.getResults ().empty () ? mlir::TypeRange{}
1008
+ : fctType.getReturnType (),
1009
+ " " , args);
1010
+ callOp.removeCalleeAttr (); // Indirect calls do not have callee attr.
1011
+
1012
+ return mlir::success ();
898
1013
}
899
1014
};
900
1015
@@ -1127,7 +1242,7 @@ template <typename OP>
1127
1242
struct EmboxCommonConversion : public FIROpConversion <OP> {
1128
1243
using FIROpConversion<OP>::FIROpConversion;
1129
1244
1130
- static int getCFIAttr (fir::BoxType boxTy) {
1245
+ static int getCFIAttr (fir::BaseBoxType boxTy) {
1131
1246
auto eleTy = boxTy.getEleTy ();
1132
1247
if (eleTy.isa <fir::PointerType>())
1133
1248
return CFI_attribute_pointer;
@@ -1136,15 +1251,15 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
1136
1251
return CFI_attribute_other;
1137
1252
}
1138
1253
1139
- static fir::RecordType unwrapIfDerived (fir::BoxType boxTy) {
1254
+ static fir::RecordType unwrapIfDerived (fir::BaseBoxType boxTy) {
1140
1255
return fir::unwrapSequenceType (fir::dyn_cast_ptrOrBoxEleTy (boxTy))
1141
1256
.template dyn_cast <fir::RecordType>();
1142
1257
}
1143
- static bool isDerivedTypeWithLenParams (fir::BoxType boxTy) {
1258
+ static bool isDerivedTypeWithLenParams (fir::BaseBoxType boxTy) {
1144
1259
auto recTy = unwrapIfDerived (boxTy);
1145
1260
return recTy && recTy.getNumLenParams () > 0 ;
1146
1261
}
1147
- static bool isDerivedType (fir::BoxType boxTy) {
1262
+ static bool isDerivedType (fir::BaseBoxType boxTy) {
1148
1263
return static_cast <bool >(unwrapIfDerived (boxTy));
1149
1264
}
1150
1265
@@ -1342,11 +1457,11 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
1342
1457
}
1343
1458
1344
1459
template <typename BOX>
1345
- std::tuple<fir::BoxType , mlir::Value, mlir::Value>
1460
+ std::tuple<fir::BaseBoxType , mlir::Value, mlir::Value>
1346
1461
consDescriptorPrefix (BOX box, mlir::ConversionPatternRewriter &rewriter,
1347
1462
unsigned rank, mlir::ValueRange lenParams) const {
1348
1463
auto loc = box.getLoc ();
1349
- auto boxTy = box.getType ().template dyn_cast <fir::BoxType >();
1464
+ auto boxTy = box.getType ().template dyn_cast <fir::BaseBoxType >();
1350
1465
auto convTy = this ->lowerTy ().convertBoxType (boxTy, rank);
1351
1466
auto llvmBoxPtrTy = convTy.template cast <mlir::LLVM::LLVMPointerType>();
1352
1467
auto llvmBoxTy = llvmBoxPtrTy.getElementType ();
@@ -3367,7 +3482,7 @@ class FIRToLLVMLowering
3367
3482
// and binding index for later use by the fir.dispatch conversion pattern.
3368
3483
BindingTables bindingTables;
3369
3484
for (auto globalOp : mod.getOps <fir::GlobalOp>()) {
3370
- if (globalOp.getSymName ().contains (" .v. " )) {
3485
+ if (globalOp.getSymName ().contains (bindingTableSeparator )) {
3371
3486
unsigned bindingIdx = 0 ;
3372
3487
BindingTable bindings;
3373
3488
for (auto addrOp : globalOp.getRegion ().getOps <fir::AddrOfOp>()) {
0 commit comments