Skip to content

Commit 480e7f0

Browse files
authored
[flang] correctly deal with bind(c) derived type result ABI (#111678)
Derived type results of BIND(C) function should be returned according the the C ABI for returning the related C struct type. This currently did not happen since the abstract-result pass was forcing the Fortran ABI for all derived type results. use the bind_c attribute that was added on call/func/dispatch in FIR to prevent such rewrite in the abstract result pass, and update the target-rewrite pass to deal with the struct return ABI. So far, the target specific part of the target-rewrite is only implemented for X86-64 according to the "System V Application Binary Interface AMD64 v1", the other targets will hit a TODO, just like for BIND(C), VALUE derived type arguments. This intends to deal with #102113.
1 parent 9014920 commit 480e7f0

File tree

7 files changed

+419
-40
lines changed

7 files changed

+419
-40
lines changed

flang/include/flang/Optimizer/CodeGen/Target.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ class CodeGenSpecifics {
126126
structArgumentType(mlir::Location loc, fir::RecordType recTy,
127127
const Marshalling &previousArguments) const = 0;
128128

129+
/// Type representation of a `fir.type<T>` type argument when returned by
130+
/// value. Such value may need to be converted to a hidden reference argument.
131+
virtual Marshalling structReturnType(mlir::Location loc,
132+
fir::RecordType eleTy) const = 0;
133+
129134
/// Type representation of a `boxchar<n>` type argument when passed by value.
130135
/// An argument value may need to be passed as a (safe) reference argument.
131136
///

flang/include/flang/Optimizer/Dialect/FIROpsSupport.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,27 @@ inline mlir::NamedAttribute getAdaptToByRefAttr(Builder &builder) {
177177
}
178178

179179
bool isDummyArgument(mlir::Value v);
180+
181+
template <fir::FortranProcedureFlagsEnum Flag>
182+
inline bool hasProcedureAttr(fir::FortranProcedureFlagsEnumAttr flags) {
183+
return flags && bitEnumContainsAny(flags.getValue(), Flag);
184+
}
185+
186+
template <fir::FortranProcedureFlagsEnum Flag>
187+
inline bool hasProcedureAttr(mlir::Operation *op) {
188+
if (auto firCallOp = mlir::dyn_cast<fir::CallOp>(op))
189+
return hasProcedureAttr<Flag>(firCallOp.getProcedureAttrsAttr());
190+
if (auto firCallOp = mlir::dyn_cast<fir::DispatchOp>(op))
191+
return hasProcedureAttr<Flag>(firCallOp.getProcedureAttrsAttr());
192+
return hasProcedureAttr<Flag>(
193+
op->getAttrOfType<fir::FortranProcedureFlagsEnumAttr>(
194+
getFortranProcedureFlagsAttrName()));
195+
}
196+
197+
inline bool hasBindcAttr(mlir::Operation *op) {
198+
return hasProcedureAttr<fir::FortranProcedureFlagsEnum::bind_c>(op);
199+
}
200+
180201
} // namespace fir
181202

182203
#endif // FORTRAN_OPTIMIZER_DIALECT_FIROPSSUPPORT_H

flang/lib/Optimizer/CodeGen/Target.cpp

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ struct GenericTarget : public CodeGenSpecifics {
100100
TODO(loc, "passing VALUE BIND(C) derived type for this target");
101101
}
102102

103+
CodeGenSpecifics::Marshalling
104+
structReturnType(mlir::Location loc, fir::RecordType ty) const override {
105+
TODO(loc, "returning BIND(C) derived type for this target");
106+
}
107+
103108
CodeGenSpecifics::Marshalling
104109
integerArgumentType(mlir::Location loc,
105110
mlir::IntegerType argTy) const override {
@@ -533,14 +538,17 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
533538
/// When \p recTy is a one field record type that can be passed
534539
/// like the field on its own, returns the field type. Returns
535540
/// a null type otherwise.
536-
mlir::Type passAsFieldIfOneFieldStruct(fir::RecordType recTy) const {
541+
mlir::Type passAsFieldIfOneFieldStruct(fir::RecordType recTy,
542+
bool allowComplex = false) const {
537543
auto typeList = recTy.getTypeList();
538544
if (typeList.size() != 1)
539545
return {};
540546
mlir::Type fieldType = typeList[0].second;
541547
if (mlir::isa<mlir::FloatType, mlir::IntegerType, fir::LogicalType>(
542548
fieldType))
543549
return fieldType;
550+
if (allowComplex && mlir::isa<mlir::ComplexType>(fieldType))
551+
return fieldType;
544552
if (mlir::isa<fir::CharacterType>(fieldType)) {
545553
// Only CHARACTER(1) are expected in BIND(C) contexts, which is the only
546554
// contexts where derived type may be passed in registers.
@@ -593,7 +601,7 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
593601
postMerge(byteOffset, Lo, Hi);
594602
if (Lo == ArgClass::Memory || Lo == ArgClass::X87 ||
595603
Lo == ArgClass::ComplexX87)
596-
return passOnTheStack(loc, recTy);
604+
return passOnTheStack(loc, recTy, /*isResult=*/false);
597605
int neededIntRegisters = 0;
598606
int neededSSERegisters = 0;
599607
if (Lo == ArgClass::SSE)
@@ -609,7 +617,7 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
609617
// all in registers or all on the stack).
610618
if (!hasEnoughRegisters(loc, neededIntRegisters, neededSSERegisters,
611619
previousArguments))
612-
return passOnTheStack(loc, recTy);
620+
return passOnTheStack(loc, recTy, /*isResult=*/false);
613621

614622
if (auto fieldType = passAsFieldIfOneFieldStruct(recTy)) {
615623
CodeGenSpecifics::Marshalling marshal;
@@ -641,17 +649,65 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
641649
return marshal;
642650
}
643651

652+
CodeGenSpecifics::Marshalling
653+
structReturnType(mlir::Location loc, fir::RecordType recTy) const override {
654+
std::uint64_t byteOffset = 0;
655+
ArgClass Lo, Hi;
656+
Lo = Hi = ArgClass::NoClass;
657+
byteOffset = classifyStruct(loc, recTy, byteOffset, Lo, Hi);
658+
mlir::MLIRContext *context = recTy.getContext();
659+
postMerge(byteOffset, Lo, Hi);
660+
if (Lo == ArgClass::Memory)
661+
return passOnTheStack(loc, recTy, /*isResult=*/true);
662+
663+
// Note that X87/ComplexX87 are passed in memory, but returned via %st0
664+
// %st1 registers. Here, they are returned as fp80 or {fp80, fp80} by
665+
// passAsFieldIfOneFieldStruct, and LLVM will use the expected registers.
666+
667+
// Note that {_Complex long double} is not 100% clear from an ABI
668+
// perspective because the aggregate post merger rules say it should be
669+
// passed in memory because it is bigger than 2 eight bytes. This has the
670+
// funny effect of
671+
// {_Complex long double} return to be dealt with differently than
672+
// _Complex long double.
673+
674+
if (auto fieldType =
675+
passAsFieldIfOneFieldStruct(recTy, /*allowComplex=*/true)) {
676+
if (auto complexType = mlir::dyn_cast<mlir::ComplexType>(fieldType))
677+
return complexReturnType(loc, complexType.getElementType());
678+
CodeGenSpecifics::Marshalling marshal;
679+
marshal.emplace_back(fieldType, AT{});
680+
return marshal;
681+
}
682+
683+
if (Hi == ArgClass::NoClass || Hi == ArgClass::SSEUp) {
684+
// Return a single integer or floating point argument.
685+
mlir::Type lowType = pickLLVMArgType(loc, context, Lo, byteOffset);
686+
CodeGenSpecifics::Marshalling marshal;
687+
marshal.emplace_back(lowType, AT{});
688+
return marshal;
689+
}
690+
// Will be returned in two different registers. Generate {lowTy, HiTy} for
691+
// the LLVM IR result type.
692+
CodeGenSpecifics::Marshalling marshal;
693+
mlir::Type lowType = pickLLVMArgType(loc, context, Lo, 8u);
694+
mlir::Type hiType = pickLLVMArgType(loc, context, Hi, byteOffset - 8u);
695+
marshal.emplace_back(mlir::TupleType::get(context, {lowType, hiType}),
696+
AT{});
697+
return marshal;
698+
}
699+
644700
/// Marshal an argument that must be passed on the stack.
645-
CodeGenSpecifics::Marshalling passOnTheStack(mlir::Location loc,
646-
mlir::Type ty) const {
701+
CodeGenSpecifics::Marshalling
702+
passOnTheStack(mlir::Location loc, mlir::Type ty, bool isResult) const {
647703
CodeGenSpecifics::Marshalling marshal;
648704
auto sizeAndAlign =
649705
fir::getTypeSizeAndAlignmentOrCrash(loc, ty, getDataLayout(), kindMap);
650706
// The stack is always 8 byte aligned (note 14 in 3.2.3).
651707
unsigned short align =
652708
std::max(sizeAndAlign.second, static_cast<unsigned short>(8));
653709
marshal.emplace_back(fir::ReferenceType::get(ty),
654-
AT{align, /*byval=*/true, /*sret=*/false});
710+
AT{align, /*byval=*/!isResult, /*sret=*/isResult});
655711
return marshal;
656712
}
657713
};

flang/lib/Optimizer/CodeGen/TargetRewrite.cpp

Lines changed: 107 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -142,20 +142,16 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
142142

143143
mlir::ModuleOp getModule() { return getOperation(); }
144144

145-
template <typename A, typename B, typename C>
145+
template <typename Ty, typename Callback>
146146
std::optional<std::function<mlir::Value(mlir::Operation *)>>
147-
rewriteCallComplexResultType(
148-
mlir::Location loc, A ty, B &newResTys,
149-
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, C &newOpers,
150-
mlir::Value &savedStackPtr) {
151-
if (noComplexConversion) {
152-
newResTys.push_back(ty);
153-
return std::nullopt;
154-
}
155-
auto m = specifics->complexReturnType(loc, ty.getElementType());
156-
// Currently targets mandate COMPLEX is a single aggregate or packed
157-
// scalar, including the sret case.
158-
assert(m.size() == 1 && "target of complex return not supported");
147+
rewriteCallResultType(mlir::Location loc, mlir::Type originalResTy,
148+
Ty &newResTys,
149+
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
150+
Callback &newOpers, mlir::Value &savedStackPtr,
151+
fir::CodeGenSpecifics::Marshalling &m) {
152+
// Currently, targets mandate COMPLEX or STRUCT is a single aggregate or
153+
// packed scalar, including the sret case.
154+
assert(m.size() == 1 && "return type not supported on this target");
159155
auto resTy = std::get<mlir::Type>(m[0]);
160156
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
161157
if (attr.isSRet()) {
@@ -170,7 +166,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
170166
newInTyAndAttrs.push_back(m[0]);
171167
newOpers.push_back(stack);
172168
return [=](mlir::Operation *) -> mlir::Value {
173-
auto memTy = fir::ReferenceType::get(ty);
169+
auto memTy = fir::ReferenceType::get(originalResTy);
174170
auto cast = rewriter->create<fir::ConvertOp>(loc, memTy, stack);
175171
return rewriter->create<fir::LoadOp>(loc, cast);
176172
};
@@ -180,11 +176,41 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
180176
// We are going to generate an alloca, so save the stack pointer.
181177
if (!savedStackPtr)
182178
savedStackPtr = genStackSave(loc);
183-
return this->convertValueInMemory(loc, call->getResult(0), ty,
179+
return this->convertValueInMemory(loc, call->getResult(0), originalResTy,
184180
/*inputMayBeBigger=*/true);
185181
};
186182
}
187183

184+
template <typename Ty, typename Callback>
185+
std::optional<std::function<mlir::Value(mlir::Operation *)>>
186+
rewriteCallComplexResultType(
187+
mlir::Location loc, mlir::ComplexType ty, Ty &newResTys,
188+
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, Callback &newOpers,
189+
mlir::Value &savedStackPtr) {
190+
if (noComplexConversion) {
191+
newResTys.push_back(ty);
192+
return std::nullopt;
193+
}
194+
auto m = specifics->complexReturnType(loc, ty.getElementType());
195+
return rewriteCallResultType(loc, ty, newResTys, newInTyAndAttrs, newOpers,
196+
savedStackPtr, m);
197+
}
198+
199+
template <typename Ty, typename Callback>
200+
std::optional<std::function<mlir::Value(mlir::Operation *)>>
201+
rewriteCallStructResultType(
202+
mlir::Location loc, fir::RecordType recTy, Ty &newResTys,
203+
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, Callback &newOpers,
204+
mlir::Value &savedStackPtr) {
205+
if (noStructConversion) {
206+
newResTys.push_back(recTy);
207+
return std::nullopt;
208+
}
209+
auto m = specifics->structReturnType(loc, recTy);
210+
return rewriteCallResultType(loc, recTy, newResTys, newInTyAndAttrs,
211+
newOpers, savedStackPtr, m);
212+
}
213+
188214
void passArgumentOnStackOrWithNewType(
189215
mlir::Location loc, fir::CodeGenSpecifics::TypeAndAttr newTypeAndAttr,
190216
mlir::Type oldType, mlir::Value oper,
@@ -356,6 +382,11 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
356382
newInTyAndAttrs, newOpers,
357383
savedStackPtr);
358384
})
385+
.template Case<fir::RecordType>([&](fir::RecordType recTy) {
386+
wrap = rewriteCallStructResultType(loc, recTy, newResTys,
387+
newInTyAndAttrs, newOpers,
388+
savedStackPtr);
389+
})
359390
.Default([&](mlir::Type ty) { newResTys.push_back(ty); });
360391
} else if (fnTy.getResults().size() > 1) {
361392
TODO(loc, "multiple results not supported yet");
@@ -562,6 +593,24 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
562593
}
563594
}
564595

596+
template <typename Ty>
597+
void
598+
lowerStructSignatureRes(mlir::Location loc, fir::RecordType recTy,
599+
Ty &newResTys,
600+
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
601+
if (noComplexConversion) {
602+
newResTys.push_back(recTy);
603+
return;
604+
} else {
605+
for (auto &tup : specifics->structReturnType(loc, recTy)) {
606+
if (std::get<fir::CodeGenSpecifics::Attributes>(tup).isSRet())
607+
newInTyAndAttrs.push_back(tup);
608+
else
609+
newResTys.push_back(std::get<mlir::Type>(tup));
610+
}
611+
}
612+
}
613+
565614
void
566615
lowerStructSignatureArg(mlir::Location loc, fir::RecordType recTy,
567616
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
@@ -595,6 +644,9 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
595644
.Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
596645
lowerComplexSignatureRes(loc, ty, newResTys, newInTyAndAttrs);
597646
})
647+
.Case<fir::RecordType>([&](fir::RecordType ty) {
648+
lowerStructSignatureRes(loc, ty, newResTys, newInTyAndAttrs);
649+
})
598650
.Default([&](mlir::Type ty) { newResTys.push_back(ty); });
599651
}
600652
llvm::SmallVector<mlir::Type> trailingInTys;
@@ -696,7 +748,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
696748
for (auto ty : func.getResults())
697749
if ((mlir::isa<fir::BoxCharType>(ty) && !noCharacterConversion) ||
698750
(fir::isa_complex(ty) && !noComplexConversion) ||
699-
(mlir::isa<mlir::IntegerType>(ty) && hasCCallingConv)) {
751+
(mlir::isa<mlir::IntegerType>(ty) && hasCCallingConv) ||
752+
(mlir::isa<fir::RecordType>(ty) && !noStructConversion)) {
700753
LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n");
701754
return false;
702755
}
@@ -770,6 +823,9 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
770823
rewriter->getUnitAttr()));
771824
newResTys.push_back(retTy);
772825
})
826+
.Case<fir::RecordType>([&](fir::RecordType recTy) {
827+
doStructReturn(func, recTy, newResTys, newInTyAndAttrs, fixups);
828+
})
773829
.Default([&](mlir::Type ty) { newResTys.push_back(ty); });
774830

775831
// Saved potential shift in argument. Handling of result can add arguments
@@ -1062,21 +1118,12 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
10621118
return false;
10631119
}
10641120

1065-
/// Convert a complex return value. This can involve converting the return
1066-
/// value to a "hidden" first argument or packing the complex into a wide
1067-
/// GPR.
10681121
template <typename Ty, typename FIXUPS>
1069-
void doComplexReturn(mlir::func::FuncOp func, mlir::ComplexType cmplx,
1070-
Ty &newResTys,
1071-
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1072-
FIXUPS &fixups) {
1073-
if (noComplexConversion) {
1074-
newResTys.push_back(cmplx);
1075-
return;
1076-
}
1077-
auto m =
1078-
specifics->complexReturnType(func.getLoc(), cmplx.getElementType());
1079-
assert(m.size() == 1);
1122+
void doReturn(mlir::func::FuncOp func, Ty &newResTys,
1123+
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1124+
FIXUPS &fixups, fir::CodeGenSpecifics::Marshalling &m) {
1125+
assert(m.size() == 1 &&
1126+
"expect result to be turned into single argument or result so far");
10801127
auto &tup = m[0];
10811128
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup);
10821129
auto argTy = std::get<mlir::Type>(tup);
@@ -1117,6 +1164,36 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11171164
newResTys.push_back(argTy);
11181165
}
11191166

1167+
/// Convert a complex return value. This can involve converting the return
1168+
/// value to a "hidden" first argument or packing the complex into a wide
1169+
/// GPR.
1170+
template <typename Ty, typename FIXUPS>
1171+
void doComplexReturn(mlir::func::FuncOp func, mlir::ComplexType cmplx,
1172+
Ty &newResTys,
1173+
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1174+
FIXUPS &fixups) {
1175+
if (noComplexConversion) {
1176+
newResTys.push_back(cmplx);
1177+
return;
1178+
}
1179+
auto m =
1180+
specifics->complexReturnType(func.getLoc(), cmplx.getElementType());
1181+
doReturn(func, newResTys, newInTyAndAttrs, fixups, m);
1182+
}
1183+
1184+
template <typename Ty, typename FIXUPS>
1185+
void doStructReturn(mlir::func::FuncOp func, fir::RecordType recTy,
1186+
Ty &newResTys,
1187+
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1188+
FIXUPS &fixups) {
1189+
if (noStructConversion) {
1190+
newResTys.push_back(recTy);
1191+
return;
1192+
}
1193+
auto m = specifics->structReturnType(func.getLoc(), recTy);
1194+
doReturn(func, newResTys, newInTyAndAttrs, fixups, m);
1195+
}
1196+
11201197
template <typename FIXUPS>
11211198
void
11221199
createFuncOpArgFixups(mlir::func::FuncOp func,

0 commit comments

Comments
 (0)