Skip to content

[flang][AIX] BIND(C) derived type alignment for AIX #121505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flang/include/flang/Optimizer/CodeGen/TypeConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
// fir.type<name(p : TY'...){f : TY...}> --> llvm<"%name = { ty... }">
std::optional<llvm::LogicalResult>
convertRecordType(fir::RecordType derived,
llvm::SmallVectorImpl<mlir::Type> &results);
llvm::SmallVectorImpl<mlir::Type> &results, bool isPacked);

// Is an extended descriptor needed given the element type of a fir.box type ?
// Extended descriptors are required for derived types.
Expand Down
6 changes: 6 additions & 0 deletions flang/include/flang/Optimizer/Dialect/FIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,12 @@ def fir_RecordType : FIR_Type<"Record", "type"> {
void finalize(llvm::ArrayRef<TypePair> lenPList,
llvm::ArrayRef<TypePair> typeList);

// fir.type is unpacked by default. If the flag is set, the packed fir.type
// is generated and the alignment is enforced by explicit padding by i8
// array fields.
bool isPacked() const;
void pack(bool);

detail::RecordTypeStorage const *uniqueKey() const;
}];
}
Expand Down
44 changes: 44 additions & 0 deletions flang/lib/Lower/ConvertType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/Support/Debug.h"
#include "llvm/TargetParser/Host.h"
#include "llvm/TargetParser/Triple.h"

#define DEBUG_TYPE "flang-lower-type"

Expand Down Expand Up @@ -385,9 +387,20 @@ struct TypeBuilderImpl {
// with dozens of components/parents (modern Fortran).
derivedTypeInConstruction.try_emplace(&derivedScope, rec);

auto targetTriple{llvm::Triple(
llvm::Triple::normalize(llvm::sys::getDefaultTargetTriple()))};
// Always generate packed FIR struct type for bind(c) derived type for AIX
if (targetTriple.getOS() == llvm::Triple::OSType::AIX &&
tySpec.typeSymbol().attrs().test(Fortran::semantics::Attr::BIND_C) &&
!IsIsoCType(&tySpec)) {
rec.pack(true);
}

// Gather the record type fields.
// (1) The data components.
if (converter.getLoweringOptions().getLowerToHighLevelFIR()) {
size_t prev_offset{0};
unsigned padCounter{0};
// In HLFIR the parent component is the first fir.type component.
for (const auto &componentName :
typeSymbol.get<Fortran::semantics::DerivedTypeDetails>()
Expand All @@ -397,7 +410,38 @@ struct TypeBuilderImpl {
"failed to find derived type component symbol");
const Fortran::semantics::Symbol &component = scopeIter->second.get();
mlir::Type ty = genSymbolType(component);
if (rec.isPacked()) {
auto compSize{component.size()};
auto compOffset{component.offset()};

if (prev_offset < compOffset) {
size_t pad{compOffset - prev_offset};
mlir::Type i8Ty{mlir::IntegerType::get(context, 8)};
fir::SequenceType::Shape shape{static_cast<int64_t>(pad)};
mlir::Type padTy{fir::SequenceType::get(shape, i8Ty)};
prev_offset += pad;
cs.emplace_back("__padding" + std::to_string(padCounter++), padTy);
}
prev_offset += compSize;
}
cs.emplace_back(converter.getRecordTypeFieldName(component), ty);
if (rec.isPacked()) {
// For the last component, determine if any padding is needed.
if (componentName ==
typeSymbol.get<Fortran::semantics::DerivedTypeDetails>()
.componentNames()
.back()) {
auto compEnd{component.offset() + component.size()};
if (compEnd < derivedScope.size()) {
size_t pad{derivedScope.size() - compEnd};
mlir::Type i8Ty{mlir::IntegerType::get(context, 8)};
fir::SequenceType::Shape shape{static_cast<int64_t>(pad)};
mlir::Type padTy{fir::SequenceType::get(shape, i8Ty)};
cs.emplace_back("__padding" + std::to_string(padCounter++),
padTy);
}
}
}
}
} else {
for (const auto &component :
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ class BoxprocTypeRewriter : public mlir::TypeConverter {
cs.emplace_back(t.first, t.second);
}
rec.finalize(ps, cs);
rec.pack(ty.isPacked());
return rec;
});
addConversion([&](TypeDescType ty) {
Expand Down
10 changes: 6 additions & 4 deletions flang/lib/Optimizer/CodeGen/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA,
[&](fir::PointerType pointer) { return convertPointerLike(pointer); });
addConversion(
[&](fir::RecordType derived, llvm::SmallVectorImpl<mlir::Type> &results) {
return convertRecordType(derived, results);
return convertRecordType(derived, results, derived.isPacked());
});
addConversion(
[&](fir::ReferenceType ref) { return convertPointerLike(ref); });
Expand Down Expand Up @@ -133,8 +133,10 @@ mlir::Type LLVMTypeConverter::indexType() const {
}

// fir.type<name(p : TY'...){f : TY...}> --> llvm<"%name = { ty... }">
std::optional<llvm::LogicalResult> LLVMTypeConverter::convertRecordType(
fir::RecordType derived, llvm::SmallVectorImpl<mlir::Type> &results) {
std::optional<llvm::LogicalResult>
LLVMTypeConverter::convertRecordType(fir::RecordType derived,
llvm::SmallVectorImpl<mlir::Type> &results,
bool isPacked) {
auto name = fir::NameUniquer::dropTypeConversionMarkers(derived.getName());
auto st = mlir::LLVM::LLVMStructType::getIdentified(&getContext(), name);

Expand All @@ -156,7 +158,7 @@ std::optional<llvm::LogicalResult> LLVMTypeConverter::convertRecordType(
else
members.push_back(mlir::cast<mlir::Type>(convertType(mem.second)));
}
if (mlir::failed(st.setBody(members, /*isPacked=*/false)))
if (mlir::failed(st.setBody(members, isPacked)))
return mlir::failure();
results.push_back(st);
return mlir::success();
Expand Down
31 changes: 28 additions & 3 deletions flang/lib/Optimizer/Dialect/FIRType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,16 +165,20 @@ struct RecordTypeStorage : public mlir::TypeStorage {
setTypeList(typeList);
}

bool isPacked() const { return packed; }
void pack(bool p) { packed = p; }

protected:
std::string name;
bool finalized;
bool packed;
std::vector<RecordType::TypePair> lens;
std::vector<RecordType::TypePair> types;

private:
RecordTypeStorage() = delete;
explicit RecordTypeStorage(llvm::StringRef name)
: name{name}, finalized{false} {}
: name{name}, finalized{false}, packed{false} {}
};

} // namespace detail
Expand Down Expand Up @@ -872,9 +876,14 @@ llvm::LogicalResult fir::PointerType::verify(
//===----------------------------------------------------------------------===//

// Fortran derived type
// unpacked:
// `type` `<` name
// (`(` id `:` type (`,` id `:` type)* `)`)?
// (`{` id `:` type (`,` id `:` type)* `}`)? '>'
// packed:
// `type` `<` name
// (`(` id `:` type (`,` id `:` type)* `)`)?
// (`<{` id `:` type (`,` id `:` type)* `}>`)? '>'
mlir::Type fir::RecordType::parse(mlir::AsmParser &parser) {
llvm::StringRef name;
if (parser.parseLess() || parser.parseKeyword(&name))
Expand All @@ -900,6 +909,10 @@ mlir::Type fir::RecordType::parse(mlir::AsmParser &parser) {
}

RecordType::TypeList typeList;
if (!parser.parseOptionalLess()) {
result.pack(true);
}

if (!parser.parseOptionalLBrace()) {
while (true) {
llvm::StringRef field;
Expand All @@ -913,8 +926,10 @@ mlir::Type fir::RecordType::parse(mlir::AsmParser &parser) {
if (parser.parseOptionalComma())
break;
}
if (parser.parseRBrace())
return {};
if (parser.parseOptionalGreater()) {
if (parser.parseRBrace())
return {};
}
}

if (parser.parseGreater())
Expand All @@ -941,13 +956,19 @@ void fir::RecordType::print(mlir::AsmPrinter &printer) const {
printer << ')';
}
if (getTypeList().size()) {
if (isPacked()) {
printer << '<';
}
char ch = '{';
for (auto p : getTypeList()) {
printer << ch << p.first << ':';
p.second.print(printer.getStream());
ch = ',';
}
printer << '}';
if (isPacked()) {
printer << '>';
}
}
recordTypeVisited.erase(uniqueKey());
}
Expand All @@ -973,6 +994,10 @@ RecordType::TypeList fir::RecordType::getLenParamList() const {

bool fir::RecordType::isFinalized() const { return getImpl()->isFinalized(); }

void fir::RecordType::pack(bool p) { getImpl()->pack(p); }

bool fir::RecordType::isPacked() const { return getImpl()->isPacked(); }

detail::RecordTypeStorage const *fir::RecordType::uniqueKey() const {
return getImpl();
}
Expand Down
88 changes: 83 additions & 5 deletions flang/lib/Semantics/compute-offsets.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/tools.h"
#include "flang/Semantics/type.h"
#include "llvm/TargetParser/Host.h"
#include "llvm/TargetParser/Triple.h"
#include <algorithm>
#include <vector>

Expand Down Expand Up @@ -51,9 +53,12 @@ class ComputeOffsetsHelper {
SymbolAndOffset Resolve(const SymbolAndOffset &);
std::size_t ComputeOffset(const EquivalenceObject &);
// Returns amount of padding that was needed for alignment
std::size_t DoSymbol(Symbol &);
std::size_t DoSymbol(
Symbol &, std::optional<const size_t> newAlign = std::nullopt);
SizeAndAlignment GetSizeAndAlignment(const Symbol &, bool entire);
std::size_t Align(std::size_t, std::size_t);
std::optional<size_t> CompAlignment(const Symbol &);
std::optional<size_t> HasSpecialAlign(const Symbol &, Scope &);

SemanticsContext &context_;
std::size_t offset_{0};
Expand All @@ -65,6 +70,69 @@ class ComputeOffsetsHelper {
equivalenceBlock_;
};

// This function is only called if the target platform is AIX.
static bool isReal8OrLarger(const Fortran::semantics::DeclTypeSpec *type) {
return ((type->IsNumeric(common::TypeCategory::Real) ||
type->IsNumeric(common::TypeCategory::Complex)) &&
evaluate::ToInt64(type->numericTypeSpec().kind()) > 4);
}

// This function is only called if the target platform is AIX.
// It determines the alignment of a component. If the component is a derived
// type, the alignment is computed accordingly.
std::optional<size_t> ComputeOffsetsHelper::CompAlignment(const Symbol &sym) {
size_t max_align{0};
constexpr size_t fourByteAlign{4};
bool contain_double{false};
auto derivedTypeSpec{sym.GetType()->AsDerived()};
DirectComponentIterator directs{*derivedTypeSpec};
for (auto it{directs.begin()}; it != directs.end(); ++it) {
auto type{it->GetType()};
auto s{GetSizeAndAlignment(*it, true)};
if (isReal8OrLarger(type)) {
max_align = std::max(max_align, fourByteAlign);
contain_double = true;
} else if (type->AsDerived()) {
if (const auto newAlgin{CompAlignment(*it)}) {
max_align = std::max(max_align, s.alignment);
} else {
return std::nullopt;
}
} else {
max_align = std::max(max_align, s.alignment);
}
}

if (contain_double) {
return max_align;
} else {
return std::nullopt;
}
}

// This function is only called if the target platform is AIX.
// Special alignment is needed only if it is a bind(c) derived type
// and contain real type components that have larger than 4 bytes.
std::optional<size_t> ComputeOffsetsHelper::HasSpecialAlign(
const Symbol &sym, Scope &scope) {
// On AIX, if the component that is not the first component and is
// a float of 8 bytes or larger, it has the 4-byte alignment.
// Only set the special alignment for bind(c) derived type on that platform.
if (const auto type{sym.GetType()}) {
auto &symOwner{sym.owner()};
if (symOwner.symbol() && symOwner.IsDerivedType() &&
symOwner.symbol()->attrs().HasAny({semantics::Attr::BIND_C}) &&
&sym != &(*scope.GetSymbols().front())) {
if (isReal8OrLarger(type)) {
return 4UL;
} else if (type->AsDerived()) {
return CompAlignment(sym);
}
}
}
return std::nullopt;
}

void ComputeOffsetsHelper::Compute(Scope &scope) {
for (Scope &child : scope.children()) {
ComputeOffsets(context_, child);
Expand Down Expand Up @@ -113,7 +181,15 @@ void ComputeOffsetsHelper::Compute(Scope &scope) {
if (!FindCommonBlockContaining(*symbol) &&
dependents_.find(symbol) == dependents_.end() &&
equivalenceBlock_.find(symbol) == equivalenceBlock_.end()) {
DoSymbol(*symbol);

std::optional<size_t> newAlign{std::nullopt};
// Handle special alignment requirement for AIX
auto triple{llvm::Triple(
llvm::Triple::normalize(llvm::sys::getDefaultTargetTriple()))};
if (triple.getOS() == llvm::Triple::OSType::AIX) {
newAlign = HasSpecialAlign(*symbol, scope);
}
DoSymbol(*symbol, newAlign);
if (auto *generic{symbol->detailsIf<GenericDetails>()}) {
if (Symbol * specific{generic->specific()};
specific && !FindCommonBlockContaining(*specific)) {
Expand Down Expand Up @@ -313,7 +389,8 @@ std::size_t ComputeOffsetsHelper::ComputeOffset(
return result;
}

std::size_t ComputeOffsetsHelper::DoSymbol(Symbol &symbol) {
std::size_t ComputeOffsetsHelper::DoSymbol(
Symbol &symbol, std::optional<const size_t> newAlign) {
if (!symbol.has<ObjectEntityDetails>() && !symbol.has<ProcEntityDetails>()) {
return 0;
}
Expand All @@ -322,12 +399,13 @@ std::size_t ComputeOffsetsHelper::DoSymbol(Symbol &symbol) {
return 0;
}
std::size_t previousOffset{offset_};
offset_ = Align(offset_, s.alignment);
size_t alignVal{newAlign.value_or(s.alignment)};
offset_ = Align(offset_, alignVal);
std::size_t padding{offset_ - previousOffset};
symbol.set_size(s.size);
symbol.set_offset(offset_);
offset_ += s.size;
alignment_ = std::max(alignment_, s.alignment);
alignment_ = std::max(alignment_, alignVal);
return padding;
}

Expand Down
4 changes: 2 additions & 2 deletions flang/test/Lower/CUDA/cuda-devptr.cuf
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ end

! CHECK-LABEL: func.func @_QPsub2()
! CHECK: %[[X:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFsub2Ex"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
! CHECK: %[[CPTR:.*]] = fir.field_index cptr, !fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>
! CHECK: %[[CPTR_COORD:.*]] = fir.coordinate_of %{{.*}}#1, %[[CPTR]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>>, !fir.field) -> !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>
! CHECK: %[[CPTR:.*]] = fir.field_index cptr, !fir.type<_QM__fortran_builtinsT__builtin_c_devptr{{[<]?}}{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}{{[>]?}}>
! CHECK: %[[CPTR_COORD:.*]] = fir.coordinate_of %{{.*}}#1, %[[CPTR]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{{[<]?}}{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}{{[>]?}}>>, !fir.field) -> !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>
! CHECK: %[[ADDRESS:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
! CHECK: %[[ADDRESS_COORD:.*]] = fir.coordinate_of %[[CPTR_COORD]], %[[ADDRESS]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
! CHECK: %[[ADDRESS_LOADED:.*]] = fir.load %[[ADDRESS_COORD]] : !fir.ref<i64>
Expand Down
Loading
Loading