Skip to content

[flang][cuda] Adding variable registration in constructor #113976

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 68 additions & 4 deletions flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,23 @@
//
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/Builder/BoxValue.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/CodeGen/Target.h"
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
#include "flang/Optimizer/Dialect/FIRAttr.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
#include "flang/Optimizer/Support/DataLayout.h"
#include "flang/Optimizer/Transforms/CUFCommon.h"
#include "flang/Runtime/CUDA/registration.h"
#include "flang/Runtime/entry-names.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/SmallVector.h"

Expand All @@ -23,6 +31,8 @@ namespace fir {
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir

using namespace Fortran::runtime::cuda;

namespace {

static constexpr llvm::StringRef cudaFortranCtorName{
Expand All @@ -34,13 +44,23 @@ struct CUFAddConstructor
void runOnOperation() override {
mlir::ModuleOp mod = getOperation();
mlir::SymbolTable symTab(mod);
mlir::OpBuilder builder{mod.getBodyRegion()};
mlir::OpBuilder opBuilder{mod.getBodyRegion()};
fir::FirOpBuilder builder(opBuilder, mod);
fir::KindMapping kindMap{fir::getKindMapping(mod)};
builder.setInsertionPointToEnd(mod.getBody());
mlir::Location loc = mod.getLoc();
auto *ctx = mod.getContext();
auto voidTy = mlir::LLVM::LLVMVoidType::get(ctx);
auto idxTy = builder.getIndexType();
auto funcTy =
mlir::LLVM::LLVMFunctionType::get(voidTy, {}, /*isVarArg=*/false);
std::optional<mlir::DataLayout> dl =
fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/false);
if (!dl) {
mlir::emitError(mod.getLoc(),
"data layout attribute is required to perform " +
getName() + "pass");
}

// Symbol reference to CUFRegisterAllocator.
builder.setInsertionPointToEnd(mod.getBody());
Expand All @@ -58,12 +78,13 @@ struct CUFAddConstructor
builder.setInsertionPointToStart(func.addEntryBlock(builder));
builder.create<mlir::LLVM::CallOp>(loc, funcTy, cufRegisterAllocatorRef);

// Register kernels
auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(cudaDeviceModuleName);
if (gpuMod) {
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(ctx);
auto registeredMod = builder.create<cuf::RegisterModuleOp>(
loc, llvmPtrTy, mlir::SymbolRefAttr::get(ctx, gpuMod.getName()));

// Register kernels
for (auto func : gpuMod.getOps<mlir::gpu::GPUFuncOp>()) {
if (func.isKernel()) {
auto kernelName = mlir::SymbolRefAttr::get(
Expand All @@ -72,12 +93,55 @@ struct CUFAddConstructor
builder.create<cuf::RegisterKernelOp>(loc, kernelName, registeredMod);
}
}

// Register variables
for (fir::GlobalOp globalOp : mod.getOps<fir::GlobalOp>()) {
auto attr = globalOp.getDataAttrAttr();
if (!attr)
continue;

mlir::func::FuncOp func;
switch (attr.getValue()) {
case cuf::DataAttribute::Device:
case cuf::DataAttribute::Constant: {
func = fir::runtime::getRuntimeFunc<mkRTKey(CUFRegisterVariable)>(
loc, builder);
auto fTy = func.getFunctionType();

// Global variable name
std::string gblNameStr = globalOp.getSymbol().getValue().str();
gblNameStr += '\0';
mlir::Value gblName = fir::getBase(
fir::factory::createStringLiteral(builder, loc, gblNameStr));

// Global variable size
auto sizeAndAlign = fir::getTypeSizeAndAlignmentOrCrash(
loc, globalOp.getType(), *dl, kindMap);
auto size =
builder.createIntegerConstant(loc, idxTy, sizeAndAlign.first);

// Global variable address
mlir::Value addr = builder.create<fir::AddrOfOp>(
loc, globalOp.resultType(), globalOp.getSymbol());

llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
builder, loc, fTy, registeredMod, addr, gblName, size)};
builder.create<fir::CallOp>(loc, func, args);
} break;
case cuf::DataAttribute::Managed:
TODO(loc, "registration of managed variables");
default:
break;
}
if (!func)
continue;
}
}
builder.create<mlir::LLVM::ReturnOp>(loc, mlir::ValueRange{});

// Create the llvm.global_ctor with the function.
// TODO: We might want to have a utility that retrieve it if already created
// and adds new functions.
// TODO: We might want to have a utility that retrieve it if already
// created and adds new functions.
builder.setInsertionPointToEnd(mod.getBody());
llvm::SmallVector<mlir::Attribute> funcs;
funcs.push_back(
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ mlir::Value getDeviceAddress(mlir::PatternRewriter &rewriter,
switch (attr.getValue()) {
case cuf::DataAttribute::Device:
case cuf::DataAttribute::Managed:
case cuf::DataAttribute::Pinned:
case cuf::DataAttribute::Constant:
isDevGlobal = true;
break;
default:
Expand Down
22 changes: 22 additions & 0 deletions flang/test/Fir/CUDA/cuda-constructor-2.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// RUN: fir-opt --split-input-file --cuf-add-constructor %s | FileCheck %s

module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", gpu.container_module, llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.ident = "flang version 20.0.0 (https://github.com/llvm/llvm-project.git cae351f3453a0a26ec8eb2ddaf773c24a29d929e)", llvm.target_triple = "x86_64-unknown-linux-gnu"} {

fir.global @_QMmtestsEn(dense<[3, 4, 5, 6, 7]> : tensor<5xi32>) {data_attr = #cuf.cuda<device>} : !fir.array<5xi32>

gpu.module @cuda_device_mod [#nvvm.target] {
}
}

// CHECK: gpu.module @cuda_device_mod [#nvvm.target]

// CHECK: llvm.func internal @__cudaFortranConstructor() {
// CHECK-DAG: %[[MODULE:.*]] = cuf.register_module @cuda_device_mod -> !llvm.ptr
// CHECK-DAG: %[[VAR_NAME:.*]] = fir.address_of(@_QQ{{.*}}) : !fir.ref<!fir.char<1,12>>
// CHECK-DAG: %[[VAR_ADDR:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
// CHECK-DAG: %[[MODULE2:.*]] = fir.convert %[[MODULE]] : (!llvm.ptr) -> !fir.ref<!fir.llvm_ptr<i8>>
// CHECK-DAG: %[[VAR_ADDR2:.*]] = fir.convert %[[VAR_ADDR]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.ref<i8>
// CHECK-DAG: %[[VAR_NAME2:.*]] = fir.convert %[[VAR_NAME]] : (!fir.ref<!fir.char<1,12>>) -> !fir.ref<i8>
// CHECK-DAG: %[[CST:.*]] = arith.constant 20 : index
// CHECK-DAG %[[CST2:.*]] = fir.convert %[[CST]] : (index) -> i64
// CHECK fir.call @_FortranACUFRegisterVariable(%[[MODULE2]], %[[VAR_ADDR2]], %[[VAR_NAME2]], %[[CST2]]) : (!fir.ref<!fir.llvm_ptr<i8>>, !fir.ref<i8>, !fir.ref<i8>, i64) -> none
Loading