Skip to content

[flang] AArch64 ABI for BIND(C) VALUE parameters #118305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 18, 2024

Conversation

DavidTruby
Copy link
Member

This patch adds handling for derived type VALUE parameters in BIND(C)
functions for AArch64.

@DavidTruby DavidTruby requested a review from tblah December 2, 2024 14:52
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:codegen labels Dec 2, 2024
@llvmbot
Copy link
Member

llvmbot commented Dec 2, 2024

@llvm/pr-subscribers-flang-codegen

@llvm/pr-subscribers-flang-fir-hlfir

Author: David Truby (DavidTruby)

Changes

This patch adds handling for derived type VALUE parameters in BIND(C)
functions for AArch64.


Full diff: https://github.com/llvm/llvm-project/pull/118305.diff

2 Files Affected:

  • (modified) flang/lib/Optimizer/CodeGen/Target.cpp (+116-24)
  • (added) flang/test/Fir/struct-passing-aarch64-byval.fir (+73)
diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp
index f7bffbf53c190e..0d865ee09535a3 100644
--- a/flang/lib/Optimizer/CodeGen/Target.cpp
+++ b/flang/lib/Optimizer/CodeGen/Target.cpp
@@ -826,7 +826,7 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> {
     return marshal;
   }
 
-  // Flatten a RecordType::TypeList containing more record types or array types
+  // Flatten a RecordType::TypeList containing more record types or array type
   static std::optional<std::vector<mlir::Type>>
   flattenTypeList(const RecordType::TypeList &types) {
     std::vector<mlir::Type> flatTypes;
@@ -870,51 +870,143 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> {
 
   // Determine if the type is a Homogenous Floating-point Aggregate (HFA). An
   // HFA is a record type with up to 4 floating-point members of the same type.
-  static bool isHFA(fir::RecordType ty) {
+  static std::optional<int> usedRegsForHFA(fir::RecordType ty) {
     RecordType::TypeList types = ty.getTypeList();
     if (types.empty() || types.size() > 4)
-      return false;
+      return std::nullopt;
 
     std::optional<std::vector<mlir::Type>> flatTypes = flattenTypeList(types);
     if (!flatTypes || flatTypes->size() > 4) {
-      return false;
+      return std::nullopt;
     }
 
     if (!isa_real(flatTypes->front())) {
-      return false;
+      return std::nullopt;
+    }
+
+    return llvm::all_equal(*flatTypes) ? std::optional<int>{flatTypes->size()}
+                                       : std::nullopt;
+  }
+
+  struct NRegs {
+    int n{0};
+    bool isSimd{false};
+  };
+
+  NRegs usedRegsForRecordType(mlir::Location loc, fir::RecordType type) const {
+    if (std::optional<int> size = usedRegsForHFA(type))
+      return {*size, true};
+
+    auto [size, align] = fir::getTypeSizeAndAlignmentOrCrash(
+        loc, type, getDataLayout(), kindMap);
+
+    if (size <= 16)
+      return {static_cast<int>((size + 7) / 8), false};
+
+    // Pass on the stack, i.e. no registers used
+    return {};
+  }
+
+  NRegs usedRegsForType(mlir::Location loc, mlir::Type type) const {
+    return llvm::TypeSwitch<mlir::Type, NRegs>(type)
+        .Case<mlir::IntegerType>([&](auto intTy) {
+          return intTy.getWidth() == 128 ? NRegs{2, false} : NRegs{1, false};
+        })
+        .Case<mlir::FloatType>([&](auto) { return NRegs{1, true}; })
+        .Case<mlir::ComplexType>([&](auto) { return NRegs{2, true}; })
+        .Case<fir::LogicalType>([&](auto) { return NRegs{1, false}; })
+        .Case<fir::CharacterType>([&](auto) { return NRegs{1, false}; })
+        .Case<fir::SequenceType>([&](auto ty) {
+          NRegs nregs = usedRegsForType(loc, ty.getEleTy());
+          nregs.n *= ty.getShape()[0];
+          return nregs;
+        })
+        .Case<fir::RecordType>(
+            [&](auto ty) { return usedRegsForRecordType(loc, ty); })
+        .Case<fir::VectorType>([&](auto) {
+          TODO(loc, "passing vector argument to C by value is not supported");
+          return NRegs{};
+        });
+  }
+
+  bool hasEnoughRegisters(mlir::Location loc, fir::RecordType type,
+                          const Marshalling &previousArguments) const {
+    int availIntRegisters = 8;
+    int availSIMDRegisters = 8;
+
+    // Check previous arguments to see how many registers are used already
+    for (auto [type, attr] : previousArguments) {
+      if (availIntRegisters <= 0 || availSIMDRegisters <= 0)
+        break;
+
+      if (attr.isByVal())
+        continue; // Previous argument passed on the stack
+
+      NRegs nregs = usedRegsForType(loc, type);
+      if (nregs.isSimd)
+        availSIMDRegisters -= nregs.n;
+      else
+        availIntRegisters -= nregs.n;
     }
 
-    return llvm::all_equal(*flatTypes);
+    NRegs nregs = usedRegsForRecordType(loc, type);
+
+    if (nregs.isSimd)
+      return nregs.n <= availSIMDRegisters;
+
+    return nregs.n <= availIntRegisters;
+  }
+
+  CodeGenSpecifics::Marshalling
+  passOnTheStack(mlir::Location loc, mlir::Type ty, bool isResult) const {
+    CodeGenSpecifics::Marshalling marshal;
+    auto sizeAndAlign =
+        fir::getTypeSizeAndAlignmentOrCrash(loc, ty, getDataLayout(), kindMap);
+    // The stack is always 8 byte aligned
+    unsigned short align =
+        std::max(sizeAndAlign.second, static_cast<unsigned short>(8));
+    marshal.emplace_back(fir::ReferenceType::get(ty),
+                         AT{align, /*byval=*/!isResult, /*sret=*/isResult});
+    return marshal;
   }
 
   // AArch64 procedure call ABI:
   // https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#parameter-passing
   CodeGenSpecifics::Marshalling
-  structReturnType(mlir::Location loc, fir::RecordType ty) const override {
+  structType(mlir::Location loc, fir::RecordType type, bool isResult) const {
+    NRegs nregs = usedRegsForRecordType(loc, type);
+
+    // If the type needs no registers it must need to be passed on the stack
+    if (nregs.n == 0)
+      return passOnTheStack(loc, type, isResult);
+
     CodeGenSpecifics::Marshalling marshal;
 
-    if (isHFA(ty)) {
-      // Just return the existing record type
-      marshal.emplace_back(ty, AT{});
-      return marshal;
+    mlir::Type pcsType;
+    if (nregs.isSimd) {
+      pcsType = type;
+    } else {
+      pcsType = fir::SequenceType::get(
+          nregs.n, mlir::IntegerType::get(type.getContext(), 64));
     }
 
-    auto [size, align] =
-        fir::getTypeSizeAndAlignmentOrCrash(loc, ty, getDataLayout(), kindMap);
+    marshal.emplace_back(pcsType, AT{});
+    return marshal;
+  }
 
-    // return in registers if size <= 16 bytes
-    if (size <= 16) {
-      std::size_t dwordSize = (size + 7) / 8;
-      auto newTy = fir::SequenceType::get(
-          dwordSize, mlir::IntegerType::get(ty.getContext(), 64));
-      marshal.emplace_back(newTy, AT{});
-      return marshal;
+  CodeGenSpecifics::Marshalling
+  structArgumentType(mlir::Location loc, fir::RecordType ty,
+                     const Marshalling &previousArguments) const override {
+    if (!hasEnoughRegisters(loc, ty, previousArguments)) {
+      return passOnTheStack(loc, ty, /*isResult=*/false);
     }
 
-    unsigned short stackAlign = std::max<unsigned short>(align, 8u);
-    marshal.emplace_back(fir::ReferenceType::get(ty),
-                         AT{stackAlign, false, true});
-    return marshal;
+    return structType(loc, ty, /*isResult=*/false);
+  }
+
+  CodeGenSpecifics::Marshalling
+  structReturnType(mlir::Location loc, fir::RecordType ty) const override {
+    return structType(loc, ty, /*isResult=*/true);
   }
 };
 } // namespace
diff --git a/flang/test/Fir/struct-passing-aarch64-byval.fir b/flang/test/Fir/struct-passing-aarch64-byval.fir
new file mode 100644
index 00000000000000..27143459dde2f2
--- /dev/null
+++ b/flang/test/Fir/struct-passing-aarch64-byval.fir
@@ -0,0 +1,73 @@
+// Test AArch64 ABI rewrite of struct passed by value (BIND(C), VALUE derived types).
+// RUN: fir-opt --target-rewrite="target=aarch64-unknown-linux-gnu" %s | FileCheck %s
+
+// CHECK-LABEL: func.func private @small_i32(!fir.array<2xi64>)
+func.func private @small_i32(!fir.type<small_i32{i:i32,j:i32,k:i32}>)
+// CHECK-LABEL: func.func private @small_i64(!fir.array<2xi64>)
+func.func private @small_i64(!fir.type<small_i64{i:i64,j:i64}>)
+// CHECK-LABEL: func.func private @small_mixed(!fir.array<2xi64>)
+func.func private @small_mixed(!fir.type<small_mixed{i:i64,j:f32,k:i32}>)
+// CHECK-LABEL: func.func private @small_non_hfa(!fir.array<2xi64>)
+func.func private @small_non_hfa(!fir.type<small_non_hfa{i:f64,j:f32,k:f16}>)
+
+// CHECK-LABEL: func.func private @hfa_f16(!fir.type<hfa_f16{i:f16,j:f16}>)
+func.func private @hfa_f16(!fir.type<hfa_f16{i:f16,j:f16}>)
+// CHECK-LABEL: func.func private @hfa_bf16(!fir.type<hfa_bf16{i:bf16,j:bf16,k:bf16,l:bf16}>)
+func.func private @hfa_bf16(!fir.type<hfa_bf16{i:bf16,j:bf16,k:bf16,l:bf16}>)
+// CHECK-LABEL: func.func private @hfa_f32(!fir.type<hfa_f32{i:f32,j:f32}>)
+func.func private @hfa_f32(!fir.type<hfa_f32{i:f32,j:f32}>)
+// CHECK-LABEL: func.func private @hfa_f64(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>)
+func.func private @hfa_f64(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>)
+// CHECK-LABEL: func.func private @hfa_f128(!fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>)
+func.func private @hfa_f128(!fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>)
+
+// CHECK-LABEL: func.func private @multi_small_integer(!fir.array<2xi64>, !fir.array<2xi64>)
+func.func private @multi_small_integer(!fir.type<small_i32{i:i32,j:i32,k:i32}>, !fir.type<small_i64{i:i64,j:i64}>)
+// CHECK-LABEL: func.func private @multi_hfas(!fir.type<hfa_f16{i:f16,j:f16}>, !fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>)
+func.func private @multi_hfas(!fir.type<hfa_f16{i:f16,j:f16}>, !fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>)
+// CHECK-LABEL: func.func private @multi_mixed(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>, !fir.array<2xi64>, !fir.type<hfa_f32{i:f32,j:f32}>, !fir.array<2xi64>)
+func.func private @multi_mixed(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>,!fir.type<small_non_hfa{i:f64,j:f32,k:f16}>,!fir.type<hfa_f32{i:f32,j:f32}>,!fir.type<small_i64{i:i64,j:i64}>)
+
+// CHECK-LABEL: func.func private @int_max(!fir.array<2xi64>,
+// CHECK-SAME: !fir.array<2xi64>,
+// CHECK-SAME: !fir.array<2xi64>,
+// CHECK-SAME: !fir.array<2xi64>)
+func.func private @int_max(!fir.type<int_max{i:i64,j:i64}>,
+                       !fir.type<int_max{i:i64,j:i64}>,
+                       !fir.type<int_max{i:i64,j:i64}>,
+                       !fir.type<int_max{i:i64,j:i64}>)
+// CHECK-LABEL: func.func private @hfa_max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>, !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>)
+func.func private @hfa_max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>, !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>)
+// CHECK-LABEL: func.func private @max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
+// CHECK-SAME: !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
+// CHECK-SAME: !fir.array<2xi64>,
+// CHECK-SAME: !fir.array<2xi64>,
+// CHECK-SAME: !fir.array<2xi64>,
+// CHECK-SAME: !fir.array<2xi64>)
+func.func private @max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
+                       !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
+                       !fir.type<int_max{i:i64,j:i64}>,
+                       !fir.type<int_max{i:i64,j:i64}>,
+                       !fir.type<int_max{i:i64,j:i64}>,
+                       !fir.type<int_max{i:i64,j:i64}>)
+
+
+// CHECK-LABEL: func.func private @too_many_int(!fir.array<2xi64>,
+// CHECK-SAME: !fir.array<2xi64>,
+// CHECK-SAME: !fir.array<2xi64>,
+// CHECK-SAME: !fir.array<2xi64>,
+// CHECK-SAME: !fir.ref<!fir.type<int_max{i:i64,j:i64}>> {{{.*}}, llvm.byval = !fir.type<int_max{i:i64,j:i64}>})
+func.func private @too_many_int(!fir.type<int_max{i:i64,j:i64}>,
+                       !fir.type<int_max{i:i64,j:i64}>,
+                       !fir.type<int_max{i:i64,j:i64}>,
+                       !fir.type<int_max{i:i64,j:i64}>,
+                       !fir.type<int_max{i:i64,j:i64}>)
+// CHECK-LABEL: func.func private @too_many_hfa(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
+// CHECK-SAME: !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
+// CHECK-SAME: !fir.ref<!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>> {{{.*}}, llvm.byval = !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>})
+func.func private @too_many_hfa(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
+                           !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
+                           !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>)
+
+// CHECK-LABEL: func.func private @too_big(!fir.ref<!fir.type<too_big{i:!fir.array<5xi32>}>> {{{.*}}, llvm.byval = !fir.type<too_big{i:!fir.array<5xi32>}>})
+func.func private @too_big(!fir.type<too_big{i:!fir.array<5xi32>}>)

.Case<fir::CharacterType>([&](auto) { return NRegs{1, false}; })
.Case<fir::SequenceType>([&](auto ty) {
NRegs nregs = usedRegsForType(loc, ty.getEleTy());
nregs.n *= ty.getShape()[0];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if this is an assumed shape array? Then the shape in the type might have placeholder numbers here.

On another note, I think get eleTy gets the scalar element type, no matter the rank. So wouldn't this need to work across all dimensions of the shape?

I'm not sure what we do about assumed rank here. I guess that is not allowed for bind(c)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arrays in general aren't actually allowed for BIND(C), nor as VALUE parameters even without BIND(C). I added the handling for scalar arrays here in case we use them elsewhere in lowering because it's easy to imagine what that would look like, and in case we use fir.array parameters/return values to represent other things but maybe we should just throw an error here?

Copy link
Contributor

@jeanPerier jeanPerier Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arrays in general aren't actually allowed for BIND(C), nor as VALUE parameters even without BIND(C).

Outside of the BIND(C) case array dummy arguments can have the VALUE attribute, but are always lowered in memory (+copy if needed) already.

Anyway, the extent it should never be unknown in this context, but you can always add an assert.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification

Copy link
Member Author

@DavidTruby DavidTruby Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outside of the BIND(C) case array dummy arguments can have the VALUE attribute

My mistake, I should have checked the standard rather than just checking gfortran :)

Comment on lines +934 to +939
int availIntRegisters = 8;
int availSIMDRegisters = 8;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to have a standard reference to make it clear where these numbers come from.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a comment linking to the spec lower down but I can move it up here, or maybe to the top of this class?

int availSIMDRegisters = 8;

// Check previous arguments to see how many registers are used already
for (auto [type, attr] : previousArguments) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels like a lot of time complexity because presumably every argument will check all of its previous arguments.

I guess in practice, this can't get too big because of the number of registers and you are limited by the API of struct argument type.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took this method from the x86 handling; I'm not sure there's another way to do it with how this is structured really.

Copy link
Contributor

@jeanPerier jeanPerier Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, it is a bit dumb, but I do not expect the BIND(C) VALUE struct argument usage to be high enough so I did not modify the logic/interface too much when I added the X86-64 impl.

The main "issue" is that we do not call the target lowering for all arguments, so the target lowering cannot maintain some register state properly.

Adding callbacks for "normal" arguments could arguably increase the cost more in general because of the virtual aspects of the callback than doing the computation "again" for the few BIND(C) VALUE struct arguments.

This patch adds handling for derived type VALUE parameters in BIND(C)
functions for AArch64.
@DavidTruby DavidTruby merged commit 44aa476 into llvm:main Dec 18, 2024
8 checks passed
@DavidTruby DavidTruby deleted the bindc-value branch December 18, 2024 07:44
mmuetzel referenced this pull request in msys2/MINGW-packages Mar 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:codegen flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants