Skip to content

Commit 09db497

Browse files
committed
[flang] finish BIND(C) VALUE derived type passing ABI on X86-64
Derived type passed with VALUE in BIND(C) context must be passed like C struct and LLVM is not implementing the ABI for this (it is up to the frontends like clang). Previous patch llvm#75802 implemented the simple cases where the derived type have one field, this patch implements the general case. Note that the generated LLVM IR is compliant from a X86-64 C ABI point of view and compatible with clang generated assembly, but that it is not guaranteed to match the LLVM IR signatures generated by clang for the C equivalent functions because several LLVM IR signatures may lead to the same X86-64 signature.
1 parent b8dca4f commit 09db497

File tree

3 files changed

+322
-61
lines changed

3 files changed

+322
-61
lines changed

flang/lib/Optimizer/CodeGen/Target.cpp

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,36 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
604604
return {};
605605
}
606606

607+
mlir::Type pickLLVMArgType(mlir::Location loc, mlir::MLIRContext *context,
608+
ArgClass argClass,
609+
std::uint64_t partByteSize) const {
610+
if (argClass == ArgClass::SSE) {
611+
if (partByteSize > 16)
612+
TODO(loc, "passing struct as a real > 128 bits in register");
613+
// Clang uses vector type when several fp fields are marshalled
614+
// into a single SSE register (like <n x smallest fp field> ).
615+
// It should make no difference from an ABI point of view to just
616+
// select an fp type of the right size, and it makes things simpler
617+
// here.
618+
if (partByteSize > 8)
619+
return mlir::FloatType::getF128(context);
620+
if (partByteSize > 4)
621+
return mlir::FloatType::getF64(context);
622+
if (partByteSize > 2)
623+
return mlir::FloatType::getF32(context);
624+
return mlir::FloatType::getF16(context);
625+
}
626+
assert(partByteSize <= 8 &&
627+
"expect integer part of aggregate argument to fit into eight bytes");
628+
if (partByteSize > 4)
629+
return mlir::IntegerType::get(context, 64);
630+
if (partByteSize > 2)
631+
return mlir::IntegerType::get(context, 32);
632+
if (partByteSize > 1)
633+
return mlir::IntegerType::get(context, 16);
634+
return mlir::IntegerType::get(context, 8);
635+
}
636+
607637
/// Marshal a derived type passed by value like a C struct.
608638
CodeGenSpecifics::Marshalling
609639
structArgumentType(mlir::Location loc, fir::RecordType recTy,
@@ -638,9 +668,29 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
638668
marshal.emplace_back(fieldType, AT{});
639669
return marshal;
640670
}
641-
// TODO, marshal the struct with several components, or with a single
642-
// complex, array, or derived type component into registers.
643-
TODO(loc, "passing BIND(C), VALUE derived type in registers on X86-64");
671+
if (Hi == ArgClass::NoClass || Hi == ArgClass::SSEUp) {
672+
// Pass a single integer or floating point argument.
673+
mlir::Type lowType =
674+
pickLLVMArgType(loc, recTy.getContext(), Lo, byteOffset);
675+
CodeGenSpecifics::Marshalling marshal;
676+
marshal.emplace_back(lowType, AT{});
677+
return marshal;
678+
}
679+
// Split into two integer or floating point arguments.
680+
// Note that for the first argument, this will always pick i64 or f64 which
681+
// may be bigger than needed if some struct padding ends the first eight
682+
// byte (e.g. for `{i32, f64}`). It is valid from an X86-64 ABI and
683+
// semantic point of view, but it may not match the LLVM IR interface clang
684+
// would produce for the equivalent C code (the assembly will still be
685+
// compatible). This allows keeping the logic simpler here since it
686+
// avoids computing the "data" size of the Lo part.
687+
mlir::Type lowType = pickLLVMArgType(loc, recTy.getContext(), Lo, 8u);
688+
mlir::Type hiType =
689+
pickLLVMArgType(loc, recTy.getContext(), Hi, byteOffset - 8u);
690+
CodeGenSpecifics::Marshalling marshal;
691+
marshal.emplace_back(lowType, AT{});
692+
marshal.emplace_back(hiType, AT{});
693+
return marshal;
644694
}
645695

646696
/// Marshal an argument that must be passed on the stack.

flang/lib/Optimizer/CodeGen/TargetRewrite.cpp

Lines changed: 110 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
180180
// We are going to generate an alloca, so save the stack pointer.
181181
if (!savedStackPtr)
182182
savedStackPtr = genStackSave(loc);
183-
auto mem = rewriter->create<fir::AllocaOp>(loc, resTy);
184-
rewriter->create<fir::StoreOp>(loc, call->getResult(0), mem);
185-
auto memTy = fir::ReferenceType::get(ty);
186-
auto cast = rewriter->create<fir::ConvertOp>(loc, memTy, mem);
187-
return rewriter->create<fir::LoadOp>(loc, cast);
183+
return this->convertValueInMemory(loc, call->getResult(0), ty,
184+
/*inputMayBeBigger=*/true);
188185
};
189186
}
190187

@@ -195,7 +192,6 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
195192
mlir::Value &savedStackPtr) {
196193
auto resTy = std::get<mlir::Type>(newTypeAndAttr);
197194
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(newTypeAndAttr);
198-
auto oldRefTy = fir::ReferenceType::get(oldType);
199195
// We are going to generate an alloca, so save the stack pointer.
200196
if (!savedStackPtr)
201197
savedStackPtr = genStackSave(loc);
@@ -206,11 +202,83 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
206202
mem = rewriter->create<fir::ConvertOp>(loc, resTy, mem);
207203
newOpers.push_back(mem);
208204
} else {
209-
auto mem = rewriter->create<fir::AllocaOp>(loc, resTy);
205+
mlir::Value bitcast =
206+
convertValueInMemory(loc, oper, resTy, /*inputMayBeBigger=*/false);
207+
newOpers.push_back(bitcast);
208+
}
209+
}
210+
211+
// Do a bitcast (convert a value via its memory representation).
212+
// The input and output types may have different storage sizes,
213+
// "inputMayBeBigger" should be set to indicate which of the input or
214+
// output type may be bigger in order for the load/store to be safe.
215+
// The mismatch comes from the fact that the LLVM register used for passing
216+
// may be bigger than the value being passed (e.g., passing
217+
// a `!fir.type<t{fir.array<3xi8>}>` into an i32 LLVM register).
218+
mlir::Value convertValueInMemory(mlir::Location loc, mlir::Value value,
219+
mlir::Type newType, bool inputMayBeBigger) {
220+
if (inputMayBeBigger) {
221+
auto newRefTy = fir::ReferenceType::get(newType);
222+
auto mem = rewriter->create<fir::AllocaOp>(loc, value.getType());
223+
rewriter->create<fir::StoreOp>(loc, value, mem);
224+
auto cast = rewriter->create<fir::ConvertOp>(loc, newRefTy, mem);
225+
return rewriter->create<fir::LoadOp>(loc, cast);
226+
} else {
227+
auto oldRefTy = fir::ReferenceType::get(value.getType());
228+
auto mem = rewriter->create<fir::AllocaOp>(loc, newType);
210229
auto cast = rewriter->create<fir::ConvertOp>(loc, oldRefTy, mem);
211-
rewriter->create<fir::StoreOp>(loc, oper, cast);
212-
newOpers.push_back(rewriter->create<fir::LoadOp>(loc, mem));
230+
rewriter->create<fir::StoreOp>(loc, value, cast);
231+
return rewriter->create<fir::LoadOp>(loc, mem);
232+
}
233+
}
234+
235+
void passSplitArgument(mlir::Location loc,
236+
fir::CodeGenSpecifics::Marshalling splitArgs,
237+
mlir::Type oldType, mlir::Value oper,
238+
llvm::SmallVectorImpl<mlir::Value> &newOpers,
239+
mlir::Value &savedStackPtr) {
240+
// COMPLEX or struct argument split into separate arguments
241+
if (!fir::isa_complex(oldType)) {
242+
// Cast original operand to a tuple of the new arguments
243+
// via memory.
244+
llvm::SmallVector<mlir::Type> partTypes;
245+
for (auto argPart : splitArgs)
246+
partTypes.push_back(std::get<mlir::Type>(argPart));
247+
mlir::Type tupleType =
248+
mlir::TupleType::get(oldType.getContext(), partTypes);
249+
if (!savedStackPtr)
250+
savedStackPtr = genStackSave(loc);
251+
oper = convertValueInMemory(loc, oper, tupleType,
252+
/*inputMayBeBigger=*/false);
253+
}
254+
auto iTy = rewriter->getIntegerType(32);
255+
for (auto e : llvm::enumerate(splitArgs)) {
256+
auto &tup = e.value();
257+
auto ty = std::get<mlir::Type>(tup);
258+
auto index = e.index();
259+
auto idx = rewriter->getIntegerAttr(iTy, index);
260+
auto val = rewriter->create<fir::ExtractValueOp>(
261+
loc, ty, oper, rewriter->getArrayAttr(idx));
262+
newOpers.push_back(val);
263+
}
264+
}
265+
266+
void rewriteCallOperands(
267+
mlir::Location loc, fir::CodeGenSpecifics::Marshalling passArgAs,
268+
mlir::Type originalArgTy, mlir::Value oper,
269+
llvm::SmallVectorImpl<mlir::Value> &newOpers, mlir::Value &savedStackPtr,
270+
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
271+
if (passArgAs.size() == 1) {
272+
// COMPLEX or derived type is passed as a single argument.
273+
passArgumentOnStackOrWithNewType(loc, passArgAs[0], originalArgTy, oper,
274+
newOpers, savedStackPtr);
275+
} else {
276+
// COMPLEX or derived type is split into separate arguments
277+
passSplitArgument(loc, passArgAs, originalArgTy, oper, newOpers,
278+
savedStackPtr);
213279
}
280+
newInTyAndAttrs.insert(newInTyAndAttrs.end(), passArgAs.begin(),
281+
passArgAs.end());
214282
}
215283

216284
template <typename CPLX>
@@ -224,28 +292,9 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
224292
newOpers.push_back(oper);
225293
return;
226294
}
227-
228295
auto m = specifics->complexArgumentType(loc, ty.getElementType());
229-
if (m.size() == 1) {
230-
// COMPLEX is a single aggregate
231-
passArgumentOnStackOrWithNewType(loc, m[0], ty, oper, newOpers,
232-
savedStackPtr);
233-
newInTyAndAttrs.push_back(m[0]);
234-
} else {
235-
assert(m.size() == 2);
236-
// COMPLEX is split into 2 separate arguments
237-
auto iTy = rewriter->getIntegerType(32);
238-
for (auto e : llvm::enumerate(m)) {
239-
auto &tup = e.value();
240-
auto ty = std::get<mlir::Type>(tup);
241-
auto index = e.index();
242-
auto idx = rewriter->getIntegerAttr(iTy, index);
243-
auto val = rewriter->create<fir::ExtractValueOp>(
244-
loc, ty, oper, rewriter->getArrayAttr(idx));
245-
newInTyAndAttrs.push_back(tup);
246-
newOpers.push_back(val);
247-
}
248-
}
296+
rewriteCallOperands(loc, m, ty, oper, newOpers, savedStackPtr,
297+
newInTyAndAttrs);
249298
}
250299

251300
void rewriteCallStructInputType(
@@ -260,11 +309,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
260309
}
261310
auto structArgs =
262311
specifics->structArgumentType(loc, recTy, newInTyAndAttrs);
263-
if (structArgs.size() != 1)
264-
TODO(loc, "splitting BIND(C), VALUE derived type into several arguments");
265-
passArgumentOnStackOrWithNewType(loc, structArgs[0], recTy, oper, newOpers,
266-
savedStackPtr);
267-
structArgs.push_back(structArgs[0]);
312+
rewriteCallOperands(loc, structArgs, recTy, oper, newOpers, savedStackPtr,
313+
newInTyAndAttrs);
268314
}
269315

270316
static bool hasByValOrSRetArgs(
@@ -849,20 +895,17 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
849895
case FixupTy::Codes::ArgumentType: {
850896
// Argument is pass-by-value, but its type has likely been modified to
851897
// suit the target ABI convention.
852-
auto oldArgTy =
853-
fir::ReferenceType::get(oldArgTys[fixup.index - offset]);
898+
auto oldArgTy = oldArgTys[fixup.index - offset];
854899
// If type did not change, keep the original argument.
855900
if (fixupType == oldArgTy)
856901
break;
857902

858903
auto newArg =
859904
func.front().insertArgument(fixup.index, fixupType, loc);
860905
rewriter->setInsertionPointToStart(&func.front());
861-
auto mem = rewriter->create<fir::AllocaOp>(loc, fixupType);
862-
rewriter->create<fir::StoreOp>(loc, newArg, mem);
863-
auto cast = rewriter->create<fir::ConvertOp>(loc, oldArgTy, mem);
864-
mlir::Value load = rewriter->create<fir::LoadOp>(loc, cast);
865-
func.getArgument(fixup.index + 1).replaceAllUsesWith(load);
906+
mlir::Value bitcast = convertValueInMemory(loc, newArg, oldArgTy,
907+
/*inputMayBeBigger=*/true);
908+
func.getArgument(fixup.index + 1).replaceAllUsesWith(bitcast);
866909
func.front().eraseArgument(fixup.index + 1);
867910
LLVM_DEBUG(llvm::dbgs()
868911
<< "old argument: " << oldArgTy.getEleTy()
@@ -907,34 +950,43 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
907950
func.walk([&](mlir::func::ReturnOp ret) {
908951
rewriter->setInsertionPoint(ret);
909952
auto oldOper = ret.getOperand(0);
910-
auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
911-
auto mem =
912-
rewriter->create<fir::AllocaOp>(loc, newResTys[fixup.index]);
913-
auto cast = rewriter->create<fir::ConvertOp>(loc, oldOperTy, mem);
914-
rewriter->create<fir::StoreOp>(loc, oldOper, cast);
915-
mlir::Value load = rewriter->create<fir::LoadOp>(loc, mem);
916-
rewriter->create<mlir::func::ReturnOp>(loc, load);
953+
mlir::Value bitcast =
954+
convertValueInMemory(loc, oldOper, newResTys[fixup.index],
955+
/*inputMayBeBigger=*/false);
956+
rewriter->create<mlir::func::ReturnOp>(loc, bitcast);
917957
ret.erase();
918958
});
919959
} break;
920960
case FixupTy::Codes::Split: {
921961
// The FIR argument has been split into a pair of distinct arguments
922-
// that are in juxtaposition to each other. (For COMPLEX value.)
962+
// that are in juxtaposition to each other. (For COMPLEX value or
963+
// derived type passed with VALUE in BIND(C) context).
923964
auto newArg =
924965
func.front().insertArgument(fixup.index, fixupType, loc);
925966
if (fixup.second == 1) {
926967
rewriter->setInsertionPointToStart(&func.front());
927-
auto cplxTy = oldArgTys[fixup.index - offset - fixup.second];
928-
auto undef = rewriter->create<fir::UndefOp>(loc, cplxTy);
968+
mlir::Value firstArg = func.front().getArgument(fixup.index - 1);
969+
mlir::Type originalTy =
970+
oldArgTys[fixup.index - offset - fixup.second];
971+
mlir::Type pairTy = originalTy;
972+
if (!fir::isa_complex(originalTy)) {
973+
pairTy = mlir::TupleType::get(
974+
originalTy.getContext(),
975+
mlir::TypeRange{firstArg.getType(), newArg.getType()});
976+
}
977+
auto undef = rewriter->create<fir::UndefOp>(loc, pairTy);
929978
auto iTy = rewriter->getIntegerType(32);
930979
auto zero = rewriter->getIntegerAttr(iTy, 0);
931980
auto one = rewriter->getIntegerAttr(iTy, 1);
932-
auto cplx1 = rewriter->create<fir::InsertValueOp>(
933-
loc, cplxTy, undef, func.front().getArgument(fixup.index - 1),
934-
rewriter->getArrayAttr(zero));
935-
auto cplx = rewriter->create<fir::InsertValueOp>(
936-
loc, cplxTy, cplx1, newArg, rewriter->getArrayAttr(one));
937-
func.getArgument(fixup.index + 1).replaceAllUsesWith(cplx);
981+
mlir::Value pair1 = rewriter->create<fir::InsertValueOp>(
982+
loc, pairTy, undef, firstArg, rewriter->getArrayAttr(zero));
983+
mlir::Value pair = rewriter->create<fir::InsertValueOp>(
984+
loc, pairTy, pair1, newArg, rewriter->getArrayAttr(one));
985+
// Cast local argument tuple to original type via memory if needed.
986+
if (pairTy != originalTy)
987+
pair = convertValueInMemory(loc, pair, originalTy,
988+
/*inputMayBeBigger=*/true);
989+
func.getArgument(fixup.index + 1).replaceAllUsesWith(pair);
938990
func.front().eraseArgument(fixup.index + 1);
939991
offset++;
940992
}

0 commit comments

Comments
 (0)