Skip to content

Commit b9978f8

Browse files
authored
[flang][cuda] Adding variable registration in constructor (llvm#113976)
1) Adding variable registration in constructor 2) Applying feedback from PR llvm#112989
1 parent b4e1af0 commit b9978f8

File tree

3 files changed

+91
-5
lines changed

3 files changed

+91
-5
lines changed

flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,23 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "flang/Optimizer/Builder/BoxValue.h"
910
#include "flang/Optimizer/Builder/FIRBuilder.h"
11+
#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
12+
#include "flang/Optimizer/Builder/Todo.h"
13+
#include "flang/Optimizer/CodeGen/Target.h"
1014
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
1115
#include "flang/Optimizer/Dialect/FIRAttr.h"
1216
#include "flang/Optimizer/Dialect/FIRDialect.h"
17+
#include "flang/Optimizer/Dialect/FIROps.h"
1318
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
19+
#include "flang/Optimizer/Support/DataLayout.h"
1420
#include "flang/Optimizer/Transforms/CUFCommon.h"
21+
#include "flang/Runtime/CUDA/registration.h"
1522
#include "flang/Runtime/entry-names.h"
1623
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1724
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
25+
#include "mlir/IR/Value.h"
1826
#include "mlir/Pass/Pass.h"
1927
#include "llvm/ADT/SmallVector.h"
2028

@@ -23,6 +31,8 @@ namespace fir {
2331
#include "flang/Optimizer/Transforms/Passes.h.inc"
2432
} // namespace fir
2533

34+
using namespace Fortran::runtime::cuda;
35+
2636
namespace {
2737

2838
static constexpr llvm::StringRef cudaFortranCtorName{
@@ -34,13 +44,23 @@ struct CUFAddConstructor
3444
void runOnOperation() override {
3545
mlir::ModuleOp mod = getOperation();
3646
mlir::SymbolTable symTab(mod);
37-
mlir::OpBuilder builder{mod.getBodyRegion()};
47+
mlir::OpBuilder opBuilder{mod.getBodyRegion()};
48+
fir::FirOpBuilder builder(opBuilder, mod);
49+
fir::KindMapping kindMap{fir::getKindMapping(mod)};
3850
builder.setInsertionPointToEnd(mod.getBody());
3951
mlir::Location loc = mod.getLoc();
4052
auto *ctx = mod.getContext();
4153
auto voidTy = mlir::LLVM::LLVMVoidType::get(ctx);
54+
auto idxTy = builder.getIndexType();
4255
auto funcTy =
4356
mlir::LLVM::LLVMFunctionType::get(voidTy, {}, /*isVarArg=*/false);
57+
std::optional<mlir::DataLayout> dl =
58+
fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/false);
59+
if (!dl) {
60+
mlir::emitError(mod.getLoc(),
61+
"data layout attribute is required to perform " +
62+
getName() + "pass");
63+
}
4464

4565
// Symbol reference to CUFRegisterAllocator.
4666
builder.setInsertionPointToEnd(mod.getBody());
@@ -58,12 +78,13 @@ struct CUFAddConstructor
5878
builder.setInsertionPointToStart(func.addEntryBlock(builder));
5979
builder.create<mlir::LLVM::CallOp>(loc, funcTy, cufRegisterAllocatorRef);
6080

61-
// Register kernels
6281
auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(cudaDeviceModuleName);
6382
if (gpuMod) {
6483
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(ctx);
6584
auto registeredMod = builder.create<cuf::RegisterModuleOp>(
6685
loc, llvmPtrTy, mlir::SymbolRefAttr::get(ctx, gpuMod.getName()));
86+
87+
// Register kernels
6788
for (auto func : gpuMod.getOps<mlir::gpu::GPUFuncOp>()) {
6889
if (func.isKernel()) {
6990
auto kernelName = mlir::SymbolRefAttr::get(
@@ -72,12 +93,55 @@ struct CUFAddConstructor
7293
builder.create<cuf::RegisterKernelOp>(loc, kernelName, registeredMod);
7394
}
7495
}
96+
97+
// Register variables
98+
for (fir::GlobalOp globalOp : mod.getOps<fir::GlobalOp>()) {
99+
auto attr = globalOp.getDataAttrAttr();
100+
if (!attr)
101+
continue;
102+
103+
mlir::func::FuncOp func;
104+
switch (attr.getValue()) {
105+
case cuf::DataAttribute::Device:
106+
case cuf::DataAttribute::Constant: {
107+
func = fir::runtime::getRuntimeFunc<mkRTKey(CUFRegisterVariable)>(
108+
loc, builder);
109+
auto fTy = func.getFunctionType();
110+
111+
// Global variable name
112+
std::string gblNameStr = globalOp.getSymbol().getValue().str();
113+
gblNameStr += '\0';
114+
mlir::Value gblName = fir::getBase(
115+
fir::factory::createStringLiteral(builder, loc, gblNameStr));
116+
117+
// Global variable size
118+
auto sizeAndAlign = fir::getTypeSizeAndAlignmentOrCrash(
119+
loc, globalOp.getType(), *dl, kindMap);
120+
auto size =
121+
builder.createIntegerConstant(loc, idxTy, sizeAndAlign.first);
122+
123+
// Global variable address
124+
mlir::Value addr = builder.create<fir::AddrOfOp>(
125+
loc, globalOp.resultType(), globalOp.getSymbol());
126+
127+
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
128+
builder, loc, fTy, registeredMod, addr, gblName, size)};
129+
builder.create<fir::CallOp>(loc, func, args);
130+
} break;
131+
case cuf::DataAttribute::Managed:
132+
TODO(loc, "registration of managed variables");
133+
default:
134+
break;
135+
}
136+
if (!func)
137+
continue;
138+
}
75139
}
76140
builder.create<mlir::LLVM::ReturnOp>(loc, mlir::ValueRange{});
77141

78142
// Create the llvm.global_ctor with the function.
79-
// TODO: We might want to have a utility that retrieve it if already created
80-
// and adds new functions.
143+
// TODO: We might want to have a utility that retrieve it if already
144+
// created and adds new functions.
81145
builder.setInsertionPointToEnd(mod.getBody());
82146
llvm::SmallVector<mlir::Attribute> funcs;
83147
funcs.push_back(

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ mlir::Value getDeviceAddress(mlir::PatternRewriter &rewriter,
111111
switch (attr.getValue()) {
112112
case cuf::DataAttribute::Device:
113113
case cuf::DataAttribute::Managed:
114-
case cuf::DataAttribute::Pinned:
114+
case cuf::DataAttribute::Constant:
115115
isDevGlobal = true;
116116
break;
117117
default:
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: fir-opt --split-input-file --cuf-add-constructor %s | FileCheck %s
2+
3+
module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", gpu.container_module, llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.ident = "flang version 20.0.0 (https://github.com/llvm/llvm-project.git cae351f3453a0a26ec8eb2ddaf773c24a29d929e)", llvm.target_triple = "x86_64-unknown-linux-gnu"} {
4+
5+
fir.global @_QMmtestsEn(dense<[3, 4, 5, 6, 7]> : tensor<5xi32>) {data_attr = #cuf.cuda<device>} : !fir.array<5xi32>
6+
7+
gpu.module @cuda_device_mod [#nvvm.target] {
8+
}
9+
}
10+
11+
// CHECK: gpu.module @cuda_device_mod [#nvvm.target]
12+
13+
// CHECK: llvm.func internal @__cudaFortranConstructor() {
14+
// CHECK-DAG: %[[MODULE:.*]] = cuf.register_module @cuda_device_mod -> !llvm.ptr
15+
// CHECK-DAG: %[[VAR_NAME:.*]] = fir.address_of(@_QQ{{.*}}) : !fir.ref<!fir.char<1,12>>
16+
// CHECK-DAG: %[[VAR_ADDR:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
17+
// CHECK-DAG: %[[MODULE2:.*]] = fir.convert %[[MODULE]] : (!llvm.ptr) -> !fir.ref<!fir.llvm_ptr<i8>>
18+
// CHECK-DAG: %[[VAR_ADDR2:.*]] = fir.convert %[[VAR_ADDR]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.ref<i8>
19+
// CHECK-DAG: %[[VAR_NAME2:.*]] = fir.convert %[[VAR_NAME]] : (!fir.ref<!fir.char<1,12>>) -> !fir.ref<i8>
20+
// CHECK-DAG: %[[CST:.*]] = arith.constant 20 : index
21+
// CHECK-DAG %[[CST2:.*]] = fir.convert %[[CST]] : (index) -> i64
22+
// CHECK fir.call @_FortranACUFRegisterVariable(%[[MODULE2]], %[[VAR_ADDR2]], %[[VAR_NAME2]], %[[CST2]]) : (!fir.ref<!fir.llvm_ptr<i8>>, !fir.ref<i8>, !fir.ref<i8>, i64) -> none

0 commit comments

Comments
 (0)