-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[flang] AArch64 ABI for BIND(C) VALUE parameters #118305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -788,6 +788,8 @@ struct TargetX86_64Win : public GenericTarget<TargetX86_64Win> { | |
//===----------------------------------------------------------------------===// | ||
|
||
namespace { | ||
// AArch64 procedure call standard: | ||
// https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#parameter-passing | ||
struct TargetAArch64 : public GenericTarget<TargetAArch64> { | ||
using GenericTarget::GenericTarget; | ||
|
||
|
@@ -826,7 +828,7 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> { | |
return marshal; | ||
} | ||
|
||
// Flatten a RecordType::TypeList containing more record types or array types | ||
// Flatten a RecordType::TypeList containing more record types or array type | ||
static std::optional<std::vector<mlir::Type>> | ||
flattenTypeList(const RecordType::TypeList &types) { | ||
std::vector<mlir::Type> flatTypes; | ||
|
@@ -870,52 +872,144 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> { | |
|
||
// Determine if the type is a Homogenous Floating-point Aggregate (HFA). An | ||
// HFA is a record type with up to 4 floating-point members of the same type. | ||
static bool isHFA(fir::RecordType ty) { | ||
static std::optional<int> usedRegsForHFA(fir::RecordType ty) { | ||
RecordType::TypeList types = ty.getTypeList(); | ||
if (types.empty() || types.size() > 4) | ||
return false; | ||
return std::nullopt; | ||
|
||
std::optional<std::vector<mlir::Type>> flatTypes = flattenTypeList(types); | ||
if (!flatTypes || flatTypes->size() > 4) { | ||
return false; | ||
return std::nullopt; | ||
} | ||
|
||
if (!isa_real(flatTypes->front())) { | ||
return false; | ||
return std::nullopt; | ||
} | ||
|
||
return llvm::all_equal(*flatTypes); | ||
return llvm::all_equal(*flatTypes) ? std::optional<int>{flatTypes->size()} | ||
: std::nullopt; | ||
} | ||
|
||
// AArch64 procedure call ABI: | ||
// https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#parameter-passing | ||
CodeGenSpecifics::Marshalling | ||
structReturnType(mlir::Location loc, fir::RecordType ty) const override { | ||
CodeGenSpecifics::Marshalling marshal; | ||
struct NRegs { | ||
int n{0}; | ||
bool isSimd{false}; | ||
}; | ||
|
||
if (isHFA(ty)) { | ||
// Just return the existing record type | ||
marshal.emplace_back(ty, AT{}); | ||
return marshal; | ||
NRegs usedRegsForRecordType(mlir::Location loc, fir::RecordType type) const { | ||
if (std::optional<int> size = usedRegsForHFA(type)) | ||
return {*size, true}; | ||
|
||
auto [size, align] = fir::getTypeSizeAndAlignmentOrCrash( | ||
loc, type, getDataLayout(), kindMap); | ||
|
||
if (size <= 16) | ||
return {static_cast<int>((size + 7) / 8), false}; | ||
|
||
// Pass on the stack, i.e. no registers used | ||
return {}; | ||
} | ||
|
||
NRegs usedRegsForType(mlir::Location loc, mlir::Type type) const { | ||
return llvm::TypeSwitch<mlir::Type, NRegs>(type) | ||
.Case<mlir::IntegerType>([&](auto intTy) { | ||
return intTy.getWidth() == 128 ? NRegs{2, false} : NRegs{1, false}; | ||
}) | ||
.Case<mlir::FloatType>([&](auto) { return NRegs{1, true}; }) | ||
.Case<mlir::ComplexType>([&](auto) { return NRegs{2, true}; }) | ||
.Case<fir::LogicalType>([&](auto) { return NRegs{1, false}; }) | ||
.Case<fir::CharacterType>([&](auto) { return NRegs{1, false}; }) | ||
.Case<fir::SequenceType>([&](auto ty) { | ||
assert(ty.getShape().size() == 1 && | ||
"invalid array dimensions in BIND(C)"); | ||
NRegs nregs = usedRegsForType(loc, ty.getEleTy()); | ||
nregs.n *= ty.getShape()[0]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if this is an assumed shape array? Then the shape in the type might have placeholder numbers here. On another note, I think get eleTy gets the scalar element type, no matter the rank. So wouldn't this need to work across all dimensions of the shape? I'm not sure what we do about assumed rank here. I guess that is not allowed for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Arrays in general aren't actually allowed for BIND(C), nor as VALUE parameters even without BIND(C). I added the handling for scalar arrays here in case we use them elsewhere in lowering because it's easy to imagine what that would look like, and in case we use fir.array parameters/return values to represent other things but maybe we should just throw an error here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Outside of the BIND(C) case array dummy arguments can have the VALUE attribute, but are always lowered in memory (+copy if needed) already. Anyway, the extent it should never be unknown in this context, but you can always add an assert. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the clarification There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
My mistake, I should have checked the standard rather than just checking gfortran :) |
||
return nregs; | ||
}) | ||
.Case<fir::RecordType>( | ||
[&](auto ty) { return usedRegsForRecordType(loc, ty); }) | ||
.Case<fir::VectorType>([&](auto) { | ||
TODO(loc, "passing vector argument to C by value is not supported"); | ||
return NRegs{}; | ||
}); | ||
} | ||
|
||
bool hasEnoughRegisters(mlir::Location loc, fir::RecordType type, | ||
const Marshalling &previousArguments) const { | ||
int availIntRegisters = 8; | ||
int availSIMDRegisters = 8; | ||
Comment on lines
+938
to
+939
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be good to have a standard reference to make it clear where these numbers come from. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a comment linking to the spec lower down but I can move it up here, or maybe to the top of this class? |
||
|
||
// Check previous arguments to see how many registers are used already | ||
for (auto [type, attr] : previousArguments) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This feels like a lot of time complexity because presumably every argument will check all of its previous arguments. I guess in practice, this can't get too big because of the number of registers and you are limited by the API of struct argument type. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I took this method from the x86 handling; I'm not sure there's another way to do it with how this is structured really. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, it is a bit dumb, but I do not expect the BIND(C) VALUE struct argument usage to be high enough so I did not modify the logic/interface too much when I added the X86-64 impl. The main "issue" is that we do not call the target lowering for all arguments, so the target lowering cannot maintain some register state properly. Adding callbacks for "normal" arguments could arguably increase the cost more in general because of the virtual aspects of the callback than doing the computation "again" for the few BIND(C) VALUE struct arguments. |
||
if (availIntRegisters <= 0 || availSIMDRegisters <= 0) | ||
break; | ||
|
||
if (attr.isByVal()) | ||
continue; // Previous argument passed on the stack | ||
|
||
NRegs nregs = usedRegsForType(loc, type); | ||
if (nregs.isSimd) | ||
availSIMDRegisters -= nregs.n; | ||
else | ||
availIntRegisters -= nregs.n; | ||
} | ||
|
||
auto [size, align] = | ||
NRegs nregs = usedRegsForRecordType(loc, type); | ||
|
||
if (nregs.isSimd) | ||
return nregs.n <= availSIMDRegisters; | ||
|
||
return nregs.n <= availIntRegisters; | ||
} | ||
|
||
CodeGenSpecifics::Marshalling | ||
passOnTheStack(mlir::Location loc, mlir::Type ty, bool isResult) const { | ||
CodeGenSpecifics::Marshalling marshal; | ||
auto sizeAndAlign = | ||
fir::getTypeSizeAndAlignmentOrCrash(loc, ty, getDataLayout(), kindMap); | ||
// The stack is always 8 byte aligned | ||
unsigned short align = | ||
std::max(sizeAndAlign.second, static_cast<unsigned short>(8)); | ||
marshal.emplace_back(fir::ReferenceType::get(ty), | ||
AT{align, /*byval=*/!isResult, /*sret=*/isResult}); | ||
return marshal; | ||
} | ||
|
||
// return in registers if size <= 16 bytes | ||
if (size <= 16) { | ||
std::size_t dwordSize = (size + 7) / 8; | ||
auto newTy = fir::SequenceType::get( | ||
dwordSize, mlir::IntegerType::get(ty.getContext(), 64)); | ||
marshal.emplace_back(newTy, AT{}); | ||
return marshal; | ||
CodeGenSpecifics::Marshalling | ||
structType(mlir::Location loc, fir::RecordType type, bool isResult) const { | ||
NRegs nregs = usedRegsForRecordType(loc, type); | ||
|
||
// If the type needs no registers it must need to be passed on the stack | ||
if (nregs.n == 0) | ||
return passOnTheStack(loc, type, isResult); | ||
|
||
CodeGenSpecifics::Marshalling marshal; | ||
|
||
mlir::Type pcsType; | ||
if (nregs.isSimd) { | ||
pcsType = type; | ||
} else { | ||
pcsType = fir::SequenceType::get( | ||
nregs.n, mlir::IntegerType::get(type.getContext(), 64)); | ||
} | ||
|
||
unsigned short stackAlign = std::max<unsigned short>(align, 8u); | ||
marshal.emplace_back(fir::ReferenceType::get(ty), | ||
AT{stackAlign, false, true}); | ||
marshal.emplace_back(pcsType, AT{}); | ||
return marshal; | ||
} | ||
|
||
CodeGenSpecifics::Marshalling | ||
structArgumentType(mlir::Location loc, fir::RecordType ty, | ||
const Marshalling &previousArguments) const override { | ||
if (!hasEnoughRegisters(loc, ty, previousArguments)) { | ||
return passOnTheStack(loc, ty, /*isResult=*/false); | ||
} | ||
|
||
return structType(loc, ty, /*isResult=*/false); | ||
} | ||
|
||
CodeGenSpecifics::Marshalling | ||
structReturnType(mlir::Location loc, fir::RecordType ty) const override { | ||
return structType(loc, ty, /*isResult=*/true); | ||
} | ||
}; | ||
} // namespace | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
// Test AArch64 ABI rewrite of struct passed by value (BIND(C), VALUE derived types). | ||
// RUN: fir-opt --target-rewrite="target=aarch64-unknown-linux-gnu" %s | FileCheck %s | ||
|
||
// CHECK-LABEL: func.func private @small_i32(!fir.array<2xi64>) | ||
func.func private @small_i32(!fir.type<small_i32{i:i32,j:i32,k:i32}>) | ||
// CHECK-LABEL: func.func private @small_i64(!fir.array<2xi64>) | ||
func.func private @small_i64(!fir.type<small_i64{i:i64,j:i64}>) | ||
// CHECK-LABEL: func.func private @small_mixed(!fir.array<2xi64>) | ||
func.func private @small_mixed(!fir.type<small_mixed{i:i64,j:f32,k:i32}>) | ||
// CHECK-LABEL: func.func private @small_non_hfa(!fir.array<2xi64>) | ||
func.func private @small_non_hfa(!fir.type<small_non_hfa{i:f64,j:f32,k:f16}>) | ||
|
||
// CHECK-LABEL: func.func private @hfa_f16(!fir.type<hfa_f16{i:f16,j:f16}>) | ||
func.func private @hfa_f16(!fir.type<hfa_f16{i:f16,j:f16}>) | ||
// CHECK-LABEL: func.func private @hfa_bf16(!fir.type<hfa_bf16{i:bf16,j:bf16,k:bf16,l:bf16}>) | ||
func.func private @hfa_bf16(!fir.type<hfa_bf16{i:bf16,j:bf16,k:bf16,l:bf16}>) | ||
// CHECK-LABEL: func.func private @hfa_f32(!fir.type<hfa_f32{i:f32,j:f32}>) | ||
func.func private @hfa_f32(!fir.type<hfa_f32{i:f32,j:f32}>) | ||
// CHECK-LABEL: func.func private @hfa_f64(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>) | ||
func.func private @hfa_f64(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>) | ||
// CHECK-LABEL: func.func private @hfa_f128(!fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>) | ||
func.func private @hfa_f128(!fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>) | ||
|
||
// CHECK-LABEL: func.func private @multi_small_integer(!fir.array<2xi64>, !fir.array<2xi64>) | ||
func.func private @multi_small_integer(!fir.type<small_i32{i:i32,j:i32,k:i32}>, !fir.type<small_i64{i:i64,j:i64}>) | ||
// CHECK-LABEL: func.func private @multi_hfas(!fir.type<hfa_f16{i:f16,j:f16}>, !fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>) | ||
func.func private @multi_hfas(!fir.type<hfa_f16{i:f16,j:f16}>, !fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>) | ||
// CHECK-LABEL: func.func private @multi_mixed(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>, !fir.array<2xi64>, !fir.type<hfa_f32{i:f32,j:f32}>, !fir.array<2xi64>) | ||
func.func private @multi_mixed(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>,!fir.type<small_non_hfa{i:f64,j:f32,k:f16}>,!fir.type<hfa_f32{i:f32,j:f32}>,!fir.type<small_i64{i:i64,j:i64}>) | ||
|
||
// CHECK-LABEL: func.func private @int_max(!fir.array<2xi64>, | ||
// CHECK-SAME: !fir.array<2xi64>, | ||
// CHECK-SAME: !fir.array<2xi64>, | ||
// CHECK-SAME: !fir.array<2xi64>) | ||
func.func private @int_max(!fir.type<int_max{i:i64,j:i64}>, | ||
!fir.type<int_max{i:i64,j:i64}>, | ||
!fir.type<int_max{i:i64,j:i64}>, | ||
!fir.type<int_max{i:i64,j:i64}>) | ||
// CHECK-LABEL: func.func private @hfa_max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>, !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>) | ||
func.func private @hfa_max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>, !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>) | ||
// CHECK-LABEL: func.func private @max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>, | ||
// CHECK-SAME: !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>, | ||
// CHECK-SAME: !fir.array<2xi64>, | ||
// CHECK-SAME: !fir.array<2xi64>, | ||
// CHECK-SAME: !fir.array<2xi64>, | ||
// CHECK-SAME: !fir.array<2xi64>) | ||
func.func private @max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>, | ||
!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>, | ||
!fir.type<int_max{i:i64,j:i64}>, | ||
!fir.type<int_max{i:i64,j:i64}>, | ||
!fir.type<int_max{i:i64,j:i64}>, | ||
!fir.type<int_max{i:i64,j:i64}>) | ||
|
||
|
||
// CHECK-LABEL: func.func private @too_many_int(!fir.array<2xi64>, | ||
// CHECK-SAME: !fir.array<2xi64>, | ||
// CHECK-SAME: !fir.array<2xi64>, | ||
// CHECK-SAME: !fir.array<2xi64>, | ||
// CHECK-SAME: !fir.ref<!fir.type<int_max{i:i64,j:i64}>> {{{.*}}, llvm.byval = !fir.type<int_max{i:i64,j:i64}>}) | ||
func.func private @too_many_int(!fir.type<int_max{i:i64,j:i64}>, | ||
!fir.type<int_max{i:i64,j:i64}>, | ||
!fir.type<int_max{i:i64,j:i64}>, | ||
!fir.type<int_max{i:i64,j:i64}>, | ||
!fir.type<int_max{i:i64,j:i64}>) | ||
// CHECK-LABEL: func.func private @too_many_hfa(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>, | ||
// CHECK-SAME: !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>, | ||
// CHECK-SAME: !fir.ref<!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>> {{{.*}}, llvm.byval = !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>}) | ||
func.func private @too_many_hfa(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>, | ||
!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>, | ||
!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>) | ||
|
||
// CHECK-LABEL: func.func private @too_big(!fir.ref<!fir.type<too_big{i:!fir.array<5xi32>}>> {{{.*}}, llvm.byval = !fir.type<too_big{i:!fir.array<5xi32>}>}) | ||
func.func private @too_big(!fir.type<too_big{i:!fir.array<5xi32>}>) |
Uh oh!
There was an error while loading. Please reload this page.