6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
8
9
+ #include " flang/Optimizer/Builder/BoxValue.h"
9
10
#include " flang/Optimizer/Builder/FIRBuilder.h"
11
+ #include " flang/Optimizer/Builder/Runtime/RTBuilder.h"
12
+ #include " flang/Optimizer/Builder/Todo.h"
13
+ #include " flang/Optimizer/CodeGen/Target.h"
10
14
#include " flang/Optimizer/Dialect/CUF/CUFOps.h"
11
15
#include " flang/Optimizer/Dialect/FIRAttr.h"
12
16
#include " flang/Optimizer/Dialect/FIRDialect.h"
17
+ #include " flang/Optimizer/Dialect/FIROps.h"
13
18
#include " flang/Optimizer/Dialect/FIROpsSupport.h"
19
+ #include " flang/Optimizer/Support/DataLayout.h"
14
20
#include " flang/Optimizer/Transforms/CUFCommon.h"
21
+ #include " flang/Runtime/CUDA/registration.h"
15
22
#include " flang/Runtime/entry-names.h"
16
23
#include " mlir/Dialect/GPU/IR/GPUDialect.h"
17
24
#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
25
+ #include " mlir/IR/Value.h"
18
26
#include " mlir/Pass/Pass.h"
19
27
#include " llvm/ADT/SmallVector.h"
20
28
@@ -23,6 +31,8 @@ namespace fir {
23
31
#include " flang/Optimizer/Transforms/Passes.h.inc"
24
32
} // namespace fir
25
33
34
+ using namespace Fortran ::runtime::cuda;
35
+
26
36
namespace {
27
37
28
38
static constexpr llvm::StringRef cudaFortranCtorName{
@@ -34,13 +44,23 @@ struct CUFAddConstructor
34
44
void runOnOperation () override {
35
45
mlir::ModuleOp mod = getOperation ();
36
46
mlir::SymbolTable symTab (mod);
37
- mlir::OpBuilder builder{mod.getBodyRegion ()};
47
+ mlir::OpBuilder opBuilder{mod.getBodyRegion ()};
48
+ fir::FirOpBuilder builder (opBuilder, mod);
49
+ fir::KindMapping kindMap{fir::getKindMapping (mod)};
38
50
builder.setInsertionPointToEnd (mod.getBody ());
39
51
mlir::Location loc = mod.getLoc ();
40
52
auto *ctx = mod.getContext ();
41
53
auto voidTy = mlir::LLVM::LLVMVoidType::get (ctx);
54
+ auto idxTy = builder.getIndexType ();
42
55
auto funcTy =
43
56
mlir::LLVM::LLVMFunctionType::get (voidTy, {}, /* isVarArg=*/ false );
57
+ std::optional<mlir::DataLayout> dl =
58
+ fir::support::getOrSetDataLayout (mod, /* allowDefaultLayout=*/ false );
59
+ if (!dl) {
60
+ mlir::emitError (mod.getLoc (),
61
+ " data layout attribute is required to perform " +
62
+ getName () + " pass" );
63
+ }
44
64
45
65
// Symbol reference to CUFRegisterAllocator.
46
66
builder.setInsertionPointToEnd (mod.getBody ());
@@ -58,12 +78,13 @@ struct CUFAddConstructor
58
78
builder.setInsertionPointToStart (func.addEntryBlock (builder));
59
79
builder.create <mlir::LLVM::CallOp>(loc, funcTy, cufRegisterAllocatorRef);
60
80
61
- // Register kernels
62
81
auto gpuMod = symTab.lookup <mlir::gpu::GPUModuleOp>(cudaDeviceModuleName);
63
82
if (gpuMod) {
64
83
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get (ctx);
65
84
auto registeredMod = builder.create <cuf::RegisterModuleOp>(
66
85
loc, llvmPtrTy, mlir::SymbolRefAttr::get (ctx, gpuMod.getName ()));
86
+
87
+ // Register kernels
67
88
for (auto func : gpuMod.getOps <mlir::gpu::GPUFuncOp>()) {
68
89
if (func.isKernel ()) {
69
90
auto kernelName = mlir::SymbolRefAttr::get (
@@ -72,12 +93,55 @@ struct CUFAddConstructor
72
93
builder.create <cuf::RegisterKernelOp>(loc, kernelName, registeredMod);
73
94
}
74
95
}
96
+
97
+ // Register variables
98
+ for (fir::GlobalOp globalOp : mod.getOps <fir::GlobalOp>()) {
99
+ auto attr = globalOp.getDataAttrAttr ();
100
+ if (!attr)
101
+ continue ;
102
+
103
+ mlir::func::FuncOp func;
104
+ switch (attr.getValue ()) {
105
+ case cuf::DataAttribute::Device:
106
+ case cuf::DataAttribute::Constant: {
107
+ func = fir::runtime::getRuntimeFunc<mkRTKey (CUFRegisterVariable)>(
108
+ loc, builder);
109
+ auto fTy = func.getFunctionType ();
110
+
111
+ // Global variable name
112
+ std::string gblNameStr = globalOp.getSymbol ().getValue ().str ();
113
+ gblNameStr += ' \0 ' ;
114
+ mlir::Value gblName = fir::getBase (
115
+ fir::factory::createStringLiteral (builder, loc, gblNameStr));
116
+
117
+ // Global variable size
118
+ auto sizeAndAlign = fir::getTypeSizeAndAlignmentOrCrash (
119
+ loc, globalOp.getType (), *dl, kindMap);
120
+ auto size =
121
+ builder.createIntegerConstant (loc, idxTy, sizeAndAlign.first );
122
+
123
+ // Global variable address
124
+ mlir::Value addr = builder.create <fir::AddrOfOp>(
125
+ loc, globalOp.resultType (), globalOp.getSymbol ());
126
+
127
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
128
+ builder, loc, fTy , registeredMod, addr, gblName, size)};
129
+ builder.create <fir::CallOp>(loc, func, args);
130
+ } break ;
131
+ case cuf::DataAttribute::Managed:
132
+ TODO (loc, " registration of managed variables" );
133
+ default :
134
+ break ;
135
+ }
136
+ if (!func)
137
+ continue ;
138
+ }
75
139
}
76
140
builder.create <mlir::LLVM::ReturnOp>(loc, mlir::ValueRange{});
77
141
78
142
// Create the llvm.global_ctor with the function.
79
- // TODO: We might want to have a utility that retrieve it if already created
80
- // and adds new functions.
143
+ // TODO: We might want to have a utility that retrieve it if already
144
+ // created and adds new functions.
81
145
builder.setInsertionPointToEnd (mod.getBody ());
82
146
llvm::SmallVector<mlir::Attribute> funcs;
83
147
funcs.push_back (
0 commit comments