Skip to content

[flang] correctly deal with bind(c) derived type result ABI #111678

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions flang/include/flang/Optimizer/CodeGen/Target.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ class CodeGenSpecifics {
structArgumentType(mlir::Location loc, fir::RecordType recTy,
const Marshalling &previousArguments) const = 0;

/// Type representation of a `fir.type<T>` type argument when returned by
/// value. Such value may need to be converted to a hidden reference argument.
virtual Marshalling structReturnType(mlir::Location loc,
fir::RecordType eleTy) const = 0;

/// Type representation of a `boxchar<n>` type argument when passed by value.
/// An argument value may need to be passed as a (safe) reference argument.
///
Expand Down
21 changes: 21 additions & 0 deletions flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,27 @@ inline mlir::NamedAttribute getAdaptToByRefAttr(Builder &builder) {
}

bool isDummyArgument(mlir::Value v);

template <fir::FortranProcedureFlagsEnum Flag>
inline bool hasProcedureAttr(fir::FortranProcedureFlagsEnumAttr flags) {
return flags && bitEnumContainsAny(flags.getValue(), Flag);
}

template <fir::FortranProcedureFlagsEnum Flag>
inline bool hasProcedureAttr(mlir::Operation *op) {
if (auto firCallOp = mlir::dyn_cast<fir::CallOp>(op))
return hasProcedureAttr<Flag>(firCallOp.getProcedureAttrsAttr());
if (auto firCallOp = mlir::dyn_cast<fir::DispatchOp>(op))
return hasProcedureAttr<Flag>(firCallOp.getProcedureAttrsAttr());
return hasProcedureAttr<Flag>(
op->getAttrOfType<fir::FortranProcedureFlagsEnumAttr>(
getFortranProcedureFlagsAttrName()));
}

inline bool hasBindcAttr(mlir::Operation *op) {
return hasProcedureAttr<fir::FortranProcedureFlagsEnum::bind_c>(op);
}

} // namespace fir

#endif // FORTRAN_OPTIMIZER_DIALECT_FIROPSSUPPORT_H
68 changes: 62 additions & 6 deletions flang/lib/Optimizer/CodeGen/Target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ struct GenericTarget : public CodeGenSpecifics {
TODO(loc, "passing VALUE BIND(C) derived type for this target");
}

CodeGenSpecifics::Marshalling
structReturnType(mlir::Location loc, fir::RecordType ty) const override {
TODO(loc, "returning BIND(C) derived type for this target");
}

CodeGenSpecifics::Marshalling
integerArgumentType(mlir::Location loc,
mlir::IntegerType argTy) const override {
Expand Down Expand Up @@ -533,14 +538,17 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
/// When \p recTy is a one field record type that can be passed
/// like the field on its own, returns the field type. Returns
/// a null type otherwise.
mlir::Type passAsFieldIfOneFieldStruct(fir::RecordType recTy) const {
mlir::Type passAsFieldIfOneFieldStruct(fir::RecordType recTy,
bool allowComplex = false) const {
auto typeList = recTy.getTypeList();
if (typeList.size() != 1)
return {};
mlir::Type fieldType = typeList[0].second;
if (mlir::isa<mlir::FloatType, mlir::IntegerType, fir::LogicalType>(
fieldType))
return fieldType;
if (allowComplex && mlir::isa<mlir::ComplexType>(fieldType))
return fieldType;
if (mlir::isa<fir::CharacterType>(fieldType)) {
// Only CHARACTER(1) are expected in BIND(C) contexts, which is the only
// contexts where derived type may be passed in registers.
Expand Down Expand Up @@ -593,7 +601,7 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
postMerge(byteOffset, Lo, Hi);
if (Lo == ArgClass::Memory || Lo == ArgClass::X87 ||
Lo == ArgClass::ComplexX87)
return passOnTheStack(loc, recTy);
return passOnTheStack(loc, recTy, /*isResult=*/false);
int neededIntRegisters = 0;
int neededSSERegisters = 0;
if (Lo == ArgClass::SSE)
Expand All @@ -609,7 +617,7 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
// all in registers or all on the stack).
if (!hasEnoughRegisters(loc, neededIntRegisters, neededSSERegisters,
previousArguments))
return passOnTheStack(loc, recTy);
return passOnTheStack(loc, recTy, /*isResult=*/false);

if (auto fieldType = passAsFieldIfOneFieldStruct(recTy)) {
CodeGenSpecifics::Marshalling marshal;
Expand Down Expand Up @@ -641,17 +649,65 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
return marshal;
}

CodeGenSpecifics::Marshalling
structReturnType(mlir::Location loc, fir::RecordType recTy) const override {
std::uint64_t byteOffset = 0;
ArgClass Lo, Hi;
Lo = Hi = ArgClass::NoClass;
byteOffset = classifyStruct(loc, recTy, byteOffset, Lo, Hi);
mlir::MLIRContext *context = recTy.getContext();
postMerge(byteOffset, Lo, Hi);
if (Lo == ArgClass::Memory)
return passOnTheStack(loc, recTy, /*isResult=*/true);

// Note that X87/ComplexX87 are passed in memory, but returned via %st0
// %st1 registers. Here, they are returned as fp80 or {fp80, fp80} by
// passAsFieldIfOneFieldStruct, and LLVM will use the expected registers.

// Note that {_Complex long double} is not 100% clear from an ABI
// perspective because the aggregate post merger rules say it should be
// passed in memory because it is bigger than 2 eight bytes. This has the
// funny effect of
// {_Complex long double} return to be dealt with differently than
// _Complex long double.

if (auto fieldType =
passAsFieldIfOneFieldStruct(recTy, /*allowComplex=*/true)) {
if (auto complexType = mlir::dyn_cast<mlir::ComplexType>(fieldType))
return complexReturnType(loc, complexType.getElementType());
CodeGenSpecifics::Marshalling marshal;
marshal.emplace_back(fieldType, AT{});
return marshal;
}

if (Hi == ArgClass::NoClass || Hi == ArgClass::SSEUp) {
// Return a single integer or floating point argument.
mlir::Type lowType = pickLLVMArgType(loc, context, Lo, byteOffset);
CodeGenSpecifics::Marshalling marshal;
marshal.emplace_back(lowType, AT{});
return marshal;
}
// Will be returned in two different registers. Generate {lowTy, HiTy} for
// the LLVM IR result type.
CodeGenSpecifics::Marshalling marshal;
mlir::Type lowType = pickLLVMArgType(loc, context, Lo, 8u);
mlir::Type hiType = pickLLVMArgType(loc, context, Hi, byteOffset - 8u);
marshal.emplace_back(mlir::TupleType::get(context, {lowType, hiType}),
AT{});
return marshal;
}

/// Marshal an argument that must be passed on the stack.
CodeGenSpecifics::Marshalling passOnTheStack(mlir::Location loc,
mlir::Type ty) const {
CodeGenSpecifics::Marshalling
passOnTheStack(mlir::Location loc, mlir::Type ty, bool isResult) const {
CodeGenSpecifics::Marshalling marshal;
auto sizeAndAlign =
fir::getTypeSizeAndAlignmentOrCrash(loc, ty, getDataLayout(), kindMap);
// The stack is always 8 byte aligned (note 14 in 3.2.3).
unsigned short align =
std::max(sizeAndAlign.second, static_cast<unsigned short>(8));
marshal.emplace_back(fir::ReferenceType::get(ty),
AT{align, /*byval=*/true, /*sret=*/false});
AT{align, /*byval=*/!isResult, /*sret=*/isResult});
return marshal;
}
};
Expand Down
137 changes: 107 additions & 30 deletions flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,20 +142,16 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {

mlir::ModuleOp getModule() { return getOperation(); }

template <typename A, typename B, typename C>
template <typename Ty, typename Callback>
std::optional<std::function<mlir::Value(mlir::Operation *)>>
rewriteCallComplexResultType(
mlir::Location loc, A ty, B &newResTys,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, C &newOpers,
mlir::Value &savedStackPtr) {
if (noComplexConversion) {
newResTys.push_back(ty);
return std::nullopt;
}
auto m = specifics->complexReturnType(loc, ty.getElementType());
// Currently targets mandate COMPLEX is a single aggregate or packed
// scalar, including the sret case.
assert(m.size() == 1 && "target of complex return not supported");
rewriteCallResultType(mlir::Location loc, mlir::Type originalResTy,
Ty &newResTys,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
Callback &newOpers, mlir::Value &savedStackPtr,
fir::CodeGenSpecifics::Marshalling &m) {
// Currently, targets mandate COMPLEX or STRUCT is a single aggregate or
// packed scalar, including the sret case.
assert(m.size() == 1 && "return type not supported on this target");
auto resTy = std::get<mlir::Type>(m[0]);
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
if (attr.isSRet()) {
Expand All @@ -170,7 +166,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
newInTyAndAttrs.push_back(m[0]);
newOpers.push_back(stack);
return [=](mlir::Operation *) -> mlir::Value {
auto memTy = fir::ReferenceType::get(ty);
auto memTy = fir::ReferenceType::get(originalResTy);
auto cast = rewriter->create<fir::ConvertOp>(loc, memTy, stack);
return rewriter->create<fir::LoadOp>(loc, cast);
};
Expand All @@ -180,11 +176,41 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
// We are going to generate an alloca, so save the stack pointer.
if (!savedStackPtr)
savedStackPtr = genStackSave(loc);
return this->convertValueInMemory(loc, call->getResult(0), ty,
return this->convertValueInMemory(loc, call->getResult(0), originalResTy,
/*inputMayBeBigger=*/true);
};
}

template <typename Ty, typename Callback>
std::optional<std::function<mlir::Value(mlir::Operation *)>>
rewriteCallComplexResultType(
mlir::Location loc, mlir::ComplexType ty, Ty &newResTys,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, Callback &newOpers,
mlir::Value &savedStackPtr) {
if (noComplexConversion) {
newResTys.push_back(ty);
return std::nullopt;
}
auto m = specifics->complexReturnType(loc, ty.getElementType());
return rewriteCallResultType(loc, ty, newResTys, newInTyAndAttrs, newOpers,
savedStackPtr, m);
}

template <typename Ty, typename Callback>
std::optional<std::function<mlir::Value(mlir::Operation *)>>
rewriteCallStructResultType(
mlir::Location loc, fir::RecordType recTy, Ty &newResTys,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, Callback &newOpers,
mlir::Value &savedStackPtr) {
if (noStructConversion) {
newResTys.push_back(recTy);
return std::nullopt;
}
auto m = specifics->structReturnType(loc, recTy);
return rewriteCallResultType(loc, recTy, newResTys, newInTyAndAttrs,
newOpers, savedStackPtr, m);
}

void passArgumentOnStackOrWithNewType(
mlir::Location loc, fir::CodeGenSpecifics::TypeAndAttr newTypeAndAttr,
mlir::Type oldType, mlir::Value oper,
Expand Down Expand Up @@ -356,6 +382,11 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
newInTyAndAttrs, newOpers,
savedStackPtr);
})
.template Case<fir::RecordType>([&](fir::RecordType recTy) {
wrap = rewriteCallStructResultType(loc, recTy, newResTys,
newInTyAndAttrs, newOpers,
savedStackPtr);
})
.Default([&](mlir::Type ty) { newResTys.push_back(ty); });
} else if (fnTy.getResults().size() > 1) {
TODO(loc, "multiple results not supported yet");
Expand Down Expand Up @@ -562,6 +593,24 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
}
}

template <typename Ty>
void
lowerStructSignatureRes(mlir::Location loc, fir::RecordType recTy,
Ty &newResTys,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
if (noComplexConversion) {
newResTys.push_back(recTy);
return;
} else {
for (auto &tup : specifics->structReturnType(loc, recTy)) {
if (std::get<fir::CodeGenSpecifics::Attributes>(tup).isSRet())
newInTyAndAttrs.push_back(tup);
else
newResTys.push_back(std::get<mlir::Type>(tup));
}
}
}

void
lowerStructSignatureArg(mlir::Location loc, fir::RecordType recTy,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
Expand Down Expand Up @@ -595,6 +644,9 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
.Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
lowerComplexSignatureRes(loc, ty, newResTys, newInTyAndAttrs);
})
.Case<fir::RecordType>([&](fir::RecordType ty) {
lowerStructSignatureRes(loc, ty, newResTys, newInTyAndAttrs);
})
.Default([&](mlir::Type ty) { newResTys.push_back(ty); });
}
llvm::SmallVector<mlir::Type> trailingInTys;
Expand Down Expand Up @@ -696,7 +748,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
for (auto ty : func.getResults())
if ((mlir::isa<fir::BoxCharType>(ty) && !noCharacterConversion) ||
(fir::isa_complex(ty) && !noComplexConversion) ||
(mlir::isa<mlir::IntegerType>(ty) && hasCCallingConv)) {
(mlir::isa<mlir::IntegerType>(ty) && hasCCallingConv) ||
(mlir::isa<fir::RecordType>(ty) && !noStructConversion)) {
LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n");
return false;
}
Expand Down Expand Up @@ -770,6 +823,9 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
rewriter->getUnitAttr()));
newResTys.push_back(retTy);
})
.Case<fir::RecordType>([&](fir::RecordType recTy) {
doStructReturn(func, recTy, newResTys, newInTyAndAttrs, fixups);
})
.Default([&](mlir::Type ty) { newResTys.push_back(ty); });

// Saved potential shift in argument. Handling of result can add arguments
Expand Down Expand Up @@ -1062,21 +1118,12 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
return false;
}

/// Convert a complex return value. This can involve converting the return
/// value to a "hidden" first argument or packing the complex into a wide
/// GPR.
template <typename Ty, typename FIXUPS>
void doComplexReturn(mlir::func::FuncOp func, mlir::ComplexType cmplx,
Ty &newResTys,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
FIXUPS &fixups) {
if (noComplexConversion) {
newResTys.push_back(cmplx);
return;
}
auto m =
specifics->complexReturnType(func.getLoc(), cmplx.getElementType());
assert(m.size() == 1);
void doReturn(mlir::func::FuncOp func, Ty &newResTys,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
FIXUPS &fixups, fir::CodeGenSpecifics::Marshalling &m) {
assert(m.size() == 1 &&
"expect result to be turned into single argument or result so far");
auto &tup = m[0];
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup);
auto argTy = std::get<mlir::Type>(tup);
Expand Down Expand Up @@ -1117,6 +1164,36 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
newResTys.push_back(argTy);
}

/// Convert a complex return value. This can involve converting the return
/// value to a "hidden" first argument or packing the complex into a wide
/// GPR.
template <typename Ty, typename FIXUPS>
void doComplexReturn(mlir::func::FuncOp func, mlir::ComplexType cmplx,
Ty &newResTys,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
FIXUPS &fixups) {
if (noComplexConversion) {
newResTys.push_back(cmplx);
return;
}
auto m =
specifics->complexReturnType(func.getLoc(), cmplx.getElementType());
doReturn(func, newResTys, newInTyAndAttrs, fixups, m);
}

template <typename Ty, typename FIXUPS>
void doStructReturn(mlir::func::FuncOp func, fir::RecordType recTy,
Ty &newResTys,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
FIXUPS &fixups) {
if (noStructConversion) {
newResTys.push_back(recTy);
return;
}
auto m = specifics->structReturnType(func.getLoc(), recTy);
doReturn(func, newResTys, newInTyAndAttrs, fixups, m);
}

template <typename FIXUPS>
void
createFuncOpArgFixups(mlir::func::FuncOp func,
Expand Down
Loading
Loading