Skip to content

Commit dc5236e

Browse files
authored
[flang][cuda] Update target rewrite to work on gpu.func (#119283)
Update the pass so it can perform the signature rewrite on gpu.func.
1 parent 1d0ca62 commit dc5236e

File tree

2 files changed

+74
-50
lines changed

2 files changed

+74
-50
lines changed

flang/lib/Optimizer/CodeGen/TargetRewrite.cpp

Lines changed: 61 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,21 @@ struct FixupTy {
6262
FixupTy(Codes code, std::size_t index,
6363
std::function<void(mlir::func::FuncOp)> &&finalizer)
6464
: code{code}, index{index}, finalizer{finalizer} {}
65+
FixupTy(Codes code, std::size_t index,
66+
std::function<void(mlir::gpu::GPUFuncOp)> &&finalizer)
67+
: code{code}, index{index}, gpuFinalizer{finalizer} {}
6568
FixupTy(Codes code, std::size_t index, std::size_t second,
6669
std::function<void(mlir::func::FuncOp)> &&finalizer)
6770
: code{code}, index{index}, second{second}, finalizer{finalizer} {}
71+
FixupTy(Codes code, std::size_t index, std::size_t second,
72+
std::function<void(mlir::gpu::GPUFuncOp)> &&finalizer)
73+
: code{code}, index{index}, second{second}, gpuFinalizer{finalizer} {}
6874

6975
Codes code;
7076
std::size_t index;
7177
std::size_t second{};
7278
std::optional<std::function<void(mlir::func::FuncOp)>> finalizer{};
79+
std::optional<std::function<void(mlir::gpu::GPUFuncOp)>> gpuFinalizer{};
7380
}; // namespace
7481

7582
/// Target-specific rewriting of the FIR. This is a prerequisite pass to code
@@ -719,12 +726,15 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
719726
if (targetFeaturesAttr)
720727
fn->setAttr("target_features", targetFeaturesAttr);
721728

722-
convertSignature(fn);
729+
convertSignature<mlir::func::ReturnOp, mlir::func::FuncOp>(fn);
723730
}
724731

725-
for (auto gpuMod : mod.getOps<mlir::gpu::GPUModuleOp>())
732+
for (auto gpuMod : mod.getOps<mlir::gpu::GPUModuleOp>()) {
726733
for (auto fn : gpuMod.getOps<mlir::func::FuncOp>())
727-
convertSignature(fn);
734+
convertSignature<mlir::func::ReturnOp, mlir::func::FuncOp>(fn);
735+
for (auto fn : gpuMod.getOps<mlir::gpu::GPUFuncOp>())
736+
convertSignature<mlir::gpu::ReturnOp, mlir::gpu::GPUFuncOp>(fn);
737+
}
728738

729739
return mlir::success();
730740
}
@@ -770,17 +780,20 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
770780

771781
/// Determine if the signature has host associations. The host association
772782
/// argument may need special target specific rewriting.
773-
static bool hasHostAssociations(mlir::func::FuncOp func) {
783+
template <typename OpTy>
784+
static bool hasHostAssociations(OpTy func) {
774785
std::size_t end = func.getFunctionType().getInputs().size();
775786
for (std::size_t i = 0; i < end; ++i)
776-
if (func.getArgAttrOfType<mlir::UnitAttr>(i, fir::getHostAssocAttrName()))
787+
if (func.template getArgAttrOfType<mlir::UnitAttr>(
788+
i, fir::getHostAssocAttrName()))
777789
return true;
778790
return false;
779791
}
780792

781793
/// Rewrite the signatures and body of the `FuncOp`s in the module for
782794
/// the immediately subsequent target code gen.
783-
void convertSignature(mlir::func::FuncOp func) {
795+
template <typename ReturnOpTy, typename FuncOpTy>
796+
void convertSignature(FuncOpTy func) {
784797
auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType());
785798
if (hasPortableSignature(funcTy, func) && !hasHostAssociations(func))
786799
return;
@@ -805,13 +818,13 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
805818
// Convert return value(s)
806819
for (auto ty : funcTy.getResults())
807820
llvm::TypeSwitch<mlir::Type>(ty)
808-
.Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
821+
.template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
809822
if (noComplexConversion)
810823
newResTys.push_back(cmplx);
811824
else
812825
doComplexReturn(func, cmplx, newResTys, newInTyAndAttrs, fixups);
813826
})
814-
.Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
827+
.template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
815828
auto m = specifics->integerArgumentType(func.getLoc(), intTy);
816829
assert(m.size() == 1);
817830
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
@@ -825,7 +838,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
825838
rewriter->getUnitAttr()));
826839
newResTys.push_back(retTy);
827840
})
828-
.Case<fir::RecordType>([&](fir::RecordType recTy) {
841+
.template Case<fir::RecordType>([&](fir::RecordType recTy) {
829842
doStructReturn(func, recTy, newResTys, newInTyAndAttrs, fixups);
830843
})
831844
.Default([&](mlir::Type ty) { newResTys.push_back(ty); });
@@ -840,7 +853,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
840853
auto ty = e.value();
841854
unsigned index = e.index();
842855
llvm::TypeSwitch<mlir::Type>(ty)
843-
.Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) {
856+
.template Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) {
844857
if (noCharacterConversion) {
845858
newInTyAndAttrs.push_back(
846859
fir::CodeGenSpecifics::getTypeAndAttr(boxTy));
@@ -863,10 +876,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
863876
}
864877
}
865878
})
866-
.Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
879+
.template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
867880
doComplexArg(func, cmplx, newInTyAndAttrs, fixups);
868881
})
869-
.Case<mlir::TupleType>([&](mlir::TupleType tuple) {
882+
.template Case<mlir::TupleType>([&](mlir::TupleType tuple) {
870883
if (fir::isCharacterProcedureTuple(tuple)) {
871884
fixups.emplace_back(FixupTy::Codes::TrailingCharProc,
872885
newInTyAndAttrs.size(), trailingTys.size());
@@ -878,7 +891,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
878891
fir::CodeGenSpecifics::getTypeAndAttr(ty));
879892
}
880893
})
881-
.Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
894+
.template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
882895
auto m = specifics->integerArgumentType(func.getLoc(), intTy);
883896
assert(m.size() == 1);
884897
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
@@ -887,7 +900,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
887900
if (!extensionAttrName.empty() &&
888901
isFuncWithCCallingConvention(func))
889902
fixups.emplace_back(FixupTy::Codes::ArgumentType, argNo,
890-
[=](mlir::func::FuncOp func) {
903+
[=](FuncOpTy func) {
891904
func.setArgAttr(
892905
argNo, extensionAttrName,
893906
mlir::UnitAttr::get(func.getContext()));
@@ -903,8 +916,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
903916
fir::CodeGenSpecifics::getTypeAndAttr(ty));
904917
});
905918

906-
if (func.getArgAttrOfType<mlir::UnitAttr>(index,
907-
fir::getHostAssocAttrName())) {
919+
if (func.template getArgAttrOfType<mlir::UnitAttr>(
920+
index, fir::getHostAssocAttrName())) {
908921
extraAttrs.push_back(
909922
{newInTyAndAttrs.size() - 1,
910923
rewriter->getNamedAttr("llvm.nest", rewriter->getUnitAttr())});
@@ -979,27 +992,27 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
979992
auto newArg =
980993
func.front().insertArgument(fixup.index, fixupType, loc);
981994
offset++;
982-
func.walk([&](mlir::func::ReturnOp ret) {
995+
func.walk([&](ReturnOpTy ret) {
983996
rewriter->setInsertionPoint(ret);
984997
auto oldOper = ret.getOperand(0);
985998
auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
986999
auto cast =
9871000
rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
9881001
rewriter->create<fir::StoreOp>(loc, oldOper, cast);
989-
rewriter->create<mlir::func::ReturnOp>(loc);
1002+
rewriter->create<ReturnOpTy>(loc);
9901003
ret.erase();
9911004
});
9921005
} break;
9931006
case FixupTy::Codes::ReturnType: {
9941007
// The function is still returning a value, but its type has likely
9951008
// changed to suit the target ABI convention.
996-
func.walk([&](mlir::func::ReturnOp ret) {
1009+
func.walk([&](ReturnOpTy ret) {
9971010
rewriter->setInsertionPoint(ret);
9981011
auto oldOper = ret.getOperand(0);
9991012
mlir::Value bitcast =
10001013
convertValueInMemory(loc, oldOper, newResTys[fixup.index],
10011014
/*inputMayBeBigger=*/false);
1002-
rewriter->create<mlir::func::ReturnOp>(loc, bitcast);
1015+
rewriter->create<ReturnOpTy>(loc, bitcast);
10031016
ret.erase();
10041017
});
10051018
} break;
@@ -1101,13 +1114,18 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11011114
}
11021115
}
11031116

1104-
for (auto &fixup : fixups)
1105-
if (fixup.finalizer)
1106-
(*fixup.finalizer)(func);
1117+
for (auto &fixup : fixups) {
1118+
if constexpr (std::is_same_v<FuncOpTy, mlir::func::FuncOp>)
1119+
if (fixup.finalizer)
1120+
(*fixup.finalizer)(func);
1121+
if constexpr (std::is_same_v<FuncOpTy, mlir::gpu::GPUFuncOp>)
1122+
if (fixup.gpuFinalizer)
1123+
(*fixup.gpuFinalizer)(func);
1124+
}
11071125
}
11081126

1109-
template <typename Ty, typename FIXUPS>
1110-
void doReturn(mlir::func::FuncOp func, Ty &newResTys,
1127+
template <typename OpTy, typename Ty, typename FIXUPS>
1128+
void doReturn(OpTy func, Ty &newResTys,
11111129
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
11121130
FIXUPS &fixups, fir::CodeGenSpecifics::Marshalling &m) {
11131131
assert(m.size() == 1 &&
@@ -1119,7 +1137,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11191137
unsigned argNo = newInTyAndAttrs.size();
11201138
if (auto align = attr.getAlignment())
11211139
fixups.emplace_back(
1122-
FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::func::FuncOp func) {
1140+
FixupTy::Codes::ReturnAsStore, argNo, [=](OpTy func) {
11231141
auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
11241142
func.getFunctionType().getInput(argNo));
11251143
func.setArgAttr(argNo, "llvm.sret",
@@ -1130,7 +1148,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11301148
});
11311149
else
11321150
fixups.emplace_back(FixupTy::Codes::ReturnAsStore, argNo,
1133-
[=](mlir::func::FuncOp func) {
1151+
[=](OpTy func) {
11341152
auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
11351153
func.getFunctionType().getInput(argNo));
11361154
func.setArgAttr(argNo, "llvm.sret",
@@ -1141,8 +1159,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11411159
}
11421160
if (auto align = attr.getAlignment())
11431161
fixups.emplace_back(
1144-
FixupTy::Codes::ReturnType, newResTys.size(),
1145-
[=](mlir::func::FuncOp func) {
1162+
FixupTy::Codes::ReturnType, newResTys.size(), [=](OpTy func) {
11461163
func.setArgAttr(
11471164
newResTys.size(), "llvm.align",
11481165
rewriter->getIntegerAttr(rewriter->getIntegerType(32), align));
@@ -1155,9 +1172,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11551172
/// Convert a complex return value. This can involve converting the return
11561173
/// value to a "hidden" first argument or packing the complex into a wide
11571174
/// GPR.
1158-
template <typename Ty, typename FIXUPS>
1159-
void doComplexReturn(mlir::func::FuncOp func, mlir::ComplexType cmplx,
1160-
Ty &newResTys,
1175+
template <typename OpTy, typename Ty, typename FIXUPS>
1176+
void doComplexReturn(OpTy func, mlir::ComplexType cmplx, Ty &newResTys,
11611177
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
11621178
FIXUPS &fixups) {
11631179
if (noComplexConversion) {
@@ -1169,9 +1185,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11691185
doReturn(func, newResTys, newInTyAndAttrs, fixups, m);
11701186
}
11711187

1172-
template <typename Ty, typename FIXUPS>
1173-
void doStructReturn(mlir::func::FuncOp func, fir::RecordType recTy,
1174-
Ty &newResTys,
1188+
template <typename OpTy, typename Ty, typename FIXUPS>
1189+
void doStructReturn(OpTy func, fir::RecordType recTy, Ty &newResTys,
11751190
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
11761191
FIXUPS &fixups) {
11771192
if (noStructConversion) {
@@ -1182,12 +1197,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11821197
doReturn(func, newResTys, newInTyAndAttrs, fixups, m);
11831198
}
11841199

1185-
template <typename FIXUPS>
1186-
void
1187-
createFuncOpArgFixups(mlir::func::FuncOp func,
1188-
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1189-
fir::CodeGenSpecifics::Marshalling &argsInTys,
1190-
FIXUPS &fixups) {
1200+
template <typename OpTy, typename FIXUPS>
1201+
void createFuncOpArgFixups(
1202+
OpTy func, fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1203+
fir::CodeGenSpecifics::Marshalling &argsInTys, FIXUPS &fixups) {
11911204
const auto fixupCode = argsInTys.size() > 1 ? FixupTy::Codes::Split
11921205
: FixupTy::Codes::ArgumentType;
11931206
for (auto e : llvm::enumerate(argsInTys)) {
@@ -1198,7 +1211,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11981211
if (attr.isByVal()) {
11991212
if (auto align = attr.getAlignment())
12001213
fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, argNo,
1201-
[=](mlir::func::FuncOp func) {
1214+
[=](OpTy func) {
12021215
auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
12031216
func.getFunctionType().getInput(argNo));
12041217
func.setArgAttr(argNo, "llvm.byval",
@@ -1210,8 +1223,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
12101223
});
12111224
else
12121225
fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad,
1213-
newInTyAndAttrs.size(),
1214-
[=](mlir::func::FuncOp func) {
1226+
newInTyAndAttrs.size(), [=](OpTy func) {
12151227
auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
12161228
func.getFunctionType().getInput(argNo));
12171229
func.setArgAttr(argNo, "llvm.byval",
@@ -1220,7 +1232,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
12201232
} else {
12211233
if (auto align = attr.getAlignment())
12221234
fixups.emplace_back(
1223-
fixupCode, argNo, index, [=](mlir::func::FuncOp func) {
1235+
fixupCode, argNo, index, [=](OpTy func) {
12241236
func.setArgAttr(argNo, "llvm.align",
12251237
rewriter->getIntegerAttr(
12261238
rewriter->getIntegerType(32), align));
@@ -1235,8 +1247,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
12351247
/// Convert a complex argument value. This can involve storing the value to
12361248
/// a temporary memory location or factoring the value into two distinct
12371249
/// arguments.
1238-
template <typename FIXUPS>
1239-
void doComplexArg(mlir::func::FuncOp func, mlir::ComplexType cmplx,
1250+
template <typename OpTy, typename FIXUPS>
1251+
void doComplexArg(OpTy func, mlir::ComplexType cmplx,
12401252
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
12411253
FIXUPS &fixups) {
12421254
if (noComplexConversion) {
@@ -1248,8 +1260,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
12481260
createFuncOpArgFixups(func, newInTyAndAttrs, cplxArgs, fixups);
12491261
}
12501262

1251-
template <typename FIXUPS>
1252-
void doStructArg(mlir::func::FuncOp func, fir::RecordType recTy,
1263+
template <typename OpTy, typename FIXUPS>
1264+
void doStructArg(OpTy func, fir::RecordType recTy,
12531265
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
12541266
FIXUPS &fixups) {
12551267
if (noStructConversion) {

flang/test/Fir/CUDA/cuda-target-rewrite.mlir

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// REQUIRES: x86-registered-target
2-
// RUN: fir-opt --target-rewrite="target=x86_64-unknown-linux-gnu" %s | FileCheck %s
2+
// RUN: fir-opt --split-input-file --target-rewrite="target=x86_64-unknown-linux-gnu" %s | FileCheck %s
33

44
gpu.module @testmod {
55
gpu.func @_QPvcpowdk(%arg0: !fir.ref<complex<f64>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"}) attributes {cuf.proc_attr = #cuf.cuda_proc<global>} {
@@ -15,3 +15,15 @@ gpu.module @testmod {
1515
// CHECK-LABEL: gpu.func @_QPvcpowdk
1616
// CHECK: %{{.*}} = fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}, %{{.*}}) : (f64, f64, i64) -> tuple<f64, f64>
1717
// CHECK: func.func private @_FortranAzpowk(f64, f64, i64) -> tuple<f64, f64> attributes {fir.bindc_name = "_FortranAzpowk", fir.runtime}
18+
19+
// -----
20+
21+
gpu.module @testmod {
22+
gpu.func @_QPtest(%arg0: complex<f64>) -> (complex<f64>) {
23+
gpu.return %arg0 : complex<f64>
24+
}
25+
}
26+
27+
// CHECK-LABEL: gpu.func @_QPtest
28+
// CHECK-SAME: (%arg0: f64, %arg1: f64) -> tuple<f64, f64> {
29+
// CHECK: gpu.return %{{.*}} : tuple<f64, f64>

0 commit comments

Comments
 (0)