Skip to content

[flang][cuda] Update target rewrite to work on gpu.func #119283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 61 additions & 49 deletions flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,21 @@ struct FixupTy {
FixupTy(Codes code, std::size_t index,
std::function<void(mlir::func::FuncOp)> &&finalizer)
: code{code}, index{index}, finalizer{finalizer} {}
FixupTy(Codes code, std::size_t index,
std::function<void(mlir::gpu::GPUFuncOp)> &&finalizer)
: code{code}, index{index}, gpuFinalizer{finalizer} {}
FixupTy(Codes code, std::size_t index, std::size_t second,
std::function<void(mlir::func::FuncOp)> &&finalizer)
: code{code}, index{index}, second{second}, finalizer{finalizer} {}
FixupTy(Codes code, std::size_t index, std::size_t second,
std::function<void(mlir::gpu::GPUFuncOp)> &&finalizer)
: code{code}, index{index}, second{second}, gpuFinalizer{finalizer} {}

Codes code;
std::size_t index;
std::size_t second{};
std::optional<std::function<void(mlir::func::FuncOp)>> finalizer{};
std::optional<std::function<void(mlir::gpu::GPUFuncOp)>> gpuFinalizer{};
}; // namespace

/// Target-specific rewriting of the FIR. This is a prerequisite pass to code
Expand Down Expand Up @@ -719,12 +726,15 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
if (targetFeaturesAttr)
fn->setAttr("target_features", targetFeaturesAttr);

convertSignature(fn);
convertSignature<mlir::func::ReturnOp, mlir::func::FuncOp>(fn);
}

for (auto gpuMod : mod.getOps<mlir::gpu::GPUModuleOp>())
for (auto gpuMod : mod.getOps<mlir::gpu::GPUModuleOp>()) {
for (auto fn : gpuMod.getOps<mlir::func::FuncOp>())
convertSignature(fn);
convertSignature<mlir::func::ReturnOp, mlir::func::FuncOp>(fn);
for (auto fn : gpuMod.getOps<mlir::gpu::GPUFuncOp>())
convertSignature<mlir::gpu::ReturnOp, mlir::gpu::GPUFuncOp>(fn);
}

return mlir::success();
}
Expand Down Expand Up @@ -770,17 +780,20 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {

/// Determine if the signature has host associations. The host association
/// argument may need special target specific rewriting.
static bool hasHostAssociations(mlir::func::FuncOp func) {
template <typename OpTy>
static bool hasHostAssociations(OpTy func) {
std::size_t end = func.getFunctionType().getInputs().size();
for (std::size_t i = 0; i < end; ++i)
if (func.getArgAttrOfType<mlir::UnitAttr>(i, fir::getHostAssocAttrName()))
if (func.template getArgAttrOfType<mlir::UnitAttr>(
i, fir::getHostAssocAttrName()))
return true;
return false;
}

/// Rewrite the signatures and body of the `FuncOp`s in the module for
/// the immediately subsequent target code gen.
void convertSignature(mlir::func::FuncOp func) {
template <typename ReturnOpTy, typename FuncOpTy>
void convertSignature(FuncOpTy func) {
auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType());
if (hasPortableSignature(funcTy, func) && !hasHostAssociations(func))
return;
Expand All @@ -805,13 +818,13 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
// Convert return value(s)
for (auto ty : funcTy.getResults())
llvm::TypeSwitch<mlir::Type>(ty)
.Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
.template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
if (noComplexConversion)
newResTys.push_back(cmplx);
else
doComplexReturn(func, cmplx, newResTys, newInTyAndAttrs, fixups);
})
.Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
.template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
auto m = specifics->integerArgumentType(func.getLoc(), intTy);
assert(m.size() == 1);
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
Expand All @@ -825,7 +838,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
rewriter->getUnitAttr()));
newResTys.push_back(retTy);
})
.Case<fir::RecordType>([&](fir::RecordType recTy) {
.template Case<fir::RecordType>([&](fir::RecordType recTy) {
doStructReturn(func, recTy, newResTys, newInTyAndAttrs, fixups);
})
.Default([&](mlir::Type ty) { newResTys.push_back(ty); });
Expand All @@ -840,7 +853,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
auto ty = e.value();
unsigned index = e.index();
llvm::TypeSwitch<mlir::Type>(ty)
.Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) {
.template Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) {
if (noCharacterConversion) {
newInTyAndAttrs.push_back(
fir::CodeGenSpecifics::getTypeAndAttr(boxTy));
Expand All @@ -863,10 +876,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
}
}
})
.Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
.template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
doComplexArg(func, cmplx, newInTyAndAttrs, fixups);
})
.Case<mlir::TupleType>([&](mlir::TupleType tuple) {
.template Case<mlir::TupleType>([&](mlir::TupleType tuple) {
if (fir::isCharacterProcedureTuple(tuple)) {
fixups.emplace_back(FixupTy::Codes::TrailingCharProc,
newInTyAndAttrs.size(), trailingTys.size());
Expand All @@ -878,7 +891,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
fir::CodeGenSpecifics::getTypeAndAttr(ty));
}
})
.Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
.template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
auto m = specifics->integerArgumentType(func.getLoc(), intTy);
assert(m.size() == 1);
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
Expand All @@ -887,7 +900,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
if (!extensionAttrName.empty() &&
isFuncWithCCallingConvention(func))
fixups.emplace_back(FixupTy::Codes::ArgumentType, argNo,
[=](mlir::func::FuncOp func) {
[=](FuncOpTy func) {
func.setArgAttr(
argNo, extensionAttrName,
mlir::UnitAttr::get(func.getContext()));
Expand All @@ -903,8 +916,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
fir::CodeGenSpecifics::getTypeAndAttr(ty));
});

if (func.getArgAttrOfType<mlir::UnitAttr>(index,
fir::getHostAssocAttrName())) {
if (func.template getArgAttrOfType<mlir::UnitAttr>(
index, fir::getHostAssocAttrName())) {
extraAttrs.push_back(
{newInTyAndAttrs.size() - 1,
rewriter->getNamedAttr("llvm.nest", rewriter->getUnitAttr())});
Expand Down Expand Up @@ -979,27 +992,27 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
auto newArg =
func.front().insertArgument(fixup.index, fixupType, loc);
offset++;
func.walk([&](mlir::func::ReturnOp ret) {
func.walk([&](ReturnOpTy ret) {
rewriter->setInsertionPoint(ret);
auto oldOper = ret.getOperand(0);
auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
auto cast =
rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
rewriter->create<fir::StoreOp>(loc, oldOper, cast);
rewriter->create<mlir::func::ReturnOp>(loc);
rewriter->create<ReturnOpTy>(loc);
ret.erase();
});
} break;
case FixupTy::Codes::ReturnType: {
// The function is still returning a value, but its type has likely
// changed to suit the target ABI convention.
func.walk([&](mlir::func::ReturnOp ret) {
func.walk([&](ReturnOpTy ret) {
rewriter->setInsertionPoint(ret);
auto oldOper = ret.getOperand(0);
mlir::Value bitcast =
convertValueInMemory(loc, oldOper, newResTys[fixup.index],
/*inputMayBeBigger=*/false);
rewriter->create<mlir::func::ReturnOp>(loc, bitcast);
rewriter->create<ReturnOpTy>(loc, bitcast);
ret.erase();
});
} break;
Expand Down Expand Up @@ -1101,13 +1114,18 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
}
}

for (auto &fixup : fixups)
if (fixup.finalizer)
(*fixup.finalizer)(func);
for (auto &fixup : fixups) {
if constexpr (std::is_same_v<FuncOpTy, mlir::func::FuncOp>)
if (fixup.finalizer)
(*fixup.finalizer)(func);
if constexpr (std::is_same_v<FuncOpTy, mlir::gpu::GPUFuncOp>)
if (fixup.gpuFinalizer)
(*fixup.gpuFinalizer)(func);
}
}

template <typename Ty, typename FIXUPS>
void doReturn(mlir::func::FuncOp func, Ty &newResTys,
template <typename OpTy, typename Ty, typename FIXUPS>
void doReturn(OpTy func, Ty &newResTys,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
FIXUPS &fixups, fir::CodeGenSpecifics::Marshalling &m) {
assert(m.size() == 1 &&
Expand All @@ -1119,7 +1137,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
unsigned argNo = newInTyAndAttrs.size();
if (auto align = attr.getAlignment())
fixups.emplace_back(
FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::func::FuncOp func) {
FixupTy::Codes::ReturnAsStore, argNo, [=](OpTy func) {
auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
func.getFunctionType().getInput(argNo));
func.setArgAttr(argNo, "llvm.sret",
Expand All @@ -1130,7 +1148,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
});
else
fixups.emplace_back(FixupTy::Codes::ReturnAsStore, argNo,
[=](mlir::func::FuncOp func) {
[=](OpTy func) {
auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
func.getFunctionType().getInput(argNo));
func.setArgAttr(argNo, "llvm.sret",
Expand All @@ -1141,8 +1159,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
}
if (auto align = attr.getAlignment())
fixups.emplace_back(
FixupTy::Codes::ReturnType, newResTys.size(),
[=](mlir::func::FuncOp func) {
FixupTy::Codes::ReturnType, newResTys.size(), [=](OpTy func) {
func.setArgAttr(
newResTys.size(), "llvm.align",
rewriter->getIntegerAttr(rewriter->getIntegerType(32), align));
Expand All @@ -1155,9 +1172,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
/// Convert a complex return value. This can involve converting the return
/// value to a "hidden" first argument or packing the complex into a wide
/// GPR.
template <typename Ty, typename FIXUPS>
void doComplexReturn(mlir::func::FuncOp func, mlir::ComplexType cmplx,
Ty &newResTys,
template <typename OpTy, typename Ty, typename FIXUPS>
void doComplexReturn(OpTy func, mlir::ComplexType cmplx, Ty &newResTys,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
FIXUPS &fixups) {
if (noComplexConversion) {
Expand All @@ -1169,9 +1185,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
doReturn(func, newResTys, newInTyAndAttrs, fixups, m);
}

template <typename Ty, typename FIXUPS>
void doStructReturn(mlir::func::FuncOp func, fir::RecordType recTy,
Ty &newResTys,
template <typename OpTy, typename Ty, typename FIXUPS>
void doStructReturn(OpTy func, fir::RecordType recTy, Ty &newResTys,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
FIXUPS &fixups) {
if (noStructConversion) {
Expand All @@ -1182,12 +1197,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
doReturn(func, newResTys, newInTyAndAttrs, fixups, m);
}

template <typename FIXUPS>
void
createFuncOpArgFixups(mlir::func::FuncOp func,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
fir::CodeGenSpecifics::Marshalling &argsInTys,
FIXUPS &fixups) {
template <typename OpTy, typename FIXUPS>
void createFuncOpArgFixups(
OpTy func, fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
fir::CodeGenSpecifics::Marshalling &argsInTys, FIXUPS &fixups) {
const auto fixupCode = argsInTys.size() > 1 ? FixupTy::Codes::Split
: FixupTy::Codes::ArgumentType;
for (auto e : llvm::enumerate(argsInTys)) {
Expand All @@ -1198,7 +1211,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
if (attr.isByVal()) {
if (auto align = attr.getAlignment())
fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, argNo,
[=](mlir::func::FuncOp func) {
[=](OpTy func) {
auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
func.getFunctionType().getInput(argNo));
func.setArgAttr(argNo, "llvm.byval",
Expand All @@ -1210,8 +1223,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
});
else
fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad,
newInTyAndAttrs.size(),
[=](mlir::func::FuncOp func) {
newInTyAndAttrs.size(), [=](OpTy func) {
auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
func.getFunctionType().getInput(argNo));
func.setArgAttr(argNo, "llvm.byval",
Expand All @@ -1220,7 +1232,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
} else {
if (auto align = attr.getAlignment())
fixups.emplace_back(
fixupCode, argNo, index, [=](mlir::func::FuncOp func) {
fixupCode, argNo, index, [=](OpTy func) {
func.setArgAttr(argNo, "llvm.align",
rewriter->getIntegerAttr(
rewriter->getIntegerType(32), align));
Expand All @@ -1235,8 +1247,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
/// Convert a complex argument value. This can involve storing the value to
/// a temporary memory location or factoring the value into two distinct
/// arguments.
template <typename FIXUPS>
void doComplexArg(mlir::func::FuncOp func, mlir::ComplexType cmplx,
template <typename OpTy, typename FIXUPS>
void doComplexArg(OpTy func, mlir::ComplexType cmplx,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
FIXUPS &fixups) {
if (noComplexConversion) {
Expand All @@ -1248,8 +1260,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
createFuncOpArgFixups(func, newInTyAndAttrs, cplxArgs, fixups);
}

template <typename FIXUPS>
void doStructArg(mlir::func::FuncOp func, fir::RecordType recTy,
template <typename OpTy, typename FIXUPS>
void doStructArg(OpTy func, fir::RecordType recTy,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
FIXUPS &fixups) {
if (noStructConversion) {
Expand Down
14 changes: 13 additions & 1 deletion flang/test/Fir/CUDA/cuda-target-rewrite.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// REQUIRES: x86-registered-target
// RUN: fir-opt --target-rewrite="target=x86_64-unknown-linux-gnu" %s | FileCheck %s
// RUN: fir-opt --split-input-file --target-rewrite="target=x86_64-unknown-linux-gnu" %s | FileCheck %s

gpu.module @testmod {
gpu.func @_QPvcpowdk(%arg0: !fir.ref<complex<f64>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"}) attributes {cuf.proc_attr = #cuf.cuda_proc<global>} {
Expand All @@ -15,3 +15,15 @@ gpu.module @testmod {
// CHECK-LABEL: gpu.func @_QPvcpowdk
// CHECK: %{{.*}} = fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}, %{{.*}}) : (f64, f64, i64) -> tuple<f64, f64>
// CHECK: func.func private @_FortranAzpowk(f64, f64, i64) -> tuple<f64, f64> attributes {fir.bindc_name = "_FortranAzpowk", fir.runtime}

// -----

gpu.module @testmod {
gpu.func @_QPtest(%arg0: complex<f64>) -> (complex<f64>) {
gpu.return %arg0 : complex<f64>
}
}

// CHECK-LABEL: gpu.func @_QPtest
// CHECK-SAME: (%arg0: f64, %arg1: f64) -> tuple<f64, f64> {
// CHECK: gpu.return %{{.*}} : tuple<f64, f64>
Loading