Skip to content

Commit 396ead9

Browse files
committed
[flang] Use proper attributes for runtime calls with 'i1' arguments/returns.
Clang uses signext/zeroext attributes for integer arguments shorter than the default 'int' type on a target. So Flang has to match this for functions from Fortran runtime and also for BIND(C) routines. This patch implements ABI adjustments only for Fortran runtime calls. BIND(C) part will be done separately. This resolves llvm#58579 Differential Revision: https://reviews.llvm.org/D137050
1 parent 3fb08d1 commit 396ead9

File tree

8 files changed

+252
-14
lines changed

8 files changed

+252
-14
lines changed

flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "flang/Common/Fortran.h"
2121
#include "flang/Common/uint128.h"
2222
#include "flang/Optimizer/Builder/FIRBuilder.h"
23+
#include "flang/Optimizer/Dialect/FIRDialect.h"
2324
#include "flang/Optimizer/Dialect/FIRType.h"
2425
#include "mlir/IR/BuiltinTypes.h"
2526
#include "mlir/IR/MLIRContext.h"
@@ -411,7 +412,7 @@ static mlir::func::FuncOp getRuntimeFunc(mlir::Location loc,
411412
return func;
412413
auto funTy = RuntimeEntry::getTypeModel()(builder.getContext());
413414
func = builder.createFunction(loc, name, funTy);
414-
func->setAttr("fir.runtime", builder.getUnitAttr());
415+
func->setAttr(FIROpsDialect::getFirRuntimeAttrName(), builder.getUnitAttr());
415416
return func;
416417
}
417418

flang/include/flang/Optimizer/Dialect/FIRDialect.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ class FIROpsDialect final : public mlir::Dialect {
3737
void printAttribute(mlir::Attribute attr,
3838
mlir::DialectAsmPrinter &p) const override;
3939

40+
/// Return string name of fir.runtime attribute.
41+
static constexpr llvm::StringRef getFirRuntimeAttrName() {
42+
return "fir.runtime";
43+
}
44+
4045
private:
4146
// Register the Attributes of this dialect.
4247
void registerAttributes();

flang/lib/Lower/IO.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "flang/Optimizer/Builder/FIRBuilder.h"
2828
#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
2929
#include "flang/Optimizer/Builder/Todo.h"
30+
#include "flang/Optimizer/Dialect/FIRDialect.h"
3031
#include "flang/Optimizer/Support/FIRContext.h"
3132
#include "flang/Parser/parse-tree.h"
3233
#include "flang/Runtime/io-api.h"
@@ -167,7 +168,8 @@ static mlir::func::FuncOp getIORuntimeFunc(mlir::Location loc,
167168
return func;
168169
auto funTy = getTypeModel<E>()(builder.getContext());
169170
func = builder.createFunction(loc, name, funTy);
170-
func->setAttr("fir.runtime", builder.getUnitAttr());
171+
func->setAttr(fir::FIROpsDialect::getFirRuntimeAttrName(),
172+
builder.getUnitAttr());
171173
func->setAttr("fir.io", builder.getUnitAttr());
172174
return func;
173175
}

flang/lib/Lower/IntrinsicCall.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "flang/Optimizer/Builder/Runtime/Stop.h"
3333
#include "flang/Optimizer/Builder/Runtime/Transformational.h"
3434
#include "flang/Optimizer/Builder/Todo.h"
35+
#include "flang/Optimizer/Dialect/FIRDialect.h"
3536
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
3637
#include "flang/Optimizer/Support/FatalError.h"
3738
#include "flang/Runtime/entry-names.h"
@@ -1684,7 +1685,8 @@ static mlir::func::FuncOp getFuncOp(mlir::Location loc,
16841685
const RuntimeFunction &runtime) {
16851686
mlir::func::FuncOp function = builder.addNamedFunction(
16861687
loc, runtime.symbol, runtime.typeGenerator(builder.getContext()));
1687-
function->setAttr("fir.runtime", builder.getUnitAttr());
1688+
function->setAttr(fir::FIROpsDialect::getFirRuntimeAttrName(),
1689+
builder.getUnitAttr());
16881690
return function;
16891691
}
16901692

flang/lib/Optimizer/CodeGen/Target.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,19 @@
2222

2323
using namespace fir;
2424

25+
namespace fir::details {
26+
llvm::StringRef Attributes::getIntExtensionAttrName() const {
27+
// The attribute names are available via LLVM dialect interfaces
28+
// like getZExtAttrName(), getByValAttrName(), etc., so we'd better
29+
// use them than literals.
30+
if (isZeroExt())
31+
return "llvm.zeroext";
32+
else if (isSignExt())
33+
return "llvm.signext";
34+
return {};
35+
}
36+
} // namespace fir::details
37+
2538
// Reduce a REAL/float type to the floating point semantics.
2639
static const llvm::fltSemantics &floatToSemantics(const KindMapping &kindMap,
2740
mlir::Type type) {
@@ -67,6 +80,46 @@ struct GenericTarget : public CodeGenSpecifics {
6780
/*sret=*/sret, /*append=*/!sret});
6881
return marshal;
6982
}
83+
84+
CodeGenSpecifics::Marshalling
85+
integerArgumentType(mlir::Location loc,
86+
mlir::IntegerType argTy) const override {
87+
CodeGenSpecifics::Marshalling marshal;
88+
AT::IntegerExtension intExt = AT::IntegerExtension::None;
89+
if (argTy.getWidth() < getCIntTypeWidth()) {
90+
// isSigned() and isUnsigned() branches below are dead code currently.
91+
// If needed, we can generate calls with signed/unsigned argument types
92+
// to more precisely match C side (e.g. for Fortran runtime functions
93+
// with 'unsigned short' arguments).
94+
if (argTy.isSigned())
95+
intExt = AT::IntegerExtension::Sign;
96+
else if (argTy.isUnsigned())
97+
intExt = AT::IntegerExtension::Zero;
98+
else if (argTy.isSignless()) {
99+
// Zero extend for 'i1' and sign extend for other types.
100+
if (argTy.getWidth() == 1)
101+
intExt = AT::IntegerExtension::Zero;
102+
else
103+
intExt = AT::IntegerExtension::Sign;
104+
}
105+
}
106+
107+
marshal.emplace_back(argTy, AT{/*alignment=*/0, /*byval=*/false,
108+
/*sret=*/false, /*append=*/false,
109+
/*intExt=*/intExt});
110+
return marshal;
111+
}
112+
113+
CodeGenSpecifics::Marshalling
114+
integerReturnType(mlir::Location loc,
115+
mlir::IntegerType argTy) const override {
116+
return integerArgumentType(loc, argTy);
117+
}
118+
119+
// Width of 'int' type is 32-bits for almost all targets, except
120+
// for AVR and MSP430 (see TargetInfo initializations
121+
// in clang/lib/Basic/Targets).
122+
unsigned char getCIntTypeWidth() const override { return 32; }
70123
};
71124
} // namespace
72125

flang/lib/Optimizer/CodeGen/Target.h

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,29 @@ namespace details {
2929
/// LLVMContext.
3030
class Attributes {
3131
public:
32+
enum class IntegerExtension { None, Zero, Sign };
33+
3234
Attributes(unsigned short alignment = 0, bool byval = false,
33-
bool sret = false, bool append = false)
34-
: alignment{alignment}, byval{byval}, sret{sret}, append{append} {}
35+
bool sret = false, bool append = false,
36+
IntegerExtension intExt = IntegerExtension::None)
37+
: alignment{alignment}, byval{byval}, sret{sret}, append{append},
38+
intExt{intExt} {}
3539

3640
unsigned getAlignment() const { return alignment; }
3741
bool hasAlignment() const { return alignment != 0; }
3842
bool isByVal() const { return byval; }
3943
bool isSRet() const { return sret; }
4044
bool isAppend() const { return append; }
45+
bool isZeroExt() const { return intExt == IntegerExtension::Zero; }
46+
bool isSignExt() const { return intExt == IntegerExtension::Sign; }
47+
llvm::StringRef getIntExtensionAttrName() const;
4148

4249
private:
4350
unsigned short alignment{};
4451
bool byval : 1;
4552
bool sret : 1;
4653
bool append : 1;
54+
IntegerExtension intExt;
4755
};
4856

4957
} // namespace details
@@ -94,6 +102,47 @@ class CodeGenSpecifics {
94102
virtual Marshalling boxcharArgumentType(mlir::Type eleTy,
95103
bool sret = false) const = 0;
96104

105+
// Compute ABI rules for an integer argument of the given mlir::IntegerType
106+
// \p argTy. Note that this methods is supposed to be called for
107+
// arguments passed by value not via reference, e.g. the 'i1' argument here:
108+
// declare i1 @_FortranAioOutputLogical(ptr, i1)
109+
//
110+
// \p loc is the location of the operation using/specifying the argument.
111+
//
112+
// Currently, the only supported marshalling is whether the argument
113+
// should be zero or sign extended.
114+
//
115+
// The zero/sign extension is especially important to comply with the ABI
116+
// used by C/C++ compiler that builds Fortran runtime. As in the above
117+
// example the callee will expect the caller to zero extend the second
118+
// argument up to the size of the C/C++'s 'int' type.
119+
// The corresponding handling in clang is done in
120+
// DefaultABIInfo::classifyArgumentType(), and the logic may brielfy
121+
// be explained as some sort of extension is required if the integer
122+
// type is shorter than the size of 'int' for the target.
123+
// The related code is located in ASTContext::isPromotableIntegerType()
124+
// and ABIInfo::isPromotableIntegerTypeForABI().
125+
// In particular, the latter returns 'true' for 'bool', several kinds
126+
// of 'char', 'short', 'wchar' and enumerated types.
127+
// The type of the extensions (zero or sign) depends on the signedness
128+
// of the original language type.
129+
//
130+
// It is not clear how to handle signless integer types.
131+
// From the point of Fortran-C interface all supported integer types
132+
// seem to be signed except for CFI_type_Bool/bool that is supported
133+
// via signless 'i1', but that is treated as unsigned type by clang
134+
// (e.g. 'bool' arguments are using 'zeroext' ABI).
135+
virtual Marshalling integerArgumentType(mlir::Location loc,
136+
mlir::IntegerType argTy) const = 0;
137+
138+
// By default, integer argument and return values use the same
139+
// zero/sign extension rules.
140+
virtual Marshalling integerReturnType(mlir::Location loc,
141+
mlir::IntegerType argTy) const = 0;
142+
143+
// Returns width in bits of C/C++ 'int' type size.
144+
virtual unsigned char getCIntTypeWidth() const = 0;
145+
97146
protected:
98147
mlir::MLIRContext &context;
99148
llvm::Triple triple;

flang/lib/Optimizer/CodeGen/TargetRewrite.cpp

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,14 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
100100
// Convert ops in target-specific patterns.
101101
mod.walk([&](mlir::Operation *op) {
102102
if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
103-
if (!hasPortableSignature(call.getFunctionType()))
103+
if (!hasPortableSignature(call.getFunctionType(), op))
104104
convertCallOp(call);
105105
} else if (auto dispatch = mlir::dyn_cast<fir::DispatchOp>(op)) {
106-
if (!hasPortableSignature(dispatch.getFunctionType()))
106+
if (!hasPortableSignature(dispatch.getFunctionType(), op))
107107
convertCallOp(dispatch);
108108
} else if (auto addr = mlir::dyn_cast<fir::AddrOfOp>(op)) {
109109
if (addr.getType().isa<mlir::FunctionType>() &&
110-
!hasPortableSignature(addr.getType()))
110+
!hasPortableSignature(addr.getType(), op))
111111
convertAddrOp(addr);
112112
}
113113
});
@@ -443,19 +443,23 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
443443
/// then it is considered portable for any target, and this function will
444444
/// return `true`. Otherwise, the signature is not portable and `false` is
445445
/// returned.
446-
bool hasPortableSignature(mlir::Type signature) {
446+
bool hasPortableSignature(mlir::Type signature, mlir::Operation *op) {
447447
assert(signature.isa<mlir::FunctionType>());
448448
auto func = signature.dyn_cast<mlir::FunctionType>();
449+
bool hasFirRuntime = op->hasAttrOfType<mlir::UnitAttr>(
450+
fir::FIROpsDialect::getFirRuntimeAttrName());
449451
for (auto ty : func.getResults())
450452
if ((ty.isa<fir::BoxCharType>() && !noCharacterConversion) ||
451-
(fir::isa_complex(ty) && !noComplexConversion)) {
453+
(fir::isa_complex(ty) && !noComplexConversion) ||
454+
(ty.isa<mlir::IntegerType>() && hasFirRuntime)) {
452455
LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n");
453456
return false;
454457
}
455458
for (auto ty : func.getInputs())
456459
if (((ty.isa<fir::BoxCharType>() || fir::isCharacterProcedureTuple(ty)) &&
457460
!noCharacterConversion) ||
458-
(fir::isa_complex(ty) && !noComplexConversion)) {
461+
(fir::isa_complex(ty) && !noComplexConversion) ||
462+
(ty.isa<mlir::IntegerType>() && hasFirRuntime)) {
459463
LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n");
460464
return false;
461465
}
@@ -476,13 +480,14 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
476480
/// the immediately subsequent target code gen.
477481
void convertSignature(mlir::func::FuncOp func) {
478482
auto funcTy = func.getFunctionType().cast<mlir::FunctionType>();
479-
if (hasPortableSignature(funcTy) && !hasHostAssociations(func))
483+
if (hasPortableSignature(funcTy, func) && !hasHostAssociations(func))
480484
return;
481485
llvm::SmallVector<mlir::Type> newResTys;
482486
llvm::SmallVector<mlir::Type> newInTys;
483487
llvm::SmallVector<std::pair<unsigned, mlir::NamedAttribute>> savedAttrs;
484488
llvm::SmallVector<std::pair<unsigned, mlir::NamedAttribute>> extraAttrs;
485489
llvm::SmallVector<FixupTy> fixups;
490+
llvm::SmallVector<std::pair<unsigned, mlir::NamedAttrList>, 1> resultAttrs;
486491

487492
// Save argument attributes in case there is a shift so we can replace them
488493
// correctly.
@@ -509,6 +514,22 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
509514
else
510515
doComplexReturn(func, cmplx, newResTys, newInTys, fixups);
511516
})
517+
.Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
518+
auto m = specifics->integerArgumentType(func.getLoc(), intTy);
519+
assert(m.size() == 1);
520+
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
521+
auto retTy = std::get<mlir::Type>(m[0]);
522+
std::size_t resId = newResTys.size();
523+
llvm::StringRef extensionAttrName = attr.getIntExtensionAttrName();
524+
if (!extensionAttrName.empty() &&
525+
// TODO: we have to do the same for BIND(C) routines.
526+
func->hasAttrOfType<mlir::UnitAttr>(
527+
fir::FIROpsDialect::getFirRuntimeAttrName()))
528+
resultAttrs.emplace_back(
529+
resId, rewriter->getNamedAttr(extensionAttrName,
530+
rewriter->getUnitAttr()));
531+
newResTys.push_back(retTy);
532+
})
512533
.Default([&](mlir::Type ty) { newResTys.push_back(ty); });
513534

514535
// Saved potential shift in argument. Handling of result can add arguments
@@ -572,6 +593,26 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
572593
newInTys.push_back(ty);
573594
}
574595
})
596+
.Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
597+
auto m = specifics->integerArgumentType(func.getLoc(), intTy);
598+
assert(m.size() == 1);
599+
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
600+
auto argTy = std::get<mlir::Type>(m[0]);
601+
auto argNo = newInTys.size();
602+
llvm::StringRef extensionAttrName = attr.getIntExtensionAttrName();
603+
if (!extensionAttrName.empty() &&
604+
// TODO: we have to do the same for BIND(C) routines.
605+
func->hasAttrOfType<mlir::UnitAttr>(
606+
fir::FIROpsDialect::getFirRuntimeAttrName())) {
607+
fixups.emplace_back(FixupTy::Codes::ArgumentType, argNo,
608+
[=](mlir::func::FuncOp func) {
609+
func.setArgAttr(
610+
argNo, extensionAttrName,
611+
mlir::UnitAttr::get(func.getContext()));
612+
});
613+
}
614+
newInTys.push_back(argTy);
615+
})
575616
.Default([&](mlir::Type ty) { newInTys.push_back(ty); });
576617

577618
if (func.getArgAttrOfType<mlir::UnitAttr>(index,
@@ -608,14 +649,18 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
608649
case FixupTy::Codes::ArgumentType: {
609650
// Argument is pass-by-value, but its type has likely been modified to
610651
// suit the target ABI convention.
652+
auto oldArgTy =
653+
fir::ReferenceType::get(oldArgTys[fixup.index - offset]);
654+
// If type did not change, keep the original argument.
655+
if (newInTys[fixup.index] == oldArgTy)
656+
break;
657+
611658
auto newArg = func.front().insertArgument(fixup.index,
612659
newInTys[fixup.index], loc);
613660
rewriter->setInsertionPointToStart(&func.front());
614661
auto mem =
615662
rewriter->create<fir::AllocaOp>(loc, newInTys[fixup.index]);
616663
rewriter->create<fir::StoreOp>(loc, newArg, mem);
617-
auto oldArgTy =
618-
fir::ReferenceType::get(oldArgTys[fixup.index - offset]);
619664
auto cast = rewriter->create<fir::ConvertOp>(loc, oldArgTy, mem);
620665
mlir::Value load = rewriter->create<fir::LoadOp>(loc, cast);
621666
func.getArgument(fixup.index + 1).replaceAllUsesWith(load);
@@ -744,6 +789,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
744789
func.setArgAttr(extraAttr.first, extraAttr.second.getName(),
745790
extraAttr.second.getValue());
746791

792+
for (auto [resId, resAttrList] : resultAttrs)
793+
for (mlir::NamedAttribute resAttr : resAttrList)
794+
func.setResultAttr(resId, resAttr.getName(), resAttr.getValue());
795+
747796
// Replace attributes to the correct argument if there was an argument shift
748797
// to the right.
749798
if (argumentShift > 0) {

0 commit comments

Comments
 (0)