Skip to content

[flang][cuda] Allocate descriptor in managed memory when emboxing device memory #120485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 116 additions & 95 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "flang/Optimizer/Support/TypeCode.h"
#include "flang/Optimizer/Support/Utils.h"
#include "flang/Runtime/CUDA/descriptor.h"
#include "flang/Runtime/CUDA/memory.h"
#include "flang/Runtime/allocator-registry-consts.h"
#include "flang/Runtime/descriptor-consts.h"
#include "flang/Semantics/runtime-type-info.h"
Expand Down Expand Up @@ -1135,6 +1136,93 @@ convertSubcomponentIndices(mlir::Location loc, mlir::Type eleTy,
return result;
}

static mlir::Value genSourceFile(mlir::Location loc, mlir::ModuleOp mod,
mlir::ConversionPatternRewriter &rewriter) {
auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext());
if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc)) {
auto fn = flc.getFilename().str() + '\0';
std::string globalName = fir::factory::uniqueCGIdent("cl", fn);

if (auto g = mod.lookupSymbol<fir::GlobalOp>(globalName)) {
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName());
} else if (auto g = mod.lookupSymbol<mlir::LLVM::GlobalOp>(globalName)) {
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName());
}

auto crtInsPt = rewriter.saveInsertionPoint();
rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end());
auto arrayTy = mlir::LLVM::LLVMArrayType::get(
mlir::IntegerType::get(rewriter.getContext(), 8), fn.size());
mlir::LLVM::GlobalOp globalOp = rewriter.create<mlir::LLVM::GlobalOp>(
loc, arrayTy, /*constant=*/true, mlir::LLVM::Linkage::Linkonce,
globalName, mlir::Attribute());

mlir::Region &region = globalOp.getInitializerRegion();
mlir::Block *block = rewriter.createBlock(&region);
rewriter.setInsertionPoint(block, block->begin());
mlir::Value constValue = rewriter.create<mlir::LLVM::ConstantOp>(
loc, arrayTy, rewriter.getStringAttr(fn));
rewriter.create<mlir::LLVM::ReturnOp>(loc, constValue);
rewriter.restoreInsertionPoint(crtInsPt);
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy,
globalOp.getName());
}
return rewriter.create<mlir::LLVM::ZeroOp>(loc, ptrTy);
}

static mlir::Value genSourceLine(mlir::Location loc,
mlir::ConversionPatternRewriter &rewriter) {
if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc))
return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
flc.getLine());
return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0);
}

static mlir::Value
genCUFAllocDescriptor(mlir::Location loc,
mlir::ConversionPatternRewriter &rewriter,
mlir::ModuleOp mod, fir::BaseBoxType boxTy,
const fir::LLVMTypeConverter &typeConverter) {
std::optional<mlir::DataLayout> dl =
fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/true);
if (!dl)
mlir::emitError(mod.getLoc(),
"module operation must carry a data layout attribute "
"to generate llvm IR from FIR");

mlir::Value sourceFile = genSourceFile(loc, mod, rewriter);
mlir::Value sourceLine = genSourceLine(loc, rewriter);

mlir::MLIRContext *ctx = mod.getContext();

mlir::LLVM::LLVMPointerType llvmPointerType =
mlir::LLVM::LLVMPointerType::get(ctx);
mlir::Type llvmInt32Type = mlir::IntegerType::get(ctx, 32);
mlir::Type llvmIntPtrType =
mlir::IntegerType::get(ctx, typeConverter.getPointerBitwidth(0));
auto fctTy = mlir::LLVM::LLVMFunctionType::get(
llvmPointerType, {llvmIntPtrType, llvmPointerType, llvmInt32Type});

auto llvmFunc = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>(
RTNAME_STRING(CUFAllocDesciptor));
auto funcFunc =
mod.lookupSymbol<mlir::func::FuncOp>(RTNAME_STRING(CUFAllocDesciptor));
if (!llvmFunc && !funcFunc)
mlir::OpBuilder::atBlockEnd(mod.getBody())
.create<mlir::LLVM::LLVMFuncOp>(loc, RTNAME_STRING(CUFAllocDesciptor),
fctTy);

mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy);
std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8;
mlir::Value sizeInBytes =
genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize);
llvm::SmallVector args = {sizeInBytes, sourceFile, sourceLine};
return rewriter
.create<mlir::LLVM::CallOp>(loc, fctTy, RTNAME_STRING(CUFAllocDesciptor),
args)
.getResult();
}

/// Common base class for embox to descriptor conversion.
template <typename OP>
struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
Expand Down Expand Up @@ -1548,15 +1636,24 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
mlir::Value
placeInMemoryIfNotGlobalInit(mlir::ConversionPatternRewriter &rewriter,
mlir::Location loc, mlir::Type boxTy,
mlir::Value boxValue) const {
mlir::Value boxValue,
bool needDeviceAllocation = false) const {
if (isInGlobalOp(rewriter))
return boxValue;
mlir::Type llvmBoxTy = boxValue.getType();
auto alloca = this->genAllocaAndAddrCastWithType(loc, llvmBoxTy,
defaultAlign, rewriter);
auto storeOp = rewriter.create<mlir::LLVM::StoreOp>(loc, boxValue, alloca);
mlir::Value storage;
if (needDeviceAllocation) {
auto mod = boxValue.getDefiningOp()->getParentOfType<mlir::ModuleOp>();
auto baseBoxTy = mlir::dyn_cast<fir::BaseBoxType>(boxTy);
storage =
genCUFAllocDescriptor(loc, rewriter, mod, baseBoxTy, this->lowerTy());
} else {
storage = this->genAllocaAndAddrCastWithType(loc, llvmBoxTy, defaultAlign,
rewriter);
}
auto storeOp = rewriter.create<mlir::LLVM::StoreOp>(loc, boxValue, storage);
this->attachTBAATag(storeOp, boxTy, boxTy, nullptr);
return alloca;
return storage;
}
};

Expand Down Expand Up @@ -1608,6 +1705,18 @@ struct EmboxOpConversion : public EmboxCommonConversion<fir::EmboxOp> {
}
};

static bool isDeviceAllocation(mlir::Value val) {
if (auto convertOp =
mlir::dyn_cast_or_null<fir::ConvertOp>(val.getDefiningOp()))
val = convertOp.getValue();
if (auto callOp = mlir::dyn_cast_or_null<fir::CallOp>(val.getDefiningOp()))
if (callOp.getCallee() &&
callOp.getCallee().value().getRootReference().getValue().starts_with(
RTNAME_STRING(CUFMemAlloc)))
return true;
return false;
}

/// Create a generic box on a memory reference.
struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
using EmboxCommonConversion::EmboxCommonConversion;
Expand Down Expand Up @@ -1791,9 +1900,8 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
dest = insertBaseAddress(rewriter, loc, dest, base);
if (fir::isDerivedTypeWithLenParams(boxTy))
TODO(loc, "fir.embox codegen of derived with length parameters");

mlir::Value result =
placeInMemoryIfNotGlobalInit(rewriter, loc, boxTy, dest);
mlir::Value result = placeInMemoryIfNotGlobalInit(
rewriter, loc, boxTy, dest, isDeviceAllocation(xbox.getMemref()));
rewriter.replaceOp(xbox, result);
return mlir::success();
}
Expand Down Expand Up @@ -2971,93 +3079,6 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
}
};

static mlir::Value genSourceFile(mlir::Location loc, mlir::ModuleOp mod,
mlir::ConversionPatternRewriter &rewriter) {
auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext());
if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc)) {
auto fn = flc.getFilename().str() + '\0';
std::string globalName = fir::factory::uniqueCGIdent("cl", fn);

if (auto g = mod.lookupSymbol<fir::GlobalOp>(globalName)) {
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName());
} else if (auto g = mod.lookupSymbol<mlir::LLVM::GlobalOp>(globalName)) {
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName());
}

auto crtInsPt = rewriter.saveInsertionPoint();
rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end());
auto arrayTy = mlir::LLVM::LLVMArrayType::get(
mlir::IntegerType::get(rewriter.getContext(), 8), fn.size());
mlir::LLVM::GlobalOp globalOp = rewriter.create<mlir::LLVM::GlobalOp>(
loc, arrayTy, /*constant=*/true, mlir::LLVM::Linkage::Linkonce,
globalName, mlir::Attribute());

mlir::Region &region = globalOp.getInitializerRegion();
mlir::Block *block = rewriter.createBlock(&region);
rewriter.setInsertionPoint(block, block->begin());
mlir::Value constValue = rewriter.create<mlir::LLVM::ConstantOp>(
loc, arrayTy, rewriter.getStringAttr(fn));
rewriter.create<mlir::LLVM::ReturnOp>(loc, constValue);
rewriter.restoreInsertionPoint(crtInsPt);
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy,
globalOp.getName());
}
return rewriter.create<mlir::LLVM::ZeroOp>(loc, ptrTy);
}

static mlir::Value genSourceLine(mlir::Location loc,
mlir::ConversionPatternRewriter &rewriter) {
if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc))
return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
flc.getLine());
return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0);
}

static mlir::Value
genCUFAllocDescriptor(mlir::Location loc,
mlir::ConversionPatternRewriter &rewriter,
mlir::ModuleOp mod, fir::BaseBoxType boxTy,
const fir::LLVMTypeConverter &typeConverter) {
std::optional<mlir::DataLayout> dl =
fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/true);
if (!dl)
mlir::emitError(mod.getLoc(),
"module operation must carry a data layout attribute "
"to generate llvm IR from FIR");

mlir::Value sourceFile = genSourceFile(loc, mod, rewriter);
mlir::Value sourceLine = genSourceLine(loc, rewriter);

mlir::MLIRContext *ctx = mod.getContext();

mlir::LLVM::LLVMPointerType llvmPointerType =
mlir::LLVM::LLVMPointerType::get(ctx);
mlir::Type llvmInt32Type = mlir::IntegerType::get(ctx, 32);
mlir::Type llvmIntPtrType =
mlir::IntegerType::get(ctx, typeConverter.getPointerBitwidth(0));
auto fctTy = mlir::LLVM::LLVMFunctionType::get(
llvmPointerType, {llvmIntPtrType, llvmPointerType, llvmInt32Type});

auto llvmFunc = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>(
RTNAME_STRING(CUFAllocDesciptor));
auto funcFunc =
mod.lookupSymbol<mlir::func::FuncOp>(RTNAME_STRING(CUFAllocDesciptor));
if (!llvmFunc && !funcFunc)
mlir::OpBuilder::atBlockEnd(mod.getBody())
.create<mlir::LLVM::LLVMFuncOp>(loc, RTNAME_STRING(CUFAllocDesciptor),
fctTy);

mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy);
std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8;
mlir::Value sizeInBytes =
genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize);
llvm::SmallVector args = {sizeInBytes, sourceFile, sourceLine};
return rewriter
.create<mlir::LLVM::CallOp>(loc, fctTy, RTNAME_STRING(CUFAllocDesciptor),
args)
.getResult();
}

/// `fir.load` --> `llvm.load`
struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
using FIROpConversion::FIROpConversion;
Expand Down
31 changes: 30 additions & 1 deletion flang/test/Fir/CUDA/cuda-code-gen.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// RUN: fir-opt --split-input-file --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" %s | FileCheck %s

module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {

func.func @_QQmain() attributes {fir.bindc_name = "cufkernel_global"} {
%c0 = arith.constant 0 : index
%0 = fir.address_of(@_QQclX3C737464696E3E00) : !fir.ref<!fir.char<1,8>>
Expand All @@ -27,3 +26,33 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> :
}
func.func private @_FortranACUFAllocDesciptor(i64, !fir.ref<i8>, i32) -> !fir.ref<!fir.box<none>> attributes {fir.runtime}
}

// -----

module attributes {dlti.dl_spec = #dlti.dl_spec<f80 = dense<128> : vector<2xi64>, i128 = dense<128> : vector<2xi64>, i64 = dense<64> : vector<2xi64>, !llvm.ptr<272> = dense<64> : vector<4xi64>, !llvm.ptr<271> = dense<32> : vector<4xi64>, !llvm.ptr<270> = dense<32> : vector<4xi64>, f128 = dense<128> : vector<2xi64>, f64 = dense<64> : vector<2xi64>, f16 = dense<16> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, i16 = dense<16> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, i1 = dense<8> : vector<2xi64>, !llvm.ptr = dense<64> : vector<4xi64>, "dlti.endianness" = "little", "dlti.stack_alignment" = 128 : i64>} {
func.func @_QQmain() attributes {fir.bindc_name = "test"} {
%c10 = arith.constant 10 : index
%c20 = arith.constant 20 : index
%0 = fir.address_of(@_QQclX64756D6D792E6D6C697200) : !fir.ref<!fir.char<1,11>>
%c4 = arith.constant 4 : index
%c200 = arith.constant 200 : index
%1 = arith.muli %c200, %c4 : index
%c6_i32 = arith.constant 6 : i32
%c0_i32 = arith.constant 0 : i32
%2 = fir.convert %1 : (index) -> i64
%3 = fir.convert %0 : (!fir.ref<!fir.char<1,11>>) -> !fir.ref<i8>
%4 = fir.call @_FortranACUFMemAlloc(%2, %c0_i32, %3, %c6_i32) : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
%5 = fir.convert %4 : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<10x20xi32>>
%6 = fircg.ext_embox %5(%c10, %c20) : (!fir.ref<!fir.array<10x20xi32>>, index, index) -> !fir.box<!fir.array<10x20xi32>>
return
}
fir.global linkonce @_QQclX64756D6D792E6D6C697200 constant : !fir.char<1,11> {
%0 = fir.string_lit "dummy.mlir\00"(11) : !fir.char<1,11>
fir.has_value %0 : !fir.char<1,11>
}
func.func private @_FortranACUFMemAlloc(i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8> attributes {fir.runtime}
}

// CHECK-LABEL: llvm.func @_QQmain()
// CHECK: llvm.call @_FortranACUFMemAlloc
// CHECK: llvm.call @_FortranACUFAllocDesciptor
Loading