Skip to content

Commit 79e788d

Browse files
authored
[flang][AIX] BIND(C) derived type alignment for AIX (#121505)
This patch is to handle the alignment requirement for the `bind(c)` derived type component that is real type and larger than 4 bytes. The alignment of such component is 4-byte.
1 parent e2c49a4 commit 79e788d

File tree

13 files changed

+340
-35
lines changed

13 files changed

+340
-35
lines changed

flang/include/flang/Optimizer/CodeGen/TypeConverter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
6262
// fir.type<name(p : TY'...){f : TY...}> --> llvm<"%name = { ty... }">
6363
std::optional<llvm::LogicalResult>
6464
convertRecordType(fir::RecordType derived,
65-
llvm::SmallVectorImpl<mlir::Type> &results);
65+
llvm::SmallVectorImpl<mlir::Type> &results, bool isPacked);
6666

6767
// Is an extended descriptor needed given the element type of a fir.box type ?
6868
// Extended descriptors are required for derived types.

flang/include/flang/Optimizer/Dialect/FIRTypes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,12 @@ def fir_RecordType : FIR_Type<"Record", "type"> {
346346
void finalize(llvm::ArrayRef<TypePair> lenPList,
347347
llvm::ArrayRef<TypePair> typeList);
348348

349+
// fir.type is unpacked by default. If the flag is set, the packed fir.type
350+
// is generated and the alignment is enforced by explicit padding by i8
351+
// array fields.
352+
bool isPacked() const;
353+
void pack(bool);
354+
349355
detail::RecordTypeStorage const *uniqueKey() const;
350356
}];
351357
}

flang/lib/Lower/ConvertType.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
#include "mlir/IR/Builders.h"
2121
#include "mlir/IR/BuiltinTypes.h"
2222
#include "llvm/Support/Debug.h"
23+
#include "llvm/TargetParser/Host.h"
24+
#include "llvm/TargetParser/Triple.h"
2325

2426
#define DEBUG_TYPE "flang-lower-type"
2527

@@ -385,9 +387,20 @@ struct TypeBuilderImpl {
385387
// with dozens of components/parents (modern Fortran).
386388
derivedTypeInConstruction.try_emplace(&derivedScope, rec);
387389

390+
auto targetTriple{llvm::Triple(
391+
llvm::Triple::normalize(llvm::sys::getDefaultTargetTriple()))};
392+
// Always generate packed FIR struct type for bind(c) derived type for AIX
393+
if (targetTriple.getOS() == llvm::Triple::OSType::AIX &&
394+
tySpec.typeSymbol().attrs().test(Fortran::semantics::Attr::BIND_C) &&
395+
!IsIsoCType(&tySpec)) {
396+
rec.pack(true);
397+
}
398+
388399
// Gather the record type fields.
389400
// (1) The data components.
390401
if (converter.getLoweringOptions().getLowerToHighLevelFIR()) {
402+
size_t prev_offset{0};
403+
unsigned padCounter{0};
391404
// In HLFIR the parent component is the first fir.type component.
392405
for (const auto &componentName :
393406
typeSymbol.get<Fortran::semantics::DerivedTypeDetails>()
@@ -397,7 +410,38 @@ struct TypeBuilderImpl {
397410
"failed to find derived type component symbol");
398411
const Fortran::semantics::Symbol &component = scopeIter->second.get();
399412
mlir::Type ty = genSymbolType(component);
413+
if (rec.isPacked()) {
414+
auto compSize{component.size()};
415+
auto compOffset{component.offset()};
416+
417+
if (prev_offset < compOffset) {
418+
size_t pad{compOffset - prev_offset};
419+
mlir::Type i8Ty{mlir::IntegerType::get(context, 8)};
420+
fir::SequenceType::Shape shape{static_cast<int64_t>(pad)};
421+
mlir::Type padTy{fir::SequenceType::get(shape, i8Ty)};
422+
prev_offset += pad;
423+
cs.emplace_back("__padding" + std::to_string(padCounter++), padTy);
424+
}
425+
prev_offset += compSize;
426+
}
400427
cs.emplace_back(converter.getRecordTypeFieldName(component), ty);
428+
if (rec.isPacked()) {
429+
// For the last component, determine if any padding is needed.
430+
if (componentName ==
431+
typeSymbol.get<Fortran::semantics::DerivedTypeDetails>()
432+
.componentNames()
433+
.back()) {
434+
auto compEnd{component.offset() + component.size()};
435+
if (compEnd < derivedScope.size()) {
436+
size_t pad{derivedScope.size() - compEnd};
437+
mlir::Type i8Ty{mlir::IntegerType::get(context, 8)};
438+
fir::SequenceType::Shape shape{static_cast<int64_t>(pad)};
439+
mlir::Type padTy{fir::SequenceType::get(shape, i8Ty)};
440+
cs.emplace_back("__padding" + std::to_string(padCounter++),
441+
padTy);
442+
}
443+
}
444+
}
401445
}
402446
} else {
403447
for (const auto &component :

flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ class BoxprocTypeRewriter : public mlir::TypeConverter {
167167
cs.emplace_back(t.first, t.second);
168168
}
169169
rec.finalize(ps, cs);
170+
rec.pack(ty.isPacked());
170171
return rec;
171172
});
172173
addConversion([&](TypeDescType ty) {

flang/lib/Optimizer/CodeGen/TypeConverter.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA,
8282
[&](fir::PointerType pointer) { return convertPointerLike(pointer); });
8383
addConversion(
8484
[&](fir::RecordType derived, llvm::SmallVectorImpl<mlir::Type> &results) {
85-
return convertRecordType(derived, results);
85+
return convertRecordType(derived, results, derived.isPacked());
8686
});
8787
addConversion(
8888
[&](fir::ReferenceType ref) { return convertPointerLike(ref); });
@@ -133,8 +133,10 @@ mlir::Type LLVMTypeConverter::indexType() const {
133133
}
134134

135135
// fir.type<name(p : TY'...){f : TY...}> --> llvm<"%name = { ty... }">
136-
std::optional<llvm::LogicalResult> LLVMTypeConverter::convertRecordType(
137-
fir::RecordType derived, llvm::SmallVectorImpl<mlir::Type> &results) {
136+
std::optional<llvm::LogicalResult>
137+
LLVMTypeConverter::convertRecordType(fir::RecordType derived,
138+
llvm::SmallVectorImpl<mlir::Type> &results,
139+
bool isPacked) {
138140
auto name = fir::NameUniquer::dropTypeConversionMarkers(derived.getName());
139141
auto st = mlir::LLVM::LLVMStructType::getIdentified(&getContext(), name);
140142

@@ -156,7 +158,7 @@ std::optional<llvm::LogicalResult> LLVMTypeConverter::convertRecordType(
156158
else
157159
members.push_back(mlir::cast<mlir::Type>(convertType(mem.second)));
158160
}
159-
if (mlir::failed(st.setBody(members, /*isPacked=*/false)))
161+
if (mlir::failed(st.setBody(members, isPacked)))
160162
return mlir::failure();
161163
results.push_back(st);
162164
return mlir::success();

flang/lib/Optimizer/Dialect/FIRType.cpp

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,16 +165,20 @@ struct RecordTypeStorage : public mlir::TypeStorage {
165165
setTypeList(typeList);
166166
}
167167

168+
bool isPacked() const { return packed; }
169+
void pack(bool p) { packed = p; }
170+
168171
protected:
169172
std::string name;
170173
bool finalized;
174+
bool packed;
171175
std::vector<RecordType::TypePair> lens;
172176
std::vector<RecordType::TypePair> types;
173177

174178
private:
175179
RecordTypeStorage() = delete;
176180
explicit RecordTypeStorage(llvm::StringRef name)
177-
: name{name}, finalized{false} {}
181+
: name{name}, finalized{false}, packed{false} {}
178182
};
179183

180184
} // namespace detail
@@ -872,9 +876,14 @@ llvm::LogicalResult fir::PointerType::verify(
872876
//===----------------------------------------------------------------------===//
873877

874878
// Fortran derived type
879+
// unpacked:
875880
// `type` `<` name
876881
// (`(` id `:` type (`,` id `:` type)* `)`)?
877882
// (`{` id `:` type (`,` id `:` type)* `}`)? '>'
883+
// packed:
884+
// `type` `<` name
885+
// (`(` id `:` type (`,` id `:` type)* `)`)?
886+
// (`<{` id `:` type (`,` id `:` type)* `}>`)? '>'
878887
mlir::Type fir::RecordType::parse(mlir::AsmParser &parser) {
879888
llvm::StringRef name;
880889
if (parser.parseLess() || parser.parseKeyword(&name))
@@ -900,6 +909,10 @@ mlir::Type fir::RecordType::parse(mlir::AsmParser &parser) {
900909
}
901910

902911
RecordType::TypeList typeList;
912+
if (!parser.parseOptionalLess()) {
913+
result.pack(true);
914+
}
915+
903916
if (!parser.parseOptionalLBrace()) {
904917
while (true) {
905918
llvm::StringRef field;
@@ -913,8 +926,10 @@ mlir::Type fir::RecordType::parse(mlir::AsmParser &parser) {
913926
if (parser.parseOptionalComma())
914927
break;
915928
}
916-
if (parser.parseRBrace())
917-
return {};
929+
if (parser.parseOptionalGreater()) {
930+
if (parser.parseRBrace())
931+
return {};
932+
}
918933
}
919934

920935
if (parser.parseGreater())
@@ -941,13 +956,19 @@ void fir::RecordType::print(mlir::AsmPrinter &printer) const {
941956
printer << ')';
942957
}
943958
if (getTypeList().size()) {
959+
if (isPacked()) {
960+
printer << '<';
961+
}
944962
char ch = '{';
945963
for (auto p : getTypeList()) {
946964
printer << ch << p.first << ':';
947965
p.second.print(printer.getStream());
948966
ch = ',';
949967
}
950968
printer << '}';
969+
if (isPacked()) {
970+
printer << '>';
971+
}
951972
}
952973
recordTypeVisited.erase(uniqueKey());
953974
}
@@ -973,6 +994,10 @@ RecordType::TypeList fir::RecordType::getLenParamList() const {
973994

974995
bool fir::RecordType::isFinalized() const { return getImpl()->isFinalized(); }
975996

997+
void fir::RecordType::pack(bool p) { getImpl()->pack(p); }
998+
999+
bool fir::RecordType::isPacked() const { return getImpl()->isPacked(); }
1000+
9761001
detail::RecordTypeStorage const *fir::RecordType::uniqueKey() const {
9771002
return getImpl();
9781003
}

flang/lib/Semantics/compute-offsets.cpp

Lines changed: 83 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
#include "flang/Semantics/symbol.h"
1818
#include "flang/Semantics/tools.h"
1919
#include "flang/Semantics/type.h"
20+
#include "llvm/TargetParser/Host.h"
21+
#include "llvm/TargetParser/Triple.h"
2022
#include <algorithm>
2123
#include <vector>
2224

@@ -51,9 +53,12 @@ class ComputeOffsetsHelper {
5153
SymbolAndOffset Resolve(const SymbolAndOffset &);
5254
std::size_t ComputeOffset(const EquivalenceObject &);
5355
// Returns amount of padding that was needed for alignment
54-
std::size_t DoSymbol(Symbol &);
56+
std::size_t DoSymbol(
57+
Symbol &, std::optional<const size_t> newAlign = std::nullopt);
5558
SizeAndAlignment GetSizeAndAlignment(const Symbol &, bool entire);
5659
std::size_t Align(std::size_t, std::size_t);
60+
std::optional<size_t> CompAlignment(const Symbol &);
61+
std::optional<size_t> HasSpecialAlign(const Symbol &, Scope &);
5762

5863
SemanticsContext &context_;
5964
std::size_t offset_{0};
@@ -65,6 +70,69 @@ class ComputeOffsetsHelper {
6570
equivalenceBlock_;
6671
};
6772

73+
// This function is only called if the target platform is AIX.
74+
static bool isReal8OrLarger(const Fortran::semantics::DeclTypeSpec *type) {
75+
return ((type->IsNumeric(common::TypeCategory::Real) ||
76+
type->IsNumeric(common::TypeCategory::Complex)) &&
77+
evaluate::ToInt64(type->numericTypeSpec().kind()) > 4);
78+
}
79+
80+
// This function is only called if the target platform is AIX.
81+
// It determines the alignment of a component. If the component is a derived
82+
// type, the alignment is computed accordingly.
83+
std::optional<size_t> ComputeOffsetsHelper::CompAlignment(const Symbol &sym) {
84+
size_t max_align{0};
85+
constexpr size_t fourByteAlign{4};
86+
bool contain_double{false};
87+
auto derivedTypeSpec{sym.GetType()->AsDerived()};
88+
DirectComponentIterator directs{*derivedTypeSpec};
89+
for (auto it{directs.begin()}; it != directs.end(); ++it) {
90+
auto type{it->GetType()};
91+
auto s{GetSizeAndAlignment(*it, true)};
92+
if (isReal8OrLarger(type)) {
93+
max_align = std::max(max_align, fourByteAlign);
94+
contain_double = true;
95+
} else if (type->AsDerived()) {
96+
if (const auto newAlgin{CompAlignment(*it)}) {
97+
max_align = std::max(max_align, s.alignment);
98+
} else {
99+
return std::nullopt;
100+
}
101+
} else {
102+
max_align = std::max(max_align, s.alignment);
103+
}
104+
}
105+
106+
if (contain_double) {
107+
return max_align;
108+
} else {
109+
return std::nullopt;
110+
}
111+
}
112+
113+
// This function is only called if the target platform is AIX.
114+
// Special alignment is needed only if it is a bind(c) derived type
115+
// and contain real type components that have larger than 4 bytes.
116+
std::optional<size_t> ComputeOffsetsHelper::HasSpecialAlign(
117+
const Symbol &sym, Scope &scope) {
118+
// On AIX, if the component that is not the first component and is
119+
// a float of 8 bytes or larger, it has the 4-byte alignment.
120+
// Only set the special alignment for bind(c) derived type on that platform.
121+
if (const auto type{sym.GetType()}) {
122+
auto &symOwner{sym.owner()};
123+
if (symOwner.symbol() && symOwner.IsDerivedType() &&
124+
symOwner.symbol()->attrs().HasAny({semantics::Attr::BIND_C}) &&
125+
&sym != &(*scope.GetSymbols().front())) {
126+
if (isReal8OrLarger(type)) {
127+
return 4UL;
128+
} else if (type->AsDerived()) {
129+
return CompAlignment(sym);
130+
}
131+
}
132+
}
133+
return std::nullopt;
134+
}
135+
68136
void ComputeOffsetsHelper::Compute(Scope &scope) {
69137
for (Scope &child : scope.children()) {
70138
ComputeOffsets(context_, child);
@@ -113,7 +181,15 @@ void ComputeOffsetsHelper::Compute(Scope &scope) {
113181
if (!FindCommonBlockContaining(*symbol) &&
114182
dependents_.find(symbol) == dependents_.end() &&
115183
equivalenceBlock_.find(symbol) == equivalenceBlock_.end()) {
116-
DoSymbol(*symbol);
184+
185+
std::optional<size_t> newAlign{std::nullopt};
186+
// Handle special alignment requirement for AIX
187+
auto triple{llvm::Triple(
188+
llvm::Triple::normalize(llvm::sys::getDefaultTargetTriple()))};
189+
if (triple.getOS() == llvm::Triple::OSType::AIX) {
190+
newAlign = HasSpecialAlign(*symbol, scope);
191+
}
192+
DoSymbol(*symbol, newAlign);
117193
if (auto *generic{symbol->detailsIf<GenericDetails>()}) {
118194
if (Symbol * specific{generic->specific()};
119195
specific && !FindCommonBlockContaining(*specific)) {
@@ -313,7 +389,8 @@ std::size_t ComputeOffsetsHelper::ComputeOffset(
313389
return result;
314390
}
315391

316-
std::size_t ComputeOffsetsHelper::DoSymbol(Symbol &symbol) {
392+
std::size_t ComputeOffsetsHelper::DoSymbol(
393+
Symbol &symbol, std::optional<const size_t> newAlign) {
317394
if (!symbol.has<ObjectEntityDetails>() && !symbol.has<ProcEntityDetails>()) {
318395
return 0;
319396
}
@@ -322,12 +399,13 @@ std::size_t ComputeOffsetsHelper::DoSymbol(Symbol &symbol) {
322399
return 0;
323400
}
324401
std::size_t previousOffset{offset_};
325-
offset_ = Align(offset_, s.alignment);
402+
size_t alignVal{newAlign.value_or(s.alignment)};
403+
offset_ = Align(offset_, alignVal);
326404
std::size_t padding{offset_ - previousOffset};
327405
symbol.set_size(s.size);
328406
symbol.set_offset(offset_);
329407
offset_ += s.size;
330-
alignment_ = std::max(alignment_, s.alignment);
408+
alignment_ = std::max(alignment_, alignVal);
331409
return padding;
332410
}
333411

flang/test/Lower/CUDA/cuda-devptr.cuf

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ end
3838

3939
! CHECK-LABEL: func.func @_QPsub2()
4040
! CHECK: %[[X:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFsub2Ex"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
41-
! CHECK: %[[CPTR:.*]] = fir.field_index cptr, !fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>
42-
! CHECK: %[[CPTR_COORD:.*]] = fir.coordinate_of %{{.*}}#1, %[[CPTR]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>>, !fir.field) -> !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>
41+
! CHECK: %[[CPTR:.*]] = fir.field_index cptr, !fir.type<_QM__fortran_builtinsT__builtin_c_devptr{{[<]?}}{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}{{[>]?}}>
42+
! CHECK: %[[CPTR_COORD:.*]] = fir.coordinate_of %{{.*}}#1, %[[CPTR]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{{[<]?}}{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}{{[>]?}}>>, !fir.field) -> !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>
4343
! CHECK: %[[ADDRESS:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
4444
! CHECK: %[[ADDRESS_COORD:.*]] = fir.coordinate_of %[[CPTR_COORD]], %[[ADDRESS]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
4545
! CHECK: %[[ADDRESS_LOADED:.*]] = fir.load %[[ADDRESS_COORD]] : !fir.ref<i64>

0 commit comments

Comments
 (0)