Skip to content

[flang] Add struct passing target rewrite hooks and partial X86-64 impl #74829

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Dec 12, 2023

Conversation

jeanPerier
Copy link
Contributor

In the context of C/Fortran interoperability (BIND(C)), it is possible to give the VALUE attribute to a BIND(C) derived type dummy, which according to Fortran 2018 18.3.6 - 2. (4) implies that it must be passed like the equivalent C structure value. The way C structure value are passed is ABI dependent.

LLVM does not implement the C struct ABI passing for LLVM aggregate type arguments. It is up to the front-end, like clang is doing, to split the struct into registers or pass the struct on the stack (llvm "byval") as required by the target ABI.
So the logic for C struct passing sits in clang. Using it from flang requires setting up a lot of clang context and to bridge FIR/MLIR representation to clang AST representation for function signatures (in both directions). It is a non trivial task.
See https://stackoverflow.com/questions/39438033/passing-structs-by-value-in-llvm-ir/75002581#75002581.

Since BIND(C) struct are rather limited as opposed to generic C struct (e.g. no bit fields). It is easier to provide a limited implementation of it for the case that matter to Fortran.

This patch:

  • Updates the generic target rewrite pass to keep track of both the new argument type and attributes. The motivation for this is to be able to tell if a previously marshalled argument is passed in memory (it is a C pointer), or if it is being passed on the stack (has the byval llvm attributes).
  • Adds an entry point in the target specific codegen to marshal struct arguments, and use it in the generic target rewrite pass.
  • Implements limited support for the X86-64 case. So far, the support allows telling if a struct must be passed in register or on the stack, and to deal with the stack case. The register case is left TODO in this patch.

The X86-64 ABI implemented is the System V ABI for AMD64 version 1.0

In the context of C/Fortran interoperability (BIND(C)), it is possible
to give the VALUE attribute to a BIND(C) derived type dummy, which
according to Fortran 2018 18.3.6 - 2. (4) implies that it must be passed
like the equivalent C structure value. The way C structure value are
passed is ABI dependent.

LLVM does not implement the C struct ABI passing for LLVM aggregate type
arguments. It is up to the front-end, like clang is doing, to split the
struct into registers or pass the struct on the stack (llvm "byval") as
required by the target ABI.
So the logic for C struct passing sits in clang. Using it from flang
requires setting up a lot of clang context and to bridge FIR/MLIR
representation to clang AST representation for function signatures (in
both directions). It is a non trivial task.
See https://stackoverflow.com/questions/39438033/passing-structs-by-value-in-llvm-ir/75002581#75002581.

Since BIND(C) struct are rather limited as opposed to generic C struct
(e.g. no bit fields). It is easier to provide a limited implementation
of it for the case that matter to Fortran.

This patch:
- Updates the generic target rewrite pass to keep track of both the new
  argument type and attributes. The motivation for this is to be able
  to tell if a previously marshalled argument is passed in memory (it is
  a C pointer), or if it is being passed on the stack (has the byval
  llvm attributes).
- Adds an entry point in the target specific codegen to marshal struct
  arguments, and use it in the generic target rewrite pass.
- Implements limited support for the X86-64 case. So far, the support
  allows telling if a struct must be passed in register or on the stack,
  and to deal with the stack case. The register case is left TODO in
  this patch.

The X86-64 ABI implemented is the System V ABI for AMD64 version 1.0
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:codegen labels Dec 8, 2023
@llvmbot
Copy link
Member

llvmbot commented Dec 8, 2023

@llvm/pr-subscribers-flang-codegen

@llvm/pr-subscribers-flang-fir-hlfir

Author: None (jeanPerier)

Changes

In the context of C/Fortran interoperability (BIND(C)), it is possible to give the VALUE attribute to a BIND(C) derived type dummy, which according to Fortran 2018 18.3.6 - 2. (4) implies that it must be passed like the equivalent C structure value. The way C structure value are passed is ABI dependent.

LLVM does not implement the C struct ABI passing for LLVM aggregate type arguments. It is up to the front-end, like clang is doing, to split the struct into registers or pass the struct on the stack (llvm "byval") as required by the target ABI.
So the logic for C struct passing sits in clang. Using it from flang requires setting up a lot of clang context and to bridge FIR/MLIR representation to clang AST representation for function signatures (in both directions). It is a non trivial task.
See https://stackoverflow.com/questions/39438033/passing-structs-by-value-in-llvm-ir/75002581#75002581.

Since BIND(C) struct are rather limited as opposed to generic C struct (e.g. no bit fields). It is easier to provide a limited implementation of it for the case that matter to Fortran.

This patch:

  • Updates the generic target rewrite pass to keep track of both the new argument type and attributes. The motivation for this is to be able to tell if a previously marshalled argument is passed in memory (it is a C pointer), or if it is being passed on the stack (has the byval llvm attributes).
  • Adds an entry point in the target specific codegen to marshal struct arguments, and use it in the generic target rewrite pass.
  • Implements limited support for the X86-64 case. So far, the support allows telling if a struct must be passed in register or on the stack, and to deal with the stack case. The register case is left TODO in this patch.

The X86-64 ABI implemented is the System V ABI for AMD64 version 1.0


Patch is 77.99 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/74829.diff

13 Files Affected:

  • (modified) flang/include/flang/Optimizer/CodeGen/CGPasses.td (+3-2)
  • (modified) flang/include/flang/Optimizer/CodeGen/Target.h (+29-5)
  • (modified) flang/include/flang/Optimizer/CodeGen/TypeConverter.h (+5-1)
  • (modified) flang/include/flang/Optimizer/Dialect/FIRTypes.td (+10)
  • (modified) flang/include/flang/Optimizer/Support/DataLayout.h (+12)
  • (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+13-1)
  • (modified) flang/lib/Optimizer/CodeGen/Target.cpp (+322-14)
  • (modified) flang/lib/Optimizer/CodeGen/TargetRewrite.cpp (+296-163)
  • (modified) flang/lib/Optimizer/CodeGen/TypeConverter.cpp (+3-2)
  • (modified) flang/lib/Optimizer/Dialect/FIRType.cpp (+6)
  • (modified) flang/lib/Optimizer/Support/DataLayout.cpp (+13)
  • (modified) flang/test/Fir/recursive-type-tco.fir (+2-2)
  • (added) flang/test/Fir/struct-passing-x86-64-byval.fir (+103)
diff --git a/flang/include/flang/Optimizer/CodeGen/CGPasses.td b/flang/include/flang/Optimizer/CodeGen/CGPasses.td
index 0014298a27a22..5e47119582776 100644
--- a/flang/include/flang/Optimizer/CodeGen/CGPasses.td
+++ b/flang/include/flang/Optimizer/CodeGen/CGPasses.td
@@ -23,7 +23,7 @@ def FIRToLLVMLowering : Pass<"fir-to-llvm-ir", "mlir::ModuleOp"> {
     will also convert ops in the standard and FIRCG dialects.
   }];
   let constructor = "::fir::createFIRToLLVMPass()";
-  let dependentDialects = ["mlir::LLVM::LLVMDialect"];
+  let dependentDialects = ["mlir::LLVM::LLVMDialect", "mlir::DLTIDialect"];
   let options = [
     Option<"forcedTargetTriple", "target", "std::string", /*default=*/"",
            "Override module's target triple.">,
@@ -53,7 +53,8 @@ def TargetRewritePass : Pass<"target-rewrite", "mlir::ModuleOp"> {
       representations that may differ based on the target machine.
   }];
   let constructor = "::fir::createFirTargetRewritePass()";
-  let dependentDialects = [ "fir::FIROpsDialect", "mlir::func::FuncDialect" ];
+  let dependentDialects = [ "fir::FIROpsDialect", "mlir::func::FuncDialect",
+                            "mlir::DLTIDialect" ];
   let options = [
     Option<"forcedTargetTriple", "target", "std::string", /*default=*/"",
            "Override module's target triple.">,
diff --git a/flang/include/flang/Optimizer/CodeGen/Target.h b/flang/include/flang/Optimizer/CodeGen/Target.h
index acffe6c1cec9c..c3ef521ced120 100644
--- a/flang/include/flang/Optimizer/CodeGen/Target.h
+++ b/flang/include/flang/Optimizer/CodeGen/Target.h
@@ -13,6 +13,7 @@
 #ifndef FORTRAN_OPTMIZER_CODEGEN_TARGET_H
 #define FORTRAN_OPTMIZER_CODEGEN_TARGET_H
 
+#include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "llvm/TargetParser/Triple.h"
@@ -20,6 +21,10 @@
 #include <tuple>
 #include <vector>
 
+namespace mlir {
+class DataLayout;
+}
+
 namespace fir {
 
 namespace details {
@@ -62,14 +67,20 @@ class Attributes {
 class CodeGenSpecifics {
 public:
   using Attributes = details::Attributes;
-  using Marshalling = std::vector<std::tuple<mlir::Type, Attributes>>;
+  using TypeAndAttr = std::tuple<mlir::Type, Attributes>;
+  using Marshalling = std::vector<TypeAndAttr>;
+
+  static std::unique_ptr<CodeGenSpecifics> get(mlir::MLIRContext *ctx,
+                                               llvm::Triple &&trp,
+                                               KindMapping &&kindMap,
+                                               const mlir::DataLayout &dl);
 
-  static std::unique_ptr<CodeGenSpecifics>
-  get(mlir::MLIRContext *ctx, llvm::Triple &&trp, KindMapping &&kindMap);
+  static TypeAndAttr getTypeAndAttr(mlir::Type t) { return TypeAndAttr{t, {}}; }
 
   CodeGenSpecifics(mlir::MLIRContext *ctx, llvm::Triple &&trp,
-                   KindMapping &&kindMap)
-      : context{*ctx}, triple{std::move(trp)}, kindMap{std::move(kindMap)} {}
+                   KindMapping &&kindMap, const mlir::DataLayout &dl)
+      : context{*ctx}, triple{std::move(trp)}, kindMap{std::move(kindMap)},
+        dataLayout{&dl} {}
   CodeGenSpecifics() = delete;
   virtual ~CodeGenSpecifics() {}
 
@@ -90,6 +101,13 @@ class CodeGenSpecifics {
   /// Type presentation of a `boxchar<n>` type value in memory.
   virtual mlir::Type boxcharMemoryType(mlir::Type eleTy) const = 0;
 
+  /// Type representation of a `fir.type<T>` type argument when passed by
+  /// value. It may have to be split into several arguments, or be passed
+  /// as a byval reference argument (on the stack).
+  virtual Marshalling
+  structArgumentType(mlir::Location loc, fir::RecordType recTy,
+                     const Marshalling &previousArguments) const = 0;
+
   /// Type representation of a `boxchar<n>` type argument when passed by value.
   /// An argument value may need to be passed as a (safe) reference argument.
   ///
@@ -143,10 +161,16 @@ class CodeGenSpecifics {
   // Returns width in bits of C/C++ 'int' type size.
   virtual unsigned char getCIntTypeWidth() const = 0;
 
+  const mlir::DataLayout &getDataLayout() const {
+    assert(dataLayout && "dataLayout must be set");
+    return *dataLayout;
+  }
+
 protected:
   mlir::MLIRContext &context;
   llvm::Triple triple;
   KindMapping kindMap;
+  const mlir::DataLayout *dataLayout = nullptr;
 };
 
 } // namespace fir
diff --git a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
index 9ce756bdfd966..396c136392555 100644
--- a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
+++ b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
@@ -39,6 +39,10 @@ static constexpr unsigned kDimLowerBoundPos = 0;
 static constexpr unsigned kDimExtentPos = 1;
 static constexpr unsigned kDimStridePos = 2;
 
+namespace mlir {
+class DataLayout;
+}
+
 namespace fir {
 
 /// FIR type converter
@@ -46,7 +50,7 @@ namespace fir {
 class LLVMTypeConverter : public mlir::LLVMTypeConverter {
 public:
   LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA,
-                    bool forceUnifiedTBAATree);
+                    bool forceUnifiedTBAATree, const mlir::DataLayout &);
 
   // i32 is used here because LLVM wants i32 constants when indexing into struct
   // types. Indexing into other aggregate types is more flexible.
diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
index 51608e3c1d63e..2a2f50720859e 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
@@ -326,6 +326,8 @@ def fir_RealType : FIR_Type<"Real", "real"> {
 
   let extraClassDeclaration = [{
     using KindTy = unsigned;
+    // Get MLIR float type with same semantics.
+    mlir::Type getFloatType(const fir::KindMapping &kindMap) const;
   }];
 
   let genVerifyDecl = 1;
@@ -495,6 +497,14 @@ def fir_SequenceType : FIR_Type<"Sequence", "array"> {
     static constexpr Extent getUnknownExtent() {
       return mlir::ShapedType::kDynamic;
     }
+
+    std::uint64_t getConstantArraySize() {
+      assert(!hasDynamicExtents() && "array type must have constant shape");
+      std::uint64_t size = 1;
+      for (Extent extent : getShape())
+        size = size * static_cast<std::uint64_t>(extent);
+      return size;
+    }
   }];
 }
 
diff --git a/flang/include/flang/Optimizer/Support/DataLayout.h b/flang/include/flang/Optimizer/Support/DataLayout.h
index 88ff575a8ff08..d21576bb95f79 100644
--- a/flang/include/flang/Optimizer/Support/DataLayout.h
+++ b/flang/include/flang/Optimizer/Support/DataLayout.h
@@ -13,6 +13,9 @@
 #ifndef FORTRAN_OPTIMIZER_SUPPORT_DATALAYOUT_H
 #define FORTRAN_OPTIMIZER_SUPPORT_DATALAYOUT_H
 
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
+#include <optional>
+
 namespace mlir {
 class ModuleOp;
 }
@@ -34,6 +37,15 @@ void setMLIRDataLayout(mlir::ModuleOp mlirModule, const llvm::DataLayout &dl);
 /// nothing.
 void setMLIRDataLayoutFromAttributes(mlir::ModuleOp mlirModule,
                                      bool allowDefaultLayout);
+
+/// Create mlir::DataLayout from the data layout information on the
+/// mlir::Module. Creates the data layout information attributes with
+/// setMLIRDataLayoutFromAttributes if the DLTI attribute is not yet set. If no
+/// information is present at all and \p allowDefaultLayout is false, returns
+/// std::nullopt.
+std::optional<mlir::DataLayout>
+getOrSetDataLayout(mlir::ModuleOp mlirModule, bool allowDefaultLayout = false);
+
 } // namespace fir::support
 
 #endif // FORTRAN_OPTIMIZER_SUPPORT_DATALAYOUT_H
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index bf175c8ebadee..c0f3ea3241a77 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -16,6 +16,7 @@
 #include "flang/Optimizer/Dialect/FIRAttr.h"
 #include "flang/Optimizer/Dialect/FIROps.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Support/DataLayout.h"
 #include "flang/Optimizer/Support/InternalNames.h"
 #include "flang/Optimizer/Support/TypeCode.h"
 #include "flang/Optimizer/Support/Utils.h"
@@ -34,6 +35,7 @@
 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/Transforms/AddComdats.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
@@ -3820,10 +3822,20 @@ class FIRToLLVMLowering
     if (mlir::failed(runPipeline(mathConvertionPM, mod)))
       return signalPassFailure();
 
+    std::optional<mlir::DataLayout> dl =
+        fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/true);
+    if (!dl) {
+      mlir::emitError(mod.getLoc(),
+                      "module operation must carry a data layout attribute "
+                      "to generate llvm IR from FIR");
+      signalPassFailure();
+      return;
+    }
+
     auto *context = getModule().getContext();
     fir::LLVMTypeConverter typeConverter{getModule(),
                                          options.applyTBAA || applyTBAA,
-                                         options.forceUnifiedTBAATree};
+                                         options.forceUnifiedTBAATree, *dl};
     mlir::RewritePatternSet pattern(context);
     pattern.insert<
         AbsentOpConversion, AddcOpConversion, AddrOfOpConversion,
diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp
index 112f56e268c3b..ea10486a6b34c 100644
--- a/flang/lib/Optimizer/CodeGen/Target.cpp
+++ b/flang/lib/Optimizer/CodeGen/Target.cpp
@@ -18,6 +18,7 @@
 #include "flang/Optimizer/Support/Utils.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeRange.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 #define DEBUG_TYPE "flang-codegen-target"
 
@@ -58,6 +59,62 @@ static void typeTodo(const llvm::fltSemantics *sem, mlir::Location loc,
   }
 }
 
+/// Return the size of alignment of FIR types.
+/// TODO: consider moving this to a DataLayoutTypeInterface implementation
+/// for FIR types. It should first be ensured that it is OK to open the gate of
+/// target dependent type size inquiries in lowering. It would also not be
+/// straightforward given the need for a kind map that would need to be
+/// converted in terms of mlir::DataLayoutEntryKey.
+static std::pair<std::uint64_t, unsigned short>
+getSizeAndAlignment(mlir::Location loc, mlir::Type ty,
+                    const mlir::DataLayout &dl,
+                    const fir::KindMapping &kindMap) {
+  if (mlir::isa<mlir::IntegerType, mlir::FloatType, mlir::ComplexType>(ty)) {
+    llvm::TypeSize size = dl.getTypeSize(ty);
+    unsigned short alignment = dl.getTypeABIAlignment(ty);
+    return {size, alignment};
+  }
+  if (auto firCmplx = mlir::dyn_cast<fir::ComplexType>(ty)) {
+    auto [floatSize, floatAlign] =
+        getSizeAndAlignment(loc, firCmplx.getEleType(kindMap), dl, kindMap);
+    return {llvm::alignTo(floatSize, floatAlign) + floatSize, floatAlign};
+  }
+  if (auto real = mlir::dyn_cast<fir::RealType>(ty))
+    return getSizeAndAlignment(loc, real.getFloatType(kindMap), dl, kindMap);
+
+  if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) {
+    auto [eleSize, eleAlign] =
+        getSizeAndAlignment(loc, seqTy.getEleTy(), dl, kindMap);
+
+    std::uint64_t size =
+        llvm::alignTo(eleSize, eleAlign) * seqTy.getConstantArraySize();
+    return {size, eleAlign};
+  }
+  if (auto recTy = mlir::dyn_cast<fir::RecordType>(ty)) {
+    std::uint64_t size = 0;
+    unsigned short align = 8;
+    for (auto component : recTy.getTypeList()) {
+      auto [compSize, compAlign] =
+          getSizeAndAlignment(loc, component.second, dl, kindMap);
+      size =
+          llvm::alignTo(size, compAlign) + llvm::alignTo(compSize, compAlign);
+      align = std::max(align, compAlign);
+    }
+    return {size, align};
+  }
+  if (auto logical = mlir::dyn_cast<fir::LogicalType>(ty)) {
+    mlir::Type intTy = mlir::IntegerType::get(
+        logical.getContext(), kindMap.getLogicalBitsize(logical.getFKind()));
+    return getSizeAndAlignment(loc, intTy, dl, kindMap);
+  }
+  if (auto logical = mlir::dyn_cast<fir::CharacterType>(ty)) {
+    mlir::Type intTy = mlir::IntegerType::get(
+        logical.getContext(), kindMap.getLogicalBitsize(logical.getFKind()));
+    return getSizeAndAlignment(loc, intTy, dl, kindMap);
+  }
+  TODO(loc, "computing size of a component");
+}
+
 namespace {
 template <typename S>
 struct GenericTarget : public CodeGenSpecifics {
@@ -95,6 +152,12 @@ struct GenericTarget : public CodeGenSpecifics {
     return marshal;
   }
 
+  CodeGenSpecifics::Marshalling
+  structArgumentType(mlir::Location loc, fir::RecordType,
+                     const Marshalling &) const override {
+    TODO(loc, "passing VALUE BIND(C) derived type for this target");
+  }
+
   CodeGenSpecifics::Marshalling
   integerArgumentType(mlir::Location loc,
                       mlir::IntegerType argTy) const override {
@@ -318,6 +381,251 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
     }
     return marshal;
   }
+
+  /// X86-64 argument classes from System V ABI version 1.0 section 3.2.3.
+  enum ArgClass {
+    Integer = 0,
+    SSE,
+    SSEUp,
+    X87,
+    X87Up,
+    ComplexX87,
+    NoClass,
+    Memory
+  };
+
+  /// Classify an argument type or a field of an aggregate type argument.
+  /// See ystem V ABI version 1.0 section 3.2.3.
+  /// The Lo and Hi class are set to the class of the lower eight eightbytes
+  /// and upper eight eightbytes on return.
+  /// If this is called for an aggregate field, the caller is responsible to
+  /// do the post-merge.
+  void classify(mlir::Location loc, mlir::Type type, std::uint64_t byteOffset,
+                ArgClass &Lo, ArgClass &Hi) const {
+    Hi = Lo = ArgClass::NoClass;
+    ArgClass &current = byteOffset < 8 ? Lo : Hi;
+    // System V AMD64 ABI 3.2.3. version 1.0
+    llvm::TypeSwitch<mlir::Type>(type)
+        .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
+          if (intTy.getWidth() == 128)
+            Hi = Lo = ArgClass::Integer;
+          else
+            current = ArgClass::Integer;
+        })
+        .template Case<mlir::FloatType, fir::RealType>([&](mlir::Type floatTy) {
+          const auto *sem = &floatToSemantics(kindMap, floatTy);
+          if (sem == &llvm::APFloat::x87DoubleExtended()) {
+            Lo = ArgClass::X87;
+            Hi = ArgClass::X87Up;
+          } else if (sem == &llvm::APFloat::IEEEquad()) {
+            Lo = ArgClass::SSE;
+            Hi = ArgClass::SSEUp;
+          } else {
+            current = ArgClass::SSE;
+          }
+        })
+        .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
+          const auto *sem = &floatToSemantics(kindMap, cmplx.getElementType());
+          if (sem == &llvm::APFloat::x87DoubleExtended()) {
+            current = ArgClass::ComplexX87;
+          } else {
+            fir::SequenceType::Shape shape{2};
+            classifyArray(loc,
+                          fir::SequenceType::get(shape, cmplx.getElementType()),
+                          byteOffset, Lo, Hi);
+          }
+        })
+        .template Case<fir::LogicalType>([&](fir::LogicalType logical) {
+          if (kindMap.getLogicalBitsize(logical.getFKind()) == 128)
+            Hi = Lo = ArgClass::Integer;
+          else
+            current = ArgClass::Integer;
+        })
+        .template Case<fir::CharacterType>(
+            [&](fir::CharacterType character) { current = ArgClass::Integer; })
+        .template Case<fir::SequenceType>([&](fir::SequenceType seqTy) {
+          // Array component.
+          classifyArray(loc, seqTy, byteOffset, Lo, Hi);
+        })
+        .template Case<fir::RecordType>([&](fir::RecordType recTy) {
+          // Component that is a derived type.
+          classifyStruct(loc, recTy, byteOffset, Lo, Hi);
+        })
+        .template Case<fir::VectorType>([&](fir::VectorType vecTy) {
+          // Previously marshalled SSE eight byte for a previous struct
+          // argument.
+          auto *sem = fir::isa_real(vecTy.getEleTy())
+                          ? &floatToSemantics(kindMap, vecTy.getEleTy())
+                          : nullptr;
+          // Note expecting to hit this todo in standard code (it would
+          // require some vector type extension).
+          if (!(sem == &llvm::APFloat::IEEEsingle() && vecTy.getLen() <= 2) &&
+              !(sem == &llvm::APFloat::IEEEhalf() && vecTy.getLen() <= 4))
+            TODO(loc, "passing vector argument to C by value");
+          current = SSE;
+        })
+        .Default([&](mlir::Type ty) {
+          if (fir::conformsWithPassByRef(ty))
+            current = ArgClass::Integer; // Pointers.
+          else
+            TODO(loc, "unsupported component type for BIND(C), VALUE derived "
+                      "type argument");
+        });
+  }
+
+  // Classify fields of a derived type starting at \p offset. Returns the new
+  // offset. Post-merge is left to the caller.
+  std::uint64_t classifyStruct(mlir::Location loc, fir::RecordType recTy,
+                               std::uint64_t byteOffset, ArgClass &Lo,
+                               ArgClass &Hi) const {
+    for (auto component : recTy.getTypeList()) {
+      if (byteOffset > 16) {
+        Lo = Hi = ArgClass::Memory;
+        return byteOffset;
+      }
+      mlir::Type compType = component.second;
+      auto [compSize, compAlign] =
+          getSizeAndAlignment(loc, compType, getDataLayout(), kindMap);
+      byteOffset = llvm::alignTo(byteOffset, compAlign);
+      ArgClass LoComp, HiComp;
+      classify(loc, compType, byteOffset, LoComp, HiComp);
+      Lo = mergeClass(Lo, LoComp);
+      Hi = mergeClass(Hi, HiComp);
+      byteOffset = byteOffset + llvm::alignTo(compSize, compAlign);
+      if (Lo == ArgClass::Memory || Hi == ArgClass::Memory)
+        return byteOffset;
+    }
+    return byteOffset;
+  }
+
+  // Classify fields of a constant size array type starting at \p offset.
+  // Returns the new offset. Post-merge is left to the caller.
+  void classifyArray(mlir::Location loc, fir::SequenceType seqTy,
+                     std::uint64_t byteOffset, ArgClass &Lo,
+                     ArgClass &Hi) const {
+    mlir::Type eleTy = seqTy.getEleTy();
+    const std::uint64_t arraySize = seqTy.getConstantArraySize();
+    auto [eleSize, eleAlign] =
+        getSizeAndAlignment(loc, eleTy, getDataLayout(), kindMap);
+    std::uint64_t eleStorageSize = llvm::alignTo(eleSize, eleAlign);
+    for (std::uint64_t i = 0; i < arraySize; ++i) {
+      byteOffset = llvm::alignTo(byteOffset, eleAlign);
+      if (byteOffset > 16) {
+        Lo = Hi = ArgClass::Memory;
+        return;
+      }
+      ArgClass LoComp, HiComp;
+      classify(loc, eleTy, byteOffset, LoComp, HiComp);
+      Lo = mergeClass(Lo, LoComp);
+      Hi = mergeClass(Hi, HiComp);
+      byteOffset = byteOffset + eleStorageSize;
+      if (Lo == ArgClass::Memory || Hi == ArgClass::Memory)
+        return;
+    }
+  }
+
+  // Goes through the previously marshalled arguments and count the
+  // register occupancy to check if there are enough registers left.
+  bool hasEnoughRegisters(mlir::Location loc, int neededIntRegisters,
+                          int neededSSERegisters,
+                          const Marshalling &previousArguments) const {
+    int availIntRegisters = 6;
+    int availSSERegisters = 8;
+    for (auto typeAndAttr : previousArguments) {
+      const auto &attr = std::get<Attributes>(typeAndAttr);
+      if (attr.isByVal() || attr.isSRet())
+        continue; // Previous argument passed on the stack.
+      ArgClass Lo, Hi;
+      Lo = Hi = ArgClass::NoClass;
+      classify(loc, std::get<mlir::Type>(typeAndAttr), 0, Lo, Hi);
+      // post merge is not needed here since previous aggregate arguments
+      // were marshalled into simpler arguments.
+      if (Lo == ArgClass::Integer)
+        --availIntRegisters;
+      else if (Lo == SSE)
+        --availSSERegisters;
+      if (Hi == ArgClass::Integer)
+        --availIntRegisters;
+      else if (Hi == ArgClass::SSE)
+        --availSSERegisters;
+    }
+    return availSSERegisters...
[truncated]

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The approach looks good. I don't know enough about the X86 ABI to comment on any specifics.

Copy link
Contributor

@psteinfeld psteinfeld left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't understand most of it, but what I understood looks good.

I found one nit and one compilation error that needs to be fixed.

jeanPerier added a commit to jeanPerier/llvm-project that referenced this pull request Dec 8, 2023
VALUE derived type are passed by reference outside of BIND(C) interface.
The ABI is much simpler and it is possible for these arguments to
have the OPTIONAL attribute.

In the BIND(C) context, these arguments must follow the C ABI for
struct, which may lead the data to be passed in register. OPTIONAL
is also forbidden for those arguments, so it is safe to directly use
the fir.type<T> type for the func.func argument.

Codegen is in charge of later applying the C passing ABI according to
the target (llvm#74829).
@kiranchandramohan
Copy link
Contributor

kiranchandramohan commented Dec 8, 2023

Thanks @jeanPerier for the patch. I have not had a chance to look at the patch in detail. I have two general questions,
-> Are you looking for help implementing the target specific parts for other targets?
-> The summary talks specifically about BIND(C) here, but doesn't the rules apply to derived types passed by value (with VALUE attribute) in general. Say a derived type that looks like a complex type internally, wouldn't that need some target specific handling like Complex?

@jeanPerier
Copy link
Contributor Author

I found one nit and one compilation error that needs to be fixed.

Thanks @psteinfeld. I fixed the assert issue. What is the nit you mentionned?

-> Are you looking for help implementing the target specific parts for other targets?

Currently no. I briefly looked at Aarch64 rules, and it seems similar/much simpler. Derived type bigger than 16 bytes are passed in memory (without the byval). And smaller structs gets rewritten in a single LLVM array argument [size x eleTy] align.

My long term hope is for some C/C++ compiler to increase support for C ABI in MLIR or for someone to go through the big project of extracting/bridging with clang ABIInfo class (https://github.com/llvm/llvm-project/blame/46a56931251eba767929f6a2110da5b1bcbc5eb9/clang/lib/CodeGen/ABIInfo.h#L68) that is not even public right now. But right now this would side track us a lot from Fortran language support for little gain (BIND(C) VALUE seems to still be a corner case that is not very used).

-> The summary talks specifically about BIND(C) here, but doesn't the rules apply to derived types passed by value (with VALUE attribute) in general. Say a derived type that looks like a complex type internally, wouldn't that need some target specific handling like Complex?

Without the BIND(C) attribute, derived type by value are passed in memory in lowering. Three reasons for this:

  • to support the OPTIONAL VALUE case that is not possible with BIND(C) (Fortran 2015 C865).
  • the ABI is much simpler/target independent.
  • that is what most compilers are doing.

So the need to pass a derived type as a C struct is a corner case that only occurs with BIND(C) + VALUE.

@psteinfeld
Copy link
Contributor

I found one nit and one compilation error that needs to be fixed.

Thanks @psteinfeld. I fixed the assert issue. What is the nit you mentionned?
In flang/lib/Optimizer/CodeGen/Target.cpp, line 399, "ystem" should read "system".

All builds and tests successfully now for me.

Thanks for the quick fix!

Copy link
Contributor

@psteinfeld psteinfeld left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

Copy link
Contributor

@vzakhari vzakhari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for working on this, Jean!

Copy link
Contributor Author

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reviews

Copy link
Contributor

@vzakhari vzakhari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the update, Jean! I have only one minor comment about 0 alignment.

@jeanPerier jeanPerier merged commit 27d9a47 into llvm:main Dec 12, 2023
@jeanPerier jeanPerier deleted the jpr-struct-abi branch December 12, 2023 10:52
jeanPerier added a commit that referenced this pull request Dec 12, 2023
VALUE derived type are passed by reference outside of BIND(C) interface.
The ABI is much simpler and it is possible for these arguments to have
the OPTIONAL attribute.

In the BIND(C) context, these arguments must follow the C ABI for
struct, which may lead the data to be passed in register. OPTIONAL is
also forbidden for those arguments, so it is safe to directly use the
fir.type<T> type for the func.func argument.

Codegen is in charge of later applying the C passing ABI according to
the target (#74829).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:codegen flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants