@@ -62,14 +62,21 @@ struct FixupTy {
62
62
FixupTy (Codes code, std::size_t index,
63
63
std::function<void (mlir::func::FuncOp)> &&finalizer)
64
64
: code{code}, index{index}, finalizer{finalizer} {}
65
+ FixupTy (Codes code, std::size_t index,
66
+ std::function<void (mlir::gpu::GPUFuncOp)> &&finalizer)
67
+ : code{code}, index{index}, gpuFinalizer{finalizer} {}
65
68
FixupTy (Codes code, std::size_t index, std::size_t second,
66
69
std::function<void (mlir::func::FuncOp)> &&finalizer)
67
70
: code{code}, index{index}, second{second}, finalizer{finalizer} {}
71
+ FixupTy (Codes code, std::size_t index, std::size_t second,
72
+ std::function<void (mlir::gpu::GPUFuncOp)> &&finalizer)
73
+ : code{code}, index{index}, second{second}, gpuFinalizer{finalizer} {}
68
74
69
75
Codes code;
70
76
std::size_t index;
71
77
std::size_t second{};
72
78
std::optional<std::function<void (mlir::func::FuncOp)>> finalizer{};
79
+ std::optional<std::function<void (mlir::gpu::GPUFuncOp)>> gpuFinalizer{};
73
80
}; // namespace
74
81
75
82
// / Target-specific rewriting of the FIR. This is a prerequisite pass to code
@@ -719,12 +726,15 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
719
726
if (targetFeaturesAttr)
720
727
fn->setAttr (" target_features" , targetFeaturesAttr);
721
728
722
- convertSignature (fn);
729
+ convertSignature<mlir::func::ReturnOp, mlir::func::FuncOp> (fn);
723
730
}
724
731
725
- for (auto gpuMod : mod.getOps <mlir::gpu::GPUModuleOp>())
732
+ for (auto gpuMod : mod.getOps <mlir::gpu::GPUModuleOp>()) {
726
733
for (auto fn : gpuMod.getOps <mlir::func::FuncOp>())
727
- convertSignature (fn);
734
+ convertSignature<mlir::func::ReturnOp, mlir::func::FuncOp>(fn);
735
+ for (auto fn : gpuMod.getOps <mlir::gpu::GPUFuncOp>())
736
+ convertSignature<mlir::gpu::ReturnOp, mlir::gpu::GPUFuncOp>(fn);
737
+ }
728
738
729
739
return mlir::success ();
730
740
}
@@ -770,17 +780,20 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
770
780
771
781
// / Determine if the signature has host associations. The host association
772
782
// / argument may need special target specific rewriting.
773
- static bool hasHostAssociations (mlir::func::FuncOp func) {
783
+ template <typename OpTy>
784
+ static bool hasHostAssociations (OpTy func) {
774
785
std::size_t end = func.getFunctionType ().getInputs ().size ();
775
786
for (std::size_t i = 0 ; i < end; ++i)
776
- if (func.getArgAttrOfType <mlir::UnitAttr>(i, fir::getHostAssocAttrName ()))
787
+ if (func.template getArgAttrOfType <mlir::UnitAttr>(
788
+ i, fir::getHostAssocAttrName ()))
777
789
return true ;
778
790
return false ;
779
791
}
780
792
781
793
// / Rewrite the signatures and body of the `FuncOp`s in the module for
782
794
// / the immediately subsequent target code gen.
783
- void convertSignature (mlir::func::FuncOp func) {
795
+ template <typename ReturnOpTy, typename FuncOpTy>
796
+ void convertSignature (FuncOpTy func) {
784
797
auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType ());
785
798
if (hasPortableSignature (funcTy, func) && !hasHostAssociations (func))
786
799
return ;
@@ -805,13 +818,13 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
805
818
// Convert return value(s)
806
819
for (auto ty : funcTy.getResults ())
807
820
llvm::TypeSwitch<mlir::Type>(ty)
808
- .Case <mlir::ComplexType>([&](mlir::ComplexType cmplx) {
821
+ .template Case <mlir::ComplexType>([&](mlir::ComplexType cmplx) {
809
822
if (noComplexConversion)
810
823
newResTys.push_back (cmplx);
811
824
else
812
825
doComplexReturn (func, cmplx, newResTys, newInTyAndAttrs, fixups);
813
826
})
814
- .Case <mlir::IntegerType>([&](mlir::IntegerType intTy) {
827
+ .template Case <mlir::IntegerType>([&](mlir::IntegerType intTy) {
815
828
auto m = specifics->integerArgumentType (func.getLoc (), intTy);
816
829
assert (m.size () == 1 );
817
830
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0 ]);
@@ -825,7 +838,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
825
838
rewriter->getUnitAttr ()));
826
839
newResTys.push_back (retTy);
827
840
})
828
- .Case <fir::RecordType>([&](fir::RecordType recTy) {
841
+ .template Case <fir::RecordType>([&](fir::RecordType recTy) {
829
842
doStructReturn (func, recTy, newResTys, newInTyAndAttrs, fixups);
830
843
})
831
844
.Default ([&](mlir::Type ty) { newResTys.push_back (ty); });
@@ -840,7 +853,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
840
853
auto ty = e.value ();
841
854
unsigned index = e.index ();
842
855
llvm::TypeSwitch<mlir::Type>(ty)
843
- .Case <fir::BoxCharType>([&](fir::BoxCharType boxTy) {
856
+ .template Case <fir::BoxCharType>([&](fir::BoxCharType boxTy) {
844
857
if (noCharacterConversion) {
845
858
newInTyAndAttrs.push_back (
846
859
fir::CodeGenSpecifics::getTypeAndAttr (boxTy));
@@ -863,10 +876,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
863
876
}
864
877
}
865
878
})
866
- .Case <mlir::ComplexType>([&](mlir::ComplexType cmplx) {
879
+ .template Case <mlir::ComplexType>([&](mlir::ComplexType cmplx) {
867
880
doComplexArg (func, cmplx, newInTyAndAttrs, fixups);
868
881
})
869
- .Case <mlir::TupleType>([&](mlir::TupleType tuple) {
882
+ .template Case <mlir::TupleType>([&](mlir::TupleType tuple) {
870
883
if (fir::isCharacterProcedureTuple (tuple)) {
871
884
fixups.emplace_back (FixupTy::Codes::TrailingCharProc,
872
885
newInTyAndAttrs.size (), trailingTys.size ());
@@ -878,7 +891,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
878
891
fir::CodeGenSpecifics::getTypeAndAttr (ty));
879
892
}
880
893
})
881
- .Case <mlir::IntegerType>([&](mlir::IntegerType intTy) {
894
+ .template Case <mlir::IntegerType>([&](mlir::IntegerType intTy) {
882
895
auto m = specifics->integerArgumentType (func.getLoc (), intTy);
883
896
assert (m.size () == 1 );
884
897
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0 ]);
@@ -887,7 +900,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
887
900
if (!extensionAttrName.empty () &&
888
901
isFuncWithCCallingConvention (func))
889
902
fixups.emplace_back (FixupTy::Codes::ArgumentType, argNo,
890
- [=](mlir::func::FuncOp func) {
903
+ [=](FuncOpTy func) {
891
904
func.setArgAttr (
892
905
argNo, extensionAttrName,
893
906
mlir::UnitAttr::get (func.getContext ()));
@@ -903,8 +916,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
903
916
fir::CodeGenSpecifics::getTypeAndAttr (ty));
904
917
});
905
918
906
- if (func.getArgAttrOfType <mlir::UnitAttr>(index,
907
- fir::getHostAssocAttrName ())) {
919
+ if (func.template getArgAttrOfType <mlir::UnitAttr>(
920
+ index, fir::getHostAssocAttrName ())) {
908
921
extraAttrs.push_back (
909
922
{newInTyAndAttrs.size () - 1 ,
910
923
rewriter->getNamedAttr (" llvm.nest" , rewriter->getUnitAttr ())});
@@ -979,27 +992,27 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
979
992
auto newArg =
980
993
func.front ().insertArgument (fixup.index , fixupType, loc);
981
994
offset++;
982
- func.walk ([&](mlir::func::ReturnOp ret) {
995
+ func.walk ([&](ReturnOpTy ret) {
983
996
rewriter->setInsertionPoint (ret);
984
997
auto oldOper = ret.getOperand (0 );
985
998
auto oldOperTy = fir::ReferenceType::get (oldOper.getType ());
986
999
auto cast =
987
1000
rewriter->create <fir::ConvertOp>(loc, oldOperTy, newArg);
988
1001
rewriter->create <fir::StoreOp>(loc, oldOper, cast);
989
- rewriter->create <mlir::func::ReturnOp >(loc);
1002
+ rewriter->create <ReturnOpTy >(loc);
990
1003
ret.erase ();
991
1004
});
992
1005
} break ;
993
1006
case FixupTy::Codes::ReturnType: {
994
1007
// The function is still returning a value, but its type has likely
995
1008
// changed to suit the target ABI convention.
996
- func.walk ([&](mlir::func::ReturnOp ret) {
1009
+ func.walk ([&](ReturnOpTy ret) {
997
1010
rewriter->setInsertionPoint (ret);
998
1011
auto oldOper = ret.getOperand (0 );
999
1012
mlir::Value bitcast =
1000
1013
convertValueInMemory (loc, oldOper, newResTys[fixup.index ],
1001
1014
/* inputMayBeBigger=*/ false );
1002
- rewriter->create <mlir::func::ReturnOp >(loc, bitcast);
1015
+ rewriter->create <ReturnOpTy >(loc, bitcast);
1003
1016
ret.erase ();
1004
1017
});
1005
1018
} break ;
@@ -1101,13 +1114,18 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
1101
1114
}
1102
1115
}
1103
1116
1104
- for (auto &fixup : fixups)
1105
- if (fixup.finalizer )
1106
- (*fixup.finalizer )(func);
1117
+ for (auto &fixup : fixups) {
1118
+ if constexpr (std::is_same_v<FuncOpTy, mlir::func::FuncOp>)
1119
+ if (fixup.finalizer )
1120
+ (*fixup.finalizer )(func);
1121
+ if constexpr (std::is_same_v<FuncOpTy, mlir::gpu::GPUFuncOp>)
1122
+ if (fixup.gpuFinalizer )
1123
+ (*fixup.gpuFinalizer )(func);
1124
+ }
1107
1125
}
1108
1126
1109
- template <typename Ty, typename FIXUPS>
1110
- void doReturn (mlir::func::FuncOp func, Ty &newResTys,
1127
+ template <typename OpTy, typename Ty, typename FIXUPS>
1128
+ void doReturn (OpTy func, Ty &newResTys,
1111
1129
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1112
1130
FIXUPS &fixups, fir::CodeGenSpecifics::Marshalling &m) {
1113
1131
assert (m.size () == 1 &&
@@ -1119,7 +1137,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
1119
1137
unsigned argNo = newInTyAndAttrs.size ();
1120
1138
if (auto align = attr.getAlignment ())
1121
1139
fixups.emplace_back (
1122
- FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::func::FuncOp func) {
1140
+ FixupTy::Codes::ReturnAsStore, argNo, [=](OpTy func) {
1123
1141
auto elemType = fir::dyn_cast_ptrOrBoxEleTy (
1124
1142
func.getFunctionType ().getInput (argNo));
1125
1143
func.setArgAttr (argNo, " llvm.sret" ,
@@ -1130,7 +1148,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
1130
1148
});
1131
1149
else
1132
1150
fixups.emplace_back (FixupTy::Codes::ReturnAsStore, argNo,
1133
- [=](mlir::func::FuncOp func) {
1151
+ [=](OpTy func) {
1134
1152
auto elemType = fir::dyn_cast_ptrOrBoxEleTy (
1135
1153
func.getFunctionType ().getInput (argNo));
1136
1154
func.setArgAttr (argNo, " llvm.sret" ,
@@ -1141,8 +1159,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
1141
1159
}
1142
1160
if (auto align = attr.getAlignment ())
1143
1161
fixups.emplace_back (
1144
- FixupTy::Codes::ReturnType, newResTys.size (),
1145
- [=](mlir::func::FuncOp func) {
1162
+ FixupTy::Codes::ReturnType, newResTys.size (), [=](OpTy func) {
1146
1163
func.setArgAttr (
1147
1164
newResTys.size (), " llvm.align" ,
1148
1165
rewriter->getIntegerAttr (rewriter->getIntegerType (32 ), align));
@@ -1155,9 +1172,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
1155
1172
// / Convert a complex return value. This can involve converting the return
1156
1173
// / value to a "hidden" first argument or packing the complex into a wide
1157
1174
// / GPR.
1158
- template <typename Ty, typename FIXUPS>
1159
- void doComplexReturn (mlir::func::FuncOp func, mlir::ComplexType cmplx,
1160
- Ty &newResTys,
1175
+ template <typename OpTy, typename Ty, typename FIXUPS>
1176
+ void doComplexReturn (OpTy func, mlir::ComplexType cmplx, Ty &newResTys,
1161
1177
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1162
1178
FIXUPS &fixups) {
1163
1179
if (noComplexConversion) {
@@ -1169,9 +1185,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
1169
1185
doReturn (func, newResTys, newInTyAndAttrs, fixups, m);
1170
1186
}
1171
1187
1172
- template <typename Ty, typename FIXUPS>
1173
- void doStructReturn (mlir::func::FuncOp func, fir::RecordType recTy,
1174
- Ty &newResTys,
1188
+ template <typename OpTy, typename Ty, typename FIXUPS>
1189
+ void doStructReturn (OpTy func, fir::RecordType recTy, Ty &newResTys,
1175
1190
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1176
1191
FIXUPS &fixups) {
1177
1192
if (noStructConversion) {
@@ -1182,12 +1197,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
1182
1197
doReturn (func, newResTys, newInTyAndAttrs, fixups, m);
1183
1198
}
1184
1199
1185
- template <typename FIXUPS>
1186
- void
1187
- createFuncOpArgFixups (mlir::func::FuncOp func,
1188
- fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1189
- fir::CodeGenSpecifics::Marshalling &argsInTys,
1190
- FIXUPS &fixups) {
1200
+ template <typename OpTy, typename FIXUPS>
1201
+ void createFuncOpArgFixups (
1202
+ OpTy func, fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1203
+ fir::CodeGenSpecifics::Marshalling &argsInTys, FIXUPS &fixups) {
1191
1204
const auto fixupCode = argsInTys.size () > 1 ? FixupTy::Codes::Split
1192
1205
: FixupTy::Codes::ArgumentType;
1193
1206
for (auto e : llvm::enumerate (argsInTys)) {
@@ -1198,7 +1211,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
1198
1211
if (attr.isByVal ()) {
1199
1212
if (auto align = attr.getAlignment ())
1200
1213
fixups.emplace_back (FixupTy::Codes::ArgumentAsLoad, argNo,
1201
- [=](mlir::func::FuncOp func) {
1214
+ [=](OpTy func) {
1202
1215
auto elemType = fir::dyn_cast_ptrOrBoxEleTy (
1203
1216
func.getFunctionType ().getInput (argNo));
1204
1217
func.setArgAttr (argNo, " llvm.byval" ,
@@ -1210,8 +1223,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
1210
1223
});
1211
1224
else
1212
1225
fixups.emplace_back (FixupTy::Codes::ArgumentAsLoad,
1213
- newInTyAndAttrs.size (),
1214
- [=](mlir::func::FuncOp func) {
1226
+ newInTyAndAttrs.size (), [=](OpTy func) {
1215
1227
auto elemType = fir::dyn_cast_ptrOrBoxEleTy (
1216
1228
func.getFunctionType ().getInput (argNo));
1217
1229
func.setArgAttr (argNo, " llvm.byval" ,
@@ -1220,7 +1232,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
1220
1232
} else {
1221
1233
if (auto align = attr.getAlignment ())
1222
1234
fixups.emplace_back (
1223
- fixupCode, argNo, index, [=](mlir::func::FuncOp func) {
1235
+ fixupCode, argNo, index, [=](OpTy func) {
1224
1236
func.setArgAttr (argNo, " llvm.align" ,
1225
1237
rewriter->getIntegerAttr (
1226
1238
rewriter->getIntegerType (32 ), align));
@@ -1235,8 +1247,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
1235
1247
// / Convert a complex argument value. This can involve storing the value to
1236
1248
// / a temporary memory location or factoring the value into two distinct
1237
1249
// / arguments.
1238
- template <typename FIXUPS>
1239
- void doComplexArg (mlir::func::FuncOp func, mlir::ComplexType cmplx,
1250
+ template <typename OpTy, typename FIXUPS>
1251
+ void doComplexArg (OpTy func, mlir::ComplexType cmplx,
1240
1252
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1241
1253
FIXUPS &fixups) {
1242
1254
if (noComplexConversion) {
@@ -1248,8 +1260,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
1248
1260
createFuncOpArgFixups (func, newInTyAndAttrs, cplxArgs, fixups);
1249
1261
}
1250
1262
1251
- template <typename FIXUPS>
1252
- void doStructArg (mlir::func::FuncOp func, fir::RecordType recTy,
1263
+ template <typename OpTy, typename FIXUPS>
1264
+ void doStructArg (OpTy func, fir::RecordType recTy,
1253
1265
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1254
1266
FIXUPS &fixups) {
1255
1267
if (noStructConversion) {
0 commit comments