Skip to content

Commit b853157

Browse files
committed
Support for target specific lowering in the Tilikum bridge.
To generate correct code for a chosen target, the Tilikum bridge must know what the selected target is and the conventions used for the specific target ABI. The properties of the target influence the calling conventions and LLVM IR that must be generated. Tilikum is the last point before any high-level abstractions must be considered and correctly translated to LLVM IR. These changed rework the Tilikum bridge to use a target specifier and convert the calling conventions and memory layouts appropriate for the selected target. Two target specifications are implemented. i386-unknown-linux-gnu and x86_64-unknown-linux-gnu. Others can be added as needed. Two high-level type abstractions are considered: COMPLEX and CHARACTER. Moving these target specific lowerings to a common place in code gen eliminates the need to perform heroics with custom code in lowering and/or reliance on assuming the target is known by implication at compiler compile-time.
1 parent 7e65641 commit b853157

35 files changed

+1575
-312
lines changed

flang/include/flang/Optimizer/CodeGen/CGPasses.td

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,27 @@
1717
include "mlir/Pass/PassBase.td"
1818

1919
def CodeGenRewrite : FunctionPass<"cg-rewrite"> {
20-
let summary = "Rewrite some FIR ops into their code-gen forms.";
20+
let summary = "Rewrite some FIR ops into their code-gen forms. "
21+
"Fuse specific subgraphs into single Ops for code generation.";
2122
let constructor = "fir::createFirCodeGenRewritePass()";
23+
let dependentDialects = ["fir::FIROpsDialect"];
24+
}
25+
26+
def TargetRewrite : Pass<"target-rewrite", "mlir::ModuleOp"> {
27+
let summary = "Rewrite some FIR dialect into target specific forms. "
28+
"Certain abstractions in the FIR dialect need to be rewritten "
29+
"to reflect representations that may differ based on the "
30+
"target machine.";
31+
let constructor = "fir::createFirTargetRewritePass()";
32+
let dependentDialects = ["fir::FIROpsDialect"];
33+
let options = [
34+
Option<"noCharacterConversion", "no-character-conversion",
35+
"bool", /*default=*/"false",
36+
"Disable target-specific conversion of CHARACTER.">,
37+
Option<"noComplexConversion", "no-complex-conversion",
38+
"bool", /*default=*/"false",
39+
"Disable target-specific conversion of COMPLEX.">
40+
];
2241
}
2342

2443
#endif // FLANG_OPTIMIZER_CODEGEN_PASSES

flang/include/flang/Optimizer/CodeGen/CodeGen.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef OPTIMIZER_CODEGEN_CODEGEN_H
1010
#define OPTIMIZER_CODEGEN_CODEGEN_H
1111

12+
#include "mlir/IR/Module.h"
1213
#include "mlir/Pass/Pass.h"
1314
#include "mlir/Pass/PassRegistry.h"
1415
#include <memory>
@@ -21,6 +22,17 @@ struct NameUniquer;
2122
/// the code gen (to LLVM-IR dialect) conversion.
2223
std::unique_ptr<mlir::Pass> createFirCodeGenRewritePass();
2324

25+
/// FirTargetRewritePass options.
26+
struct TargetRewriteOptions {
27+
bool noCharacterConversion{};
28+
bool noComplexConversion{};
29+
};
30+
31+
/// Prerequiste pass for code gen. Perform intermediate rewrites to tailor the
32+
/// IR for the chosen target.
33+
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> createFirTargetRewritePass(
34+
const TargetRewriteOptions &options = TargetRewriteOptions());
35+
2436
/// Convert FIR to the LLVM IR dialect
2537
std::unique_ptr<mlir::Pass> createFIRToLLVMPass(NameUniquer &uniquer);
2638

flang/include/flang/Optimizer/Dialect/FIRType.h

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ struct BoxTypeStorage;
4343
struct BoxCharTypeStorage;
4444
struct BoxProcTypeStorage;
4545
struct CharacterTypeStorage;
46-
struct CplxTypeStorage;
46+
struct ComplexTypeStorage;
4747
struct FieldTypeStorage;
4848
struct HeapTypeStorage;
49-
struct IntTypeStorage;
49+
struct IntegerTypeStorage;
5050
struct LenTypeStorage;
5151
struct LogicalTypeStorage;
5252
struct PointerTypeStorage;
@@ -58,6 +58,7 @@ struct ShapeTypeStorage;
5858
struct ShapeShiftTypeStorage;
5959
struct SliceTypeStorage;
6060
struct TypeDescTypeStorage;
61+
struct VectorTypeStorage;
6162
} // namespace detail
6263

6364
// These isa_ routines follow the precedent of llvm::isa_or_null<>
@@ -125,11 +126,11 @@ class CharacterType
125126
/// Model of a Fortran COMPLEX intrinsic type, including the KIND type
126127
/// parameter. COMPLEX is a floating point type with a real and imaginary
127128
/// member.
128-
class CplxType : public mlir::Type::TypeBase<CplxType, mlir::Type,
129-
detail::CplxTypeStorage> {
129+
class ComplexType : public mlir::Type::TypeBase<fir::ComplexType, mlir::Type,
130+
detail::ComplexTypeStorage> {
130131
public:
131132
using Base::Base;
132-
static CplxType get(mlir::MLIRContext *ctxt, KindTy kind);
133+
static fir::ComplexType get(mlir::MLIRContext *ctxt, KindTy kind);
133134

134135
/// Get the corresponding fir.real<k> type.
135136
mlir::Type getElementType() const;
@@ -139,19 +140,18 @@ class CplxType : public mlir::Type::TypeBase<CplxType, mlir::Type,
139140

140141
/// Model of a Fortran INTEGER intrinsic type, including the KIND type
141142
/// parameter.
142-
class IntType
143-
: public mlir::Type::TypeBase<IntType, mlir::Type, detail::IntTypeStorage> {
143+
class IntegerType : public mlir::Type::TypeBase<fir::IntegerType, mlir::Type,
144+
detail::IntegerTypeStorage> {
144145
public:
145146
using Base::Base;
146-
static IntType get(mlir::MLIRContext *ctxt, KindTy kind);
147+
static fir::IntegerType get(mlir::MLIRContext *ctxt, KindTy kind);
147148
KindTy getFKind() const;
148149
};
149150

150151
/// Model of a Fortran LOGICAL intrinsic type, including the KIND type
151152
/// parameter.
152-
class LogicalType
153-
: public mlir::Type::TypeBase<LogicalType, mlir::Type,
154-
detail::LogicalTypeStorage> {
153+
class LogicalType : public mlir::Type::TypeBase<LogicalType, mlir::Type,
154+
detail::LogicalTypeStorage> {
155155
public:
156156
using Base::Base;
157157
static LogicalType get(mlir::MLIRContext *ctxt, KindTy kind);
@@ -414,14 +414,6 @@ class RecordType : public mlir::Type::TypeBase<RecordType, mlir::Type,
414414
llvm::StringRef name);
415415
};
416416

417-
mlir::Type parseFirType(FIROpsDialect *, mlir::DialectAsmParser &parser);
418-
419-
void printFirType(FIROpsDialect *, mlir::Type ty, mlir::DialectAsmPrinter &p);
420-
421-
/// Guarantee `type` is a scalar integral type (standard Integer, standard
422-
/// Index, or FIR Int). Aborts execution if condition is false.
423-
void verifyIntegralType(mlir::Type type);
424-
425417
/// Is `t` a FIR Real or MLIR Float type?
426418
inline bool isa_real(mlir::Type t) {
427419
return t.isa<fir::RealType>() || t.isa<mlir::FloatType>();
@@ -430,12 +422,38 @@ inline bool isa_real(mlir::Type t) {
430422
/// Is `t` an integral type?
431423
inline bool isa_integer(mlir::Type t) {
432424
return t.isa<mlir::IndexType>() || t.isa<mlir::IntegerType>() ||
433-
t.isa<fir::IntType>();
425+
t.isa<fir::IntegerType>();
434426
}
435427

428+
/// Replacement for the standard dialect's vector type. Relaxes some of the
429+
/// constraints and imposes some new ones.
430+
class VectorType : public mlir::Type::TypeBase<fir::VectorType, mlir::Type,
431+
detail::VectorTypeStorage> {
432+
public:
433+
using Base::Base;
434+
435+
static fir::VectorType get(uint64_t len, mlir::Type eleTy);
436+
mlir::Type getEleTy() const;
437+
uint64_t getLen() const;
438+
439+
static mlir::LogicalResult
440+
verifyConstructionInvariants(mlir::Location, uint64_t len, mlir::Type eleTy);
441+
static bool isValidElementType(mlir::Type t) {
442+
return isa_real(t) || isa_integer(t);
443+
}
444+
};
445+
446+
mlir::Type parseFirType(FIROpsDialect *, mlir::DialectAsmParser &parser);
447+
448+
void printFirType(FIROpsDialect *, mlir::Type ty, mlir::DialectAsmPrinter &p);
449+
450+
/// Guarantee `type` is a scalar integral type (standard Integer, standard
451+
/// Index, or FIR Int). Aborts execution if condition is false.
452+
void verifyIntegralType(mlir::Type type);
453+
436454
/// Is `t` a FIR or MLIR Complex type?
437455
inline bool isa_complex(mlir::Type t) {
438-
return t.isa<fir::CplxType>() || t.isa<mlir::ComplexType>();
456+
return t.isa<fir::ComplexType>() || t.isa<mlir::ComplexType>();
439457
}
440458

441459
inline bool isa_char_string(mlir::Type t) {

flang/include/flang/Optimizer/Support/FIRContext.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,24 @@ namespace fir {
2929
class KindMapping;
3030
struct NameUniquer;
3131

32-
/// Set the target triple for the module.
32+
/// Set the target triple for the module. `triple` must not be deallocated while
33+
/// module `mod` is still live.
3334
void setTargetTriple(mlir::ModuleOp mod, llvm::Triple &triple);
3435

3536
/// Get a pointer to the Triple instance from the Module. If none was set,
3637
/// returns a nullptr.
3738
llvm::Triple *getTargetTriple(mlir::ModuleOp mod);
3839

39-
/// Set the name uniquer for the module.
40+
/// Set the name uniquer for the module. `uniquer` must not be deallocated while
41+
/// module `mod` is still live.
4042
void setNameUniquer(mlir::ModuleOp mod, NameUniquer &uniquer);
4143

4244
/// Get a pointer to the NameUniquer instance from the Module. If none was set,
4345
/// returns a nullptr.
4446
NameUniquer *getNameUniquer(mlir::ModuleOp mod);
4547

46-
/// Set the kind mapping for the module.
48+
/// Set the kind mapping for the module. `kindMap` must not be deallocated while
49+
/// module `mod` is still live.
4750
void setKindMapping(mlir::ModuleOp mod, KindMapping &kindMap);
4851

4952
/// Get a pointer to the KindMapping instance from the Module. If none was set,
@@ -53,6 +56,9 @@ KindMapping *getKindMapping(mlir::ModuleOp mod);
5356
/// Helper for determining the target from the host, etc. Tools may use this
5457
/// function to provide a consistent interpretation of the `--target=<string>`
5558
/// command-line option.
59+
/// An empty string ("") or "default" will specify that the default triple
60+
/// should be used. "native" will specify that the host machine be used to
61+
/// construct the triple.
5662
std::string determineTargetTriple(llvm::StringRef triple);
5763

5864
} // namespace fir

flang/lib/Lower/ComplexExpr.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
mlir::Type
1717
Fortran::lower::ComplexExprHelper::getComplexPartType(mlir::Type complexType) {
1818
return Fortran::lower::convertReal(
19-
builder.getContext(), complexType.cast<fir::CplxType>().getFKind());
19+
builder.getContext(), complexType.cast<fir::ComplexType>().getFKind());
2020
}
2121

2222
mlir::Type
@@ -27,7 +27,7 @@ Fortran::lower::ComplexExprHelper::getComplexPartType(mlir::Value cplx) {
2727
mlir::Value Fortran::lower::ComplexExprHelper::createComplex(fir::KindTy kind,
2828
mlir::Value real,
2929
mlir::Value imag) {
30-
auto complexTy = fir::CplxType::get(builder.getContext(), kind);
30+
auto complexTy = fir::ComplexType::get(builder.getContext(), kind);
3131
mlir::Value und = builder.create<fir::UndefOp>(loc, complexTy);
3232
return insert<Part::Imag>(insert<Part::Real>(und, real), imag);
3333
}

flang/lib/Lower/ConvertExpr.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "llvm/Support/CommandLine.h"
3333
#include "llvm/Support/ErrorHandling.h"
3434
#include "llvm/Support/raw_ostream.h"
35+
#define DEBUG_TYPE "flang-lower-expr"
3536

3637
#define TODO() llvm_unreachable("not yet implemented")
3738

@@ -362,7 +363,7 @@ class ExprLowering {
362363
Fortran::lower::getUnrestrictedIntrinsicSymbolRefAttr(
363364
builder, getLoc(), genericName, signature);
364365
mlir::Value funcPtr =
365-
builder.create<mlir::ConstantOp>(getLoc(), signature, symbolRefAttr);
366+
builder.create<fir::AddrOfOp>(getLoc(), signature, symbolRefAttr);
366367
return funcPtr;
367368
}
368369
const auto *symbol = proc.GetSymbol();
@@ -374,7 +375,7 @@ class ExprLowering {
374375
}
375376
auto name = converter.mangleName(*symbol);
376377
auto func = Fortran::lower::getOrDeclareFunction(name, proc, converter);
377-
mlir::Value funcPtr = builder.create<mlir::ConstantOp>(
378+
mlir::Value funcPtr = builder.create<fir::AddrOfOp>(
378379
getLoc(), func.getType(), builder.getSymbolRefAttr(name));
379380
return funcPtr;
380381
}
@@ -1350,6 +1351,7 @@ class ExprLowering {
13501351
}
13511352
}
13521353
auto result = genval(details.stmtFunction().value());
1354+
LLVM_DEBUG(llvm::errs() << "stmt-function: " << result << '\n');
13531355
// Remove dummy local arguments from the map.
13541356
for (const auto *dummySymbol : details.dummyArgs())
13551357
symMap.erase(*dummySymbol);
@@ -1469,7 +1471,7 @@ class ExprLowering {
14691471
if (callSiteType.getNumResults() != funcOpType.getNumResults() ||
14701472
callSiteType.getNumInputs() != funcOpType.getNumInputs())
14711473
funcPointer =
1472-
builder.create<mlir::ConstantOp>(getLoc(), funcOpType, symbolAttr);
1474+
builder.create<fir::AddrOfOp>(getLoc(), funcOpType, symbolAttr);
14731475
else
14741476
funcSymbolAttr = symbolAttr;
14751477
}
@@ -1585,6 +1587,8 @@ mlir::Value Fortran::lower::createSomeExpression(
15851587
const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr,
15861588
Fortran::lower::SymMap &symMap) {
15871589
Fortran::lower::ExpressionContext unused;
1590+
LLVM_DEBUG(llvm::errs() << "expr: "; expr.AsFortran(llvm::errs());
1591+
llvm::errs() << '\n');
15881592
return ExprLowering{loc, converter, symMap, unused}.genValue(expr);
15891593
}
15901594

@@ -1593,6 +1597,8 @@ fir::ExtendedValue Fortran::lower::createSomeExtendedExpression(
15931597
const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr,
15941598
Fortran::lower::SymMap &symMap,
15951599
const Fortran::lower::ExpressionContext &context) {
1600+
LLVM_DEBUG(llvm::errs() << "expr: "; expr.AsFortran(llvm::errs());
1601+
llvm::errs() << '\n');
15961602
return ExprLowering{loc, converter, symMap, context}.genExtValue(expr);
15971603
}
15981604

@@ -1601,6 +1607,8 @@ mlir::Value Fortran::lower::createSomeAddress(
16011607
const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr,
16021608
Fortran::lower::SymMap &symMap) {
16031609
Fortran::lower::ExpressionContext unused;
1610+
LLVM_DEBUG(llvm::errs() << "address: "; expr.AsFortran(llvm::errs());
1611+
llvm::errs() << '\n');
16041612
return ExprLowering{loc, converter, symMap, unused}.genAddr(expr);
16051613
}
16061614

@@ -1609,6 +1617,8 @@ fir::ExtendedValue Fortran::lower::createSomeExtendedAddress(
16091617
const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr,
16101618
Fortran::lower::SymMap &symMap,
16111619
const Fortran::lower::ExpressionContext &context) {
1620+
LLVM_DEBUG(llvm::errs() << "address: "; expr.AsFortran(llvm::errs());
1621+
llvm::errs() << '\n');
16121622
return ExprLowering{loc, converter, symMap, context}.genExtAddr(expr);
16131623
}
16141624

@@ -1618,6 +1628,7 @@ fir::ExtendedValue Fortran::lower::createStringLiteral(
16181628
assert(str.size() == len);
16191629
Fortran::lower::SymMap unused1;
16201630
Fortran::lower::ExpressionContext unused2;
1631+
LLVM_DEBUG(llvm::errs() << "string-lit: \"" << str << "\"\n");
16211632
return ExprLowering{loc, converter, unused1, unused2}.genStringLit(str, len);
16221633
}
16231634

flang/lib/Lower/ConvertType.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ genFIRType<Fortran::common::TypeCategory::Complex>(mlir::MLIRContext *context,
167167
int KIND) {
168168
if (Fortran::evaluate::IsValidKindOfIntrinsicType(
169169
Fortran::common::TypeCategory::Complex, KIND))
170-
return fir::CplxType::get(context, KIND);
170+
return fir::ComplexType::get(context, KIND);
171171
return {};
172172
}
173173

flang/lib/Lower/FIRBuilder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ mlir::Value Fortran::lower::FirOpBuilder::convertWithSemantics(
153153
auto eleTy = helper.getComplexPartType(toTy);
154154
auto cast = createConvert(loc, eleTy, val);
155155
llvm::APFloat zero{
156-
kindMap.getFloatSemantics(toTy.cast<fir::CplxType>().getFKind()), 0};
156+
kindMap.getFloatSemantics(toTy.cast<fir::ComplexType>().getFKind()), 0};
157157
auto imag = createRealConstant(loc, eleTy, zero);
158158
return helper.createComplex(toTy, cast, imag);
159159
}

flang/lib/Lower/IO.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ static mlir::FuncOp getOutputFunc(mlir::Location loc,
203203
return ty.getWidth() <= 32
204204
? getIORuntimeFunc<mkIOKey(OutputReal32)>(loc, builder)
205205
: getIORuntimeFunc<mkIOKey(OutputReal64)>(loc, builder);
206-
if (auto ty = type.dyn_cast<fir::CplxType>())
206+
if (auto ty = type.dyn_cast<fir::ComplexType>())
207207
return ty.getFKind() <= 4
208208
? getIORuntimeFunc<mkIOKey(OutputComplex32)>(loc, builder)
209209
: getIORuntimeFunc<mkIOKey(OutputComplex64)>(loc, builder);
@@ -283,7 +283,7 @@ static mlir::FuncOp getInputFunc(mlir::Location loc,
283283
return ty.getWidth() <= 32
284284
? getIORuntimeFunc<mkIOKey(InputReal32)>(loc, builder)
285285
: getIORuntimeFunc<mkIOKey(InputReal64)>(loc, builder);
286-
if (auto ty = type.dyn_cast<fir::CplxType>())
286+
if (auto ty = type.dyn_cast<fir::ComplexType>())
287287
return ty.getFKind() <= 4
288288
? getIORuntimeFunc<mkIOKey(InputComplex32)>(loc, builder)
289289
: getIORuntimeFunc<mkIOKey(InputComplex64)>(loc, builder);

flang/lib/Lower/IntrinsicCall.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ class FunctionDistance {
439439
// - or use evaluate/type.h
440440
if (auto r{t.dyn_cast<fir::RealType>()})
441441
return r.getFKind() * 4;
442-
if (auto cplx{t.dyn_cast<fir::CplxType>()})
442+
if (auto cplx{t.dyn_cast<fir::ComplexType>()})
443443
return cplx.getFKind() * 4;
444444
llvm_unreachable("not a floating-point type");
445445
}
@@ -459,8 +459,8 @@ class FunctionDistance {
459459
? Conversion::Narrow
460460
: Conversion::Extend;
461461
}
462-
if (auto fromCplxTy{from.dyn_cast<fir::CplxType>()}) {
463-
if (auto toCplxTy{to.dyn_cast<fir::CplxType>()}) {
462+
if (auto fromCplxTy{from.dyn_cast<fir::ComplexType>()}) {
463+
if (auto toCplxTy{to.dyn_cast<fir::ComplexType>()}) {
464464
return getFloatingPointWidth(fromCplxTy) >
465465
getFloatingPointWidth(toCplxTy)
466466
? Conversion::Narrow
@@ -837,7 +837,7 @@ IntrinsicLibrary::outlineInWrapper(GeneratorType generator,
837837

838838
auto funcType = getFunctionType(resultType, args, builder);
839839
auto wrapper = getWrapper(generator, name, funcType);
840-
return builder.create<mlir::CallOp>(loc, wrapper, args).getResult(0);
840+
return builder.create<fir::CallOp>(loc, wrapper, args).getResult(0);
841841
}
842842

843843
fir::ExtendedValue
@@ -857,7 +857,7 @@ IntrinsicLibrary::outlineInWrapper(ExtendedGenerator generator,
857857
auto funcType = getFunctionType(resultType, mlirArgs, builder);
858858
auto wrapper = getWrapper(generator, name, funcType);
859859
auto mlirResult =
860-
builder.create<mlir::CallOp>(loc, wrapper, mlirArgs).getResult(0);
860+
builder.create<fir::CallOp>(loc, wrapper, mlirArgs).getResult(0);
861861
return toExtendedValue(mlirResult, builder, loc);
862862
}
863863

@@ -884,7 +884,7 @@ IntrinsicLibrary::getRuntimeCallGenerator(llvm::StringRef name,
884884
for (const auto &pair : llvm::zip(actualFuncType.getInputs(), args))
885885
convertedArguments.push_back(
886886
builder.createConvert(loc, std::get<0>(pair), std::get<1>(pair)));
887-
auto call = builder.create<mlir::CallOp>(loc, funcOp, convertedArguments);
887+
auto call = builder.create<fir::CallOp>(loc, funcOp, convertedArguments);
888888
mlir::Type soughtType = soughtFuncType.getResult(0);
889889
return builder.createConvert(loc, soughtType, call.getResult(0));
890890
};

0 commit comments

Comments
 (0)