Skip to content

[flang] AArch64 ABI for BIND(C) VALUE parameters #118305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 120 additions & 26 deletions flang/lib/Optimizer/CodeGen/Target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,8 @@ struct TargetX86_64Win : public GenericTarget<TargetX86_64Win> {
//===----------------------------------------------------------------------===//

namespace {
// AArch64 procedure call standard:
// https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#parameter-passing
struct TargetAArch64 : public GenericTarget<TargetAArch64> {
using GenericTarget::GenericTarget;

Expand Down Expand Up @@ -826,7 +828,7 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> {
return marshal;
}

// Flatten a RecordType::TypeList containing more record types or array types
// Flatten a RecordType::TypeList containing more record types or array type
static std::optional<std::vector<mlir::Type>>
flattenTypeList(const RecordType::TypeList &types) {
std::vector<mlir::Type> flatTypes;
Expand Down Expand Up @@ -870,52 +872,144 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> {

// Determine if the type is a Homogenous Floating-point Aggregate (HFA). An
// HFA is a record type with up to 4 floating-point members of the same type.
static bool isHFA(fir::RecordType ty) {
static std::optional<int> usedRegsForHFA(fir::RecordType ty) {
RecordType::TypeList types = ty.getTypeList();
if (types.empty() || types.size() > 4)
return false;
return std::nullopt;

std::optional<std::vector<mlir::Type>> flatTypes = flattenTypeList(types);
if (!flatTypes || flatTypes->size() > 4) {
return false;
return std::nullopt;
}

if (!isa_real(flatTypes->front())) {
return false;
return std::nullopt;
}

return llvm::all_equal(*flatTypes);
return llvm::all_equal(*flatTypes) ? std::optional<int>{flatTypes->size()}
: std::nullopt;
}

// AArch64 procedure call ABI:
// https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#parameter-passing
CodeGenSpecifics::Marshalling
structReturnType(mlir::Location loc, fir::RecordType ty) const override {
CodeGenSpecifics::Marshalling marshal;
struct NRegs {
int n{0};
bool isSimd{false};
};

if (isHFA(ty)) {
// Just return the existing record type
marshal.emplace_back(ty, AT{});
return marshal;
NRegs usedRegsForRecordType(mlir::Location loc, fir::RecordType type) const {
if (std::optional<int> size = usedRegsForHFA(type))
return {*size, true};

auto [size, align] = fir::getTypeSizeAndAlignmentOrCrash(
loc, type, getDataLayout(), kindMap);

if (size <= 16)
return {static_cast<int>((size + 7) / 8), false};

// Pass on the stack, i.e. no registers used
return {};
}

NRegs usedRegsForType(mlir::Location loc, mlir::Type type) const {
return llvm::TypeSwitch<mlir::Type, NRegs>(type)
.Case<mlir::IntegerType>([&](auto intTy) {
return intTy.getWidth() == 128 ? NRegs{2, false} : NRegs{1, false};
})
.Case<mlir::FloatType>([&](auto) { return NRegs{1, true}; })
.Case<mlir::ComplexType>([&](auto) { return NRegs{2, true}; })
.Case<fir::LogicalType>([&](auto) { return NRegs{1, false}; })
.Case<fir::CharacterType>([&](auto) { return NRegs{1, false}; })
.Case<fir::SequenceType>([&](auto ty) {
assert(ty.getShape().size() == 1 &&
"invalid array dimensions in BIND(C)");
NRegs nregs = usedRegsForType(loc, ty.getEleTy());
nregs.n *= ty.getShape()[0];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if this is an assumed shape array? Then the shape in the type might have placeholder numbers here.

On another note, I think get eleTy gets the scalar element type, no matter the rank. So wouldn't this need to work across all dimensions of the shape?

I'm not sure what we do about assumed rank here. I guess that is not allowed for bind(c)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arrays in general aren't actually allowed for BIND(C), nor as VALUE parameters even without BIND(C). I added the handling for scalar arrays here in case we use them elsewhere in lowering because it's easy to imagine what that would look like, and in case we use fir.array parameters/return values to represent other things but maybe we should just throw an error here?

Copy link
Contributor

@jeanPerier jeanPerier Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arrays in general aren't actually allowed for BIND(C), nor as VALUE parameters even without BIND(C).

Outside of the BIND(C) case array dummy arguments can have the VALUE attribute, but are always lowered in memory (+copy if needed) already.

Anyway, the extent it should never be unknown in this context, but you can always add an assert.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification

Copy link
Member Author

@DavidTruby DavidTruby Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outside of the BIND(C) case array dummy arguments can have the VALUE attribute

My mistake, I should have checked the standard rather than just checking gfortran :)

return nregs;
})
.Case<fir::RecordType>(
[&](auto ty) { return usedRegsForRecordType(loc, ty); })
.Case<fir::VectorType>([&](auto) {
TODO(loc, "passing vector argument to C by value is not supported");
return NRegs{};
});
}

bool hasEnoughRegisters(mlir::Location loc, fir::RecordType type,
const Marshalling &previousArguments) const {
int availIntRegisters = 8;
int availSIMDRegisters = 8;
Comment on lines +938 to +939
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to have a standard reference to make it clear where these numbers come from.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a comment linking to the spec lower down but I can move it up here, or maybe to the top of this class?


// Check previous arguments to see how many registers are used already
for (auto [type, attr] : previousArguments) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels like a lot of time complexity because presumably every argument will check all of its previous arguments.

I guess in practice, this can't get too big because of the number of registers and you are limited by the API of struct argument type.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took this method from the x86 handling; I'm not sure there's another way to do it with how this is structured really.

Copy link
Contributor

@jeanPerier jeanPerier Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, it is a bit dumb, but I do not expect the BIND(C) VALUE struct argument usage to be high enough so I did not modify the logic/interface too much when I added the X86-64 impl.

The main "issue" is that we do not call the target lowering for all arguments, so the target lowering cannot maintain some register state properly.

Adding callbacks for "normal" arguments could arguably increase the cost more in general because of the virtual aspects of the callback than doing the computation "again" for the few BIND(C) VALUE struct arguments.

if (availIntRegisters <= 0 || availSIMDRegisters <= 0)
break;

if (attr.isByVal())
continue; // Previous argument passed on the stack

NRegs nregs = usedRegsForType(loc, type);
if (nregs.isSimd)
availSIMDRegisters -= nregs.n;
else
availIntRegisters -= nregs.n;
}

auto [size, align] =
NRegs nregs = usedRegsForRecordType(loc, type);

if (nregs.isSimd)
return nregs.n <= availSIMDRegisters;

return nregs.n <= availIntRegisters;
}

CodeGenSpecifics::Marshalling
passOnTheStack(mlir::Location loc, mlir::Type ty, bool isResult) const {
CodeGenSpecifics::Marshalling marshal;
auto sizeAndAlign =
fir::getTypeSizeAndAlignmentOrCrash(loc, ty, getDataLayout(), kindMap);
// The stack is always 8 byte aligned
unsigned short align =
std::max(sizeAndAlign.second, static_cast<unsigned short>(8));
marshal.emplace_back(fir::ReferenceType::get(ty),
AT{align, /*byval=*/!isResult, /*sret=*/isResult});
return marshal;
}

// return in registers if size <= 16 bytes
if (size <= 16) {
std::size_t dwordSize = (size + 7) / 8;
auto newTy = fir::SequenceType::get(
dwordSize, mlir::IntegerType::get(ty.getContext(), 64));
marshal.emplace_back(newTy, AT{});
return marshal;
CodeGenSpecifics::Marshalling
structType(mlir::Location loc, fir::RecordType type, bool isResult) const {
NRegs nregs = usedRegsForRecordType(loc, type);

// If the type needs no registers it must need to be passed on the stack
if (nregs.n == 0)
return passOnTheStack(loc, type, isResult);

CodeGenSpecifics::Marshalling marshal;

mlir::Type pcsType;
if (nregs.isSimd) {
pcsType = type;
} else {
pcsType = fir::SequenceType::get(
nregs.n, mlir::IntegerType::get(type.getContext(), 64));
}

unsigned short stackAlign = std::max<unsigned short>(align, 8u);
marshal.emplace_back(fir::ReferenceType::get(ty),
AT{stackAlign, false, true});
marshal.emplace_back(pcsType, AT{});
return marshal;
}

CodeGenSpecifics::Marshalling
structArgumentType(mlir::Location loc, fir::RecordType ty,
const Marshalling &previousArguments) const override {
if (!hasEnoughRegisters(loc, ty, previousArguments)) {
return passOnTheStack(loc, ty, /*isResult=*/false);
}

return structType(loc, ty, /*isResult=*/false);
}

CodeGenSpecifics::Marshalling
structReturnType(mlir::Location loc, fir::RecordType ty) const override {
return structType(loc, ty, /*isResult=*/true);
}
};
} // namespace

Expand Down
73 changes: 73 additions & 0 deletions flang/test/Fir/struct-passing-aarch64-byval.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Test AArch64 ABI rewrite of struct passed by value (BIND(C), VALUE derived types).
// RUN: fir-opt --target-rewrite="target=aarch64-unknown-linux-gnu" %s | FileCheck %s

// CHECK-LABEL: func.func private @small_i32(!fir.array<2xi64>)
func.func private @small_i32(!fir.type<small_i32{i:i32,j:i32,k:i32}>)
// CHECK-LABEL: func.func private @small_i64(!fir.array<2xi64>)
func.func private @small_i64(!fir.type<small_i64{i:i64,j:i64}>)
// CHECK-LABEL: func.func private @small_mixed(!fir.array<2xi64>)
func.func private @small_mixed(!fir.type<small_mixed{i:i64,j:f32,k:i32}>)
// CHECK-LABEL: func.func private @small_non_hfa(!fir.array<2xi64>)
func.func private @small_non_hfa(!fir.type<small_non_hfa{i:f64,j:f32,k:f16}>)

// CHECK-LABEL: func.func private @hfa_f16(!fir.type<hfa_f16{i:f16,j:f16}>)
func.func private @hfa_f16(!fir.type<hfa_f16{i:f16,j:f16}>)
// CHECK-LABEL: func.func private @hfa_bf16(!fir.type<hfa_bf16{i:bf16,j:bf16,k:bf16,l:bf16}>)
func.func private @hfa_bf16(!fir.type<hfa_bf16{i:bf16,j:bf16,k:bf16,l:bf16}>)
// CHECK-LABEL: func.func private @hfa_f32(!fir.type<hfa_f32{i:f32,j:f32}>)
func.func private @hfa_f32(!fir.type<hfa_f32{i:f32,j:f32}>)
// CHECK-LABEL: func.func private @hfa_f64(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>)
func.func private @hfa_f64(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>)
// CHECK-LABEL: func.func private @hfa_f128(!fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>)
func.func private @hfa_f128(!fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>)

// CHECK-LABEL: func.func private @multi_small_integer(!fir.array<2xi64>, !fir.array<2xi64>)
func.func private @multi_small_integer(!fir.type<small_i32{i:i32,j:i32,k:i32}>, !fir.type<small_i64{i:i64,j:i64}>)
// CHECK-LABEL: func.func private @multi_hfas(!fir.type<hfa_f16{i:f16,j:f16}>, !fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>)
func.func private @multi_hfas(!fir.type<hfa_f16{i:f16,j:f16}>, !fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>)
// CHECK-LABEL: func.func private @multi_mixed(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>, !fir.array<2xi64>, !fir.type<hfa_f32{i:f32,j:f32}>, !fir.array<2xi64>)
func.func private @multi_mixed(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>,!fir.type<small_non_hfa{i:f64,j:f32,k:f16}>,!fir.type<hfa_f32{i:f32,j:f32}>,!fir.type<small_i64{i:i64,j:i64}>)

// CHECK-LABEL: func.func private @int_max(!fir.array<2xi64>,
// CHECK-SAME: !fir.array<2xi64>,
// CHECK-SAME: !fir.array<2xi64>,
// CHECK-SAME: !fir.array<2xi64>)
func.func private @int_max(!fir.type<int_max{i:i64,j:i64}>,
!fir.type<int_max{i:i64,j:i64}>,
!fir.type<int_max{i:i64,j:i64}>,
!fir.type<int_max{i:i64,j:i64}>)
// CHECK-LABEL: func.func private @hfa_max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>, !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>)
func.func private @hfa_max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>, !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>)
// CHECK-LABEL: func.func private @max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
// CHECK-SAME: !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
// CHECK-SAME: !fir.array<2xi64>,
// CHECK-SAME: !fir.array<2xi64>,
// CHECK-SAME: !fir.array<2xi64>,
// CHECK-SAME: !fir.array<2xi64>)
func.func private @max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
!fir.type<int_max{i:i64,j:i64}>,
!fir.type<int_max{i:i64,j:i64}>,
!fir.type<int_max{i:i64,j:i64}>,
!fir.type<int_max{i:i64,j:i64}>)


// CHECK-LABEL: func.func private @too_many_int(!fir.array<2xi64>,
// CHECK-SAME: !fir.array<2xi64>,
// CHECK-SAME: !fir.array<2xi64>,
// CHECK-SAME: !fir.array<2xi64>,
// CHECK-SAME: !fir.ref<!fir.type<int_max{i:i64,j:i64}>> {{{.*}}, llvm.byval = !fir.type<int_max{i:i64,j:i64}>})
func.func private @too_many_int(!fir.type<int_max{i:i64,j:i64}>,
!fir.type<int_max{i:i64,j:i64}>,
!fir.type<int_max{i:i64,j:i64}>,
!fir.type<int_max{i:i64,j:i64}>,
!fir.type<int_max{i:i64,j:i64}>)
// CHECK-LABEL: func.func private @too_many_hfa(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
// CHECK-SAME: !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
// CHECK-SAME: !fir.ref<!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>> {{{.*}}, llvm.byval = !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>})
func.func private @too_many_hfa(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>)

// CHECK-LABEL: func.func private @too_big(!fir.ref<!fir.type<too_big{i:!fir.array<5xi32>}>> {{{.*}}, llvm.byval = !fir.type<too_big{i:!fir.array<5xi32>}>})
func.func private @too_big(!fir.type<too_big{i:!fir.array<5xi32>}>)
Loading