Skip to content

Commit 3e47e75

Browse files
authored
[flang] Use DataLayout for computing type size in LoopVersioning. (#79778)
The existing type size computation in LoopVersioning does not work for REAL*10, because the compute element size is 10 bytes, which violates the power-of-two assertion. We'd better use the DataLayout for computing the storage size of each element of an array of the given type.
1 parent 72d4fc1 commit 3e47e75

File tree

5 files changed

+168
-66
lines changed

5 files changed

+168
-66
lines changed

flang/include/flang/Optimizer/Dialect/FIRType.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "mlir/IR/BuiltinAttributes.h"
1717
#include "mlir/IR/BuiltinTypes.h"
18+
#include "mlir/Interfaces/DataLayoutInterfaces.h"
1819
#include "llvm/ADT/SmallVector.h"
1920
#include "llvm/IR/Type.h"
2021

@@ -465,6 +466,17 @@ inline bool isBoxProcAddressType(mlir::Type t) {
465466
std::string getTypeAsString(mlir::Type ty, const KindMapping &kindMap,
466467
llvm::StringRef prefix = "");
467468

469+
/// Return the size and alignment of FIR types.
470+
/// TODO: consider moving this to a DataLayoutTypeInterface implementation
471+
/// for FIR types. It should first be ensured that it is OK to open the gate of
472+
/// target dependent type size inquiries in lowering. It would also not be
473+
/// straightforward given the need for a kind map that would need to be
474+
/// converted in terms of mlir::DataLayoutEntryKey.
475+
std::pair<std::uint64_t, unsigned short>
476+
getTypeSizeAndAlignment(mlir::Location loc, mlir::Type ty,
477+
const mlir::DataLayout &dl,
478+
const fir::KindMapping &kindMap);
479+
468480
} // namespace fir
469481

470482
#endif // FORTRAN_OPTIMIZER_DIALECT_FIRTYPE_H

flang/lib/Optimizer/CodeGen/Target.cpp

Lines changed: 4 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -59,63 +59,6 @@ static void typeTodo(const llvm::fltSemantics *sem, mlir::Location loc,
5959
}
6060
}
6161

62-
/// Return the size and alignment of FIR types.
63-
/// TODO: consider moving this to a DataLayoutTypeInterface implementation
64-
/// for FIR types. It should first be ensured that it is OK to open the gate of
65-
/// target dependent type size inquiries in lowering. It would also not be
66-
/// straightforward given the need for a kind map that would need to be
67-
/// converted in terms of mlir::DataLayoutEntryKey.
68-
static std::pair<std::uint64_t, unsigned short>
69-
getSizeAndAlignment(mlir::Location loc, mlir::Type ty,
70-
const mlir::DataLayout &dl,
71-
const fir::KindMapping &kindMap) {
72-
if (mlir::isa<mlir::IntegerType, mlir::FloatType, mlir::ComplexType>(ty)) {
73-
llvm::TypeSize size = dl.getTypeSize(ty);
74-
unsigned short alignment = dl.getTypeABIAlignment(ty);
75-
return {size, alignment};
76-
}
77-
if (auto firCmplx = mlir::dyn_cast<fir::ComplexType>(ty)) {
78-
auto [floatSize, floatAlign] =
79-
getSizeAndAlignment(loc, firCmplx.getEleType(kindMap), dl, kindMap);
80-
return {llvm::alignTo(floatSize, floatAlign) + floatSize, floatAlign};
81-
}
82-
if (auto real = mlir::dyn_cast<fir::RealType>(ty))
83-
return getSizeAndAlignment(loc, real.getFloatType(kindMap), dl, kindMap);
84-
85-
if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) {
86-
auto [eleSize, eleAlign] =
87-
getSizeAndAlignment(loc, seqTy.getEleTy(), dl, kindMap);
88-
89-
std::uint64_t size =
90-
llvm::alignTo(eleSize, eleAlign) * seqTy.getConstantArraySize();
91-
return {size, eleAlign};
92-
}
93-
if (auto recTy = mlir::dyn_cast<fir::RecordType>(ty)) {
94-
std::uint64_t size = 0;
95-
unsigned short align = 1;
96-
for (auto component : recTy.getTypeList()) {
97-
auto [compSize, compAlign] =
98-
getSizeAndAlignment(loc, component.second, dl, kindMap);
99-
size =
100-
llvm::alignTo(size, compAlign) + llvm::alignTo(compSize, compAlign);
101-
align = std::max(align, compAlign);
102-
}
103-
return {size, align};
104-
}
105-
if (auto logical = mlir::dyn_cast<fir::LogicalType>(ty)) {
106-
mlir::Type intTy = mlir::IntegerType::get(
107-
logical.getContext(), kindMap.getLogicalBitsize(logical.getFKind()));
108-
return getSizeAndAlignment(loc, intTy, dl, kindMap);
109-
}
110-
if (auto character = mlir::dyn_cast<fir::CharacterType>(ty)) {
111-
mlir::Type intTy = mlir::IntegerType::get(
112-
character.getContext(),
113-
kindMap.getCharacterBitsize(character.getFKind()));
114-
return getSizeAndAlignment(loc, intTy, dl, kindMap);
115-
}
116-
TODO(loc, "computing size of a component");
117-
}
118-
11962
namespace {
12063
template <typename S>
12164
struct GenericTarget : public CodeGenSpecifics {
@@ -489,7 +432,7 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
489432
}
490433
mlir::Type compType = component.second;
491434
auto [compSize, compAlign] =
492-
getSizeAndAlignment(loc, compType, getDataLayout(), kindMap);
435+
fir::getTypeSizeAndAlignment(loc, compType, getDataLayout(), kindMap);
493436
byteOffset = llvm::alignTo(byteOffset, compAlign);
494437
ArgClass LoComp, HiComp;
495438
classify(loc, compType, byteOffset, LoComp, HiComp);
@@ -510,7 +453,7 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
510453
mlir::Type eleTy = seqTy.getEleTy();
511454
const std::uint64_t arraySize = seqTy.getConstantArraySize();
512455
auto [eleSize, eleAlign] =
513-
getSizeAndAlignment(loc, eleTy, getDataLayout(), kindMap);
456+
fir::getTypeSizeAndAlignment(loc, eleTy, getDataLayout(), kindMap);
514457
std::uint64_t eleStorageSize = llvm::alignTo(eleSize, eleAlign);
515458
for (std::uint64_t i = 0; i < arraySize; ++i) {
516459
byteOffset = llvm::alignTo(byteOffset, eleAlign);
@@ -697,7 +640,8 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
697640
CodeGenSpecifics::Marshalling passOnTheStack(mlir::Location loc,
698641
mlir::Type ty) const {
699642
CodeGenSpecifics::Marshalling marshal;
700-
auto sizeAndAlign = getSizeAndAlignment(loc, ty, getDataLayout(), kindMap);
643+
auto sizeAndAlign =
644+
fir::getTypeSizeAndAlignment(loc, ty, getDataLayout(), kindMap);
701645
// The stack is always 8 byte aligned (note 14 in 3.2.3).
702646
unsigned short align =
703647
std::max(sizeAndAlign.second, static_cast<unsigned short>(8));

flang/lib/Optimizer/Dialect/FIRType.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "flang/Optimizer/Dialect/FIRType.h"
1414
#include "flang/ISO_Fortran_binding_wrapper.h"
15+
#include "flang/Optimizer/Builder/Todo.h"
1516
#include "flang/Optimizer/Dialect/FIRDialect.h"
1617
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
1718
#include "flang/Tools/PointerModels.h"
@@ -1339,3 +1340,55 @@ void FIROpsDialect::registerTypes() {
13391340
fir::LLVMPointerType::attachInterface<
13401341
OpenACCPointerLikeModel<fir::LLVMPointerType>>(*getContext());
13411342
}
1343+
1344+
std::pair<std::uint64_t, unsigned short>
1345+
fir::getTypeSizeAndAlignment(mlir::Location loc, mlir::Type ty,
1346+
const mlir::DataLayout &dl,
1347+
const fir::KindMapping &kindMap) {
1348+
if (mlir::isa<mlir::IntegerType, mlir::FloatType, mlir::ComplexType>(ty)) {
1349+
llvm::TypeSize size = dl.getTypeSize(ty);
1350+
unsigned short alignment = dl.getTypeABIAlignment(ty);
1351+
return {size, alignment};
1352+
}
1353+
if (auto firCmplx = mlir::dyn_cast<fir::ComplexType>(ty)) {
1354+
auto [floatSize, floatAlign] =
1355+
getTypeSizeAndAlignment(loc, firCmplx.getEleType(kindMap), dl, kindMap);
1356+
return {llvm::alignTo(floatSize, floatAlign) + floatSize, floatAlign};
1357+
}
1358+
if (auto real = mlir::dyn_cast<fir::RealType>(ty))
1359+
return getTypeSizeAndAlignment(loc, real.getFloatType(kindMap), dl,
1360+
kindMap);
1361+
1362+
if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) {
1363+
auto [eleSize, eleAlign] =
1364+
getTypeSizeAndAlignment(loc, seqTy.getEleTy(), dl, kindMap);
1365+
1366+
std::uint64_t size =
1367+
llvm::alignTo(eleSize, eleAlign) * seqTy.getConstantArraySize();
1368+
return {size, eleAlign};
1369+
}
1370+
if (auto recTy = mlir::dyn_cast<fir::RecordType>(ty)) {
1371+
std::uint64_t size = 0;
1372+
unsigned short align = 1;
1373+
for (auto component : recTy.getTypeList()) {
1374+
auto [compSize, compAlign] =
1375+
getTypeSizeAndAlignment(loc, component.second, dl, kindMap);
1376+
size =
1377+
llvm::alignTo(size, compAlign) + llvm::alignTo(compSize, compAlign);
1378+
align = std::max(align, compAlign);
1379+
}
1380+
return {size, align};
1381+
}
1382+
if (auto logical = mlir::dyn_cast<fir::LogicalType>(ty)) {
1383+
mlir::Type intTy = mlir::IntegerType::get(
1384+
logical.getContext(), kindMap.getLogicalBitsize(logical.getFKind()));
1385+
return getTypeSizeAndAlignment(loc, intTy, dl, kindMap);
1386+
}
1387+
if (auto character = mlir::dyn_cast<fir::CharacterType>(ty)) {
1388+
mlir::Type intTy = mlir::IntegerType::get(
1389+
character.getContext(),
1390+
kindMap.getCharacterBitsize(character.getFKind()));
1391+
return getTypeSizeAndAlignment(loc, intTy, dl, kindMap);
1392+
}
1393+
TODO(loc, "computing size of a component");
1394+
}

flang/lib/Optimizer/Transforms/LoopVersioning.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
#include "flang/Optimizer/Dialect/FIRType.h"
5050
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
5151
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
52+
#include "flang/Optimizer/Support/DataLayout.h"
5253
#include "flang/Optimizer/Transforms/Passes.h"
5354
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
5455
#include "mlir/IR/Dominance.h"
@@ -241,6 +242,12 @@ void LoopVersioningPass::runOnOperation() {
241242
mlir::ModuleOp module = func->getParentOfType<mlir::ModuleOp>();
242243
fir::KindMapping kindMap = fir::getKindMapping(module);
243244
mlir::SmallVector<ArgInfo, 4> argsOfInterest;
245+
std::optional<mlir::DataLayout> dl =
246+
fir::support::getOrSetDataLayout(module, /*allowDefaultLayout=*/false);
247+
if (!dl)
248+
mlir::emitError(module.getLoc(),
249+
"data layout attribute is required to perform " DEBUG_TYPE
250+
"pass");
244251
for (auto &arg : args) {
245252
// Optional arguments must be checked for IsPresent before
246253
// looking for the bounds. They are unsupported for the time being.
@@ -256,11 +263,13 @@ void LoopVersioningPass::runOnOperation() {
256263
seqTy.getShape()[0] == fir::SequenceType::getUnknownExtent()) {
257264
size_t typeSize = 0;
258265
mlir::Type elementType = fir::unwrapSeqOrBoxedSeqType(arg.getType());
259-
if (elementType.isa<mlir::FloatType>() ||
260-
elementType.isa<mlir::IntegerType>())
261-
typeSize = elementType.getIntOrFloatBitWidth() / 8;
262-
else if (auto cty = elementType.dyn_cast<fir::ComplexType>())
263-
typeSize = 2 * cty.getEleType(kindMap).getIntOrFloatBitWidth() / 8;
266+
if (mlir::isa<mlir::FloatType>(elementType) ||
267+
mlir::isa<mlir::IntegerType>(elementType) ||
268+
mlir::isa<fir::ComplexType>(elementType)) {
269+
auto [eleSize, eleAlign] = fir::getTypeSizeAndAlignment(
270+
arg.getLoc(), elementType, *dl, kindMap);
271+
typeSize = llvm::alignTo(eleSize, eleAlign);
272+
}
264273
if (typeSize)
265274
argsOfInterest.push_back({arg, typeSize, rank, {}});
266275
else

flang/test/Transforms/loop-versioning.fir

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
// sum = sum + a(i)
1212
// end do
1313
// end subroutine sum1d
14-
module {
14+
module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
1515
func.func @sum1d(%arg0: !fir.box<!fir.array<?xf64>> {fir.bindc_name = "a"}, %arg1: !fir.ref<i32> {fir.bindc_name = "n"}) {
1616
%decl = fir.declare %arg0 {uniq_name = "a"} : (!fir.box<!fir.array<?xf64>>) -> !fir.box<!fir.array<?xf64>>
1717
%rebox = fir.rebox %decl : (!fir.box<!fir.array<?xf64>>) -> !fir.box<!fir.array<?xf64>>
@@ -1556,5 +1556,89 @@ func.func @minloc(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "x"}, %ar
15561556
// CHECK: fir.if %{{.*}} {
15571557
// CHECK: {{.*}} = arith.cmpi eq, %[[V17]], %c2147483647_i32
15581558

1559+
func.func @_QPtest_real10(%arg0: !fir.box<!fir.array<?x?xf80>> {fir.bindc_name = "a"}) -> f80 {
1560+
%c10 = arith.constant 10 : index
1561+
%c1 = arith.constant 1 : index
1562+
%cst = arith.constant 0.000000e+00 : f80
1563+
%0 = fir.declare %arg0 {fortran_attrs = #fir.var_attrs<intent_in>, uniq_name = "_QFtest_real10Ea"} : (!fir.box<!fir.array<?x?xf80>>) -> !fir.box<!fir.array<?x?xf80>>
1564+
%1 = fir.rebox %0 : (!fir.box<!fir.array<?x?xf80>>) -> !fir.box<!fir.array<?x?xf80>>
1565+
%2 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFtest_real10Ei"}
1566+
%3 = fir.declare %2 {uniq_name = "_QFtest_real10Ei"} : (!fir.ref<i32>) -> !fir.ref<i32>
1567+
%4 = fir.alloca f80 {bindc_name = "res", uniq_name = "_QFtest_real10Eres"}
1568+
%5 = fir.declare %4 {uniq_name = "_QFtest_real10Eres"} : (!fir.ref<f80>) -> !fir.ref<f80>
1569+
%6 = fir.address_of(@_QFtest_real10ECxdp) : !fir.ref<i32>
1570+
%7 = fir.declare %6 {fortran_attrs = #fir.var_attrs<parameter>, uniq_name = "_QFtest_real10ECxdp"} : (!fir.ref<i32>) -> !fir.ref<i32>
1571+
fir.store %cst to %5 : !fir.ref<f80>
1572+
%8 = fir.convert %c1 : (index) -> i32
1573+
%9:2 = fir.do_loop %arg1 = %c1 to %c10 step %c1 iter_args(%arg2 = %8) -> (index, i32) {
1574+
fir.store %arg2 to %3 : !fir.ref<i32>
1575+
%11 = fir.load %5 : !fir.ref<f80>
1576+
%12 = fir.load %3 : !fir.ref<i32>
1577+
%13 = fir.convert %12 : (i32) -> i64
1578+
%14 = fir.array_coor %1 %13, %13 : (!fir.box<!fir.array<?x?xf80>>, i64, i64) -> !fir.ref<f80>
1579+
%15 = fir.load %14 : !fir.ref<f80>
1580+
%16 = arith.addf %11, %15 fastmath<contract> : f80
1581+
fir.store %16 to %5 : !fir.ref<f80>
1582+
%17 = arith.addi %arg1, %c1 : index
1583+
%18 = fir.load %3 : !fir.ref<i32>
1584+
%19 = arith.addi %18, %8 : i32
1585+
fir.result %17, %19 : index, i32
1586+
}
1587+
fir.store %9#1 to %3 : !fir.ref<i32>
1588+
%10 = fir.load %5 : !fir.ref<f80>
1589+
return %10 : f80
1590+
}
1591+
// CHECK-LABEL: func.func @_QPtest_real10(
1592+
// CHECK: fir.if
1593+
// CHECK: fir.do_loop
1594+
// CHECK-DAG: arith.shrsi %{{[^,]*}}, %[[SHIFT:.*]] : index
1595+
// CHECK-DAG: %[[SHIFT]] = arith.constant 4 : index
1596+
// CHECK: fir.result
1597+
// CHECK: } else {
1598+
// CHECK: fir.do_loop
1599+
1600+
func.func @_QPtest_complex10(%arg0: !fir.box<!fir.array<?x?x!fir.complex<10>>> {fir.bindc_name = "a"}) -> !fir.complex<10> {
1601+
%c10 = arith.constant 10 : index
1602+
%c1 = arith.constant 1 : index
1603+
%cst = arith.constant 0.000000e+00 : f80
1604+
%0 = fir.declare %arg0 {fortran_attrs = #fir.var_attrs<intent_in>, uniq_name = "_QFtest_complex10Ea"} : (!fir.box<!fir.array<?x?x!fir.complex<10>>>) -> !fir.box<!fir.array<?x?x!fir.complex<10>>>
1605+
%1 = fir.rebox %0 : (!fir.box<!fir.array<?x?x!fir.complex<10>>>) -> !fir.box<!fir.array<?x?x!fir.complex<10>>>
1606+
%2 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFtest_complex10Ei"}
1607+
%3 = fir.declare %2 {uniq_name = "_QFtest_complex10Ei"} : (!fir.ref<i32>) -> !fir.ref<i32>
1608+
%4 = fir.alloca !fir.complex<10> {bindc_name = "res", uniq_name = "_QFtest_complex10Eres"}
1609+
%5 = fir.declare %4 {uniq_name = "_QFtest_complex10Eres"} : (!fir.ref<!fir.complex<10>>) -> !fir.ref<!fir.complex<10>>
1610+
%6 = fir.address_of(@_QFtest_complex10ECxdp) : !fir.ref<i32>
1611+
%7 = fir.declare %6 {fortran_attrs = #fir.var_attrs<parameter>, uniq_name = "_QFtest_complex10ECxdp"} : (!fir.ref<i32>) -> !fir.ref<i32>
1612+
%8 = fir.undefined !fir.complex<10>
1613+
%9 = fir.insert_value %8, %cst, [0 : index] : (!fir.complex<10>, f80) -> !fir.complex<10>
1614+
%10 = fir.insert_value %9, %cst, [1 : index] : (!fir.complex<10>, f80) -> !fir.complex<10>
1615+
fir.store %10 to %5 : !fir.ref<!fir.complex<10>>
1616+
%11 = fir.convert %c1 : (index) -> i32
1617+
%12:2 = fir.do_loop %arg1 = %c1 to %c10 step %c1 iter_args(%arg2 = %11) -> (index, i32) {
1618+
fir.store %arg2 to %3 : !fir.ref<i32>
1619+
%14 = fir.load %5 : !fir.ref<!fir.complex<10>>
1620+
%15 = fir.load %3 : !fir.ref<i32>
1621+
%16 = fir.convert %15 : (i32) -> i64
1622+
%17 = fir.array_coor %1 %16, %16 : (!fir.box<!fir.array<?x?x!fir.complex<10>>>, i64, i64) -> !fir.ref<!fir.complex<10>>
1623+
%18 = fir.load %17 : !fir.ref<!fir.complex<10>>
1624+
%19 = fir.addc %14, %18 {fastmath = #arith.fastmath<contract>} : !fir.complex<10>
1625+
fir.store %19 to %5 : !fir.ref<!fir.complex<10>>
1626+
%20 = arith.addi %arg1, %c1 : index
1627+
%21 = fir.load %3 : !fir.ref<i32>
1628+
%22 = arith.addi %21, %11 : i32
1629+
fir.result %20, %22 : index, i32
1630+
}
1631+
fir.store %12#1 to %3 : !fir.ref<i32>
1632+
%13 = fir.load %5 : !fir.ref<!fir.complex<10>>
1633+
return %13 : !fir.complex<10>
1634+
}
1635+
// CHECK-LABEL: func.func @_QPtest_complex10(
1636+
// CHECK: fir.if
1637+
// CHECK: fir.do_loop
1638+
// CHECK-DAG: arith.shrsi %{{[^,]*}}, %[[SHIFT:.*]] : index
1639+
// CHECK-DAG: %[[SHIFT]] = arith.constant 5 : index
1640+
// CHECK: fir.result
1641+
// CHECK: } else {
1642+
// CHECK: fir.do_loop
15591643

15601644
} // End module

0 commit comments

Comments
 (0)