Skip to content

[flang][cuda] Convert cuf.sync_descriptor to runtime call #121524

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,17 @@ def cuf_DeallocateOp : cuf_Op<"deallocate",
let hasVerifier = 1;
}

def cuf_SyncDescriptorOp : cuf_Op<"sync_descriptor", []> {
let summary =
"Synchronize the host and device descriptor of a Fortran pointer";

let arguments = (ins SymbolRefAttr:$globalName);

let assemblyFormat = [{
$globalName attr-dict
}];
}

def cuf_DataTransferOp : cuf_Op<"data_transfer", []> {
let summary = "Represent a data transfer between host and device memory";

Expand Down
4 changes: 4 additions & 0 deletions flang/include/flang/Runtime/CUDA/descriptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ void *RTDECL(CUFGetDeviceAddress)(
void RTDECL(CUFDescriptorSync)(Descriptor *dst, const Descriptor *src,
const char *sourceFile = nullptr, int sourceLine = 0);

/// Get the device address of registered with the \p hostPtr and sync them.
void RTDECL(CUFSyncGlobalDescriptor)(
void *hostPtr, const char *sourceFile = nullptr, int sourceLine = 0);

} // extern "C"

} // namespace Fortran::runtime::cuda
Expand Down
19 changes: 19 additions & 0 deletions flang/lib/Lower/Allocatable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
#include "flang/Lower/PFTBuilder.h"
#include "flang/Lower/Runtime.h"
#include "flang/Lower/StatementContext.h"
#include "flang/Optimizer/Builder/CUFCommon.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/Support/FatalError.h"
#include "flang/Optimizer/Support/InternalNames.h"
#include "flang/Parser/parse-tree.h"
Expand Down Expand Up @@ -1086,6 +1088,22 @@ bool Fortran::lower::isArraySectionWithoutVectorSubscript(
!Fortran::evaluate::HasVectorSubscript(expr);
}

static void genCUFPointerSync(const mlir::Value box,
fir::FirOpBuilder &builder) {
if (auto declareOp = box.getDefiningOp<hlfir::DeclareOp>()) {
if (auto addrOfOp = declareOp.getMemref().getDefiningOp<fir::AddrOfOp>()) {
auto mod = addrOfOp->getParentOfType<mlir::ModuleOp>();
if (auto globalOp =
mod.lookupSymbol<fir::GlobalOp>(addrOfOp.getSymbol())) {
if (cuf::isRegisteredDeviceGlobal(globalOp)) {
builder.create<cuf::SyncDescriptorOp>(box.getLoc(),
addrOfOp.getSymbol());
}
}
}
}
}

void Fortran::lower::associateMutableBox(
Fortran::lower::AbstractConverter &converter, mlir::Location loc,
const fir::MutableBoxValue &box, const Fortran::lower::SomeExpr &source,
Expand All @@ -1098,6 +1116,7 @@ void Fortran::lower::associateMutableBox(
if (converter.getLoweringOptions().getLowerToHighLevelFIR()) {
fir::ExtendedValue rhs = converter.genExprAddr(loc, source, stmtCtx);
fir::factory::associateMutableBox(builder, loc, box, rhs, lbounds);
genCUFPointerSync(box.getAddr(), builder);
return;
}
// The right hand side is not be evaluated into a temp. Array sections can
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Optimizer/Builder/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ add_flang_library(FIRBuilder
BoxValue.cpp
Character.cpp
Complex.cpp
CUFCommon.cpp
DoLoopHelper.cpp
FIRBuilder.cpp
HLFIRTools.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/Transforms/CUFCommon.h"
#include "flang/Optimizer/Builder/CUFCommon.h"
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
Expand Down
1 change: 0 additions & 1 deletion flang/lib/Optimizer/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ add_flang_library(FIRTransforms
CompilerGeneratedNames.cpp
ConstantArgumentGlobalisation.cpp
ControlFlowConverter.cpp
CUFCommon.cpp
CUFAddConstructor.cpp
CUFDeviceGlobal.cpp
CUFOpConversion.cpp
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/Builder/BoxValue.h"
#include "flang/Optimizer/Builder/CUFCommon.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
#include "flang/Optimizer/Builder/Todo.h"
Expand All @@ -19,7 +20,6 @@
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Support/DataLayout.h"
#include "flang/Optimizer/Transforms/CUFCommon.h"
#include "flang/Runtime/CUDA/registration.h"
#include "flang/Runtime/entry-names.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
//===----------------------------------------------------------------------===//

#include "flang/Common/Fortran.h"
#include "flang/Optimizer/Builder/CUFCommon.h"
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/Support/InternalNames.h"
#include "flang/Optimizer/Transforms/CUFCommon.h"
#include "flang/Runtime/CUDA/common.h"
#include "flang/Runtime/allocatable.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
Expand Down
44 changes: 42 additions & 2 deletions flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@

#include "flang/Optimizer/Transforms/CUFOpConversion.h"
#include "flang/Common/Fortran.h"
#include "flang/Optimizer/Builder/CUFCommon.h"
#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
#include "flang/Optimizer/CodeGen/TypeConverter.h"
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/Support/DataLayout.h"
#include "flang/Optimizer/Transforms/CUFCommon.h"
#include "flang/Runtime/CUDA/allocatable.h"
#include "flang/Runtime/CUDA/common.h"
#include "flang/Runtime/CUDA/descriptor.h"
Expand Down Expand Up @@ -788,6 +788,45 @@ struct CUFLaunchOpConversion
const mlir::SymbolTable &symTab;
};

struct CUFSyncDescriptorOpConversion
: public mlir::OpRewritePattern<cuf::SyncDescriptorOp> {
using OpRewritePattern::OpRewritePattern;

CUFSyncDescriptorOpConversion(mlir::MLIRContext *context,
const mlir::SymbolTable &symTab)
: OpRewritePattern(context), symTab{symTab} {}

mlir::LogicalResult
matchAndRewrite(cuf::SyncDescriptorOp op,
mlir::PatternRewriter &rewriter) const override {
auto mod = op->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, mod);
mlir::Location loc = op.getLoc();

auto globalOp = mod.lookupSymbol<fir::GlobalOp>(op.getGlobalName());
if (!globalOp)
return mlir::failure();

auto hostAddr = builder.create<fir::AddrOfOp>(
loc, fir::ReferenceType::get(globalOp.getType()), op.getGlobalName());
mlir::func::FuncOp callee =
fir::runtime::getRuntimeFunc<mkRTKey(CUFSyncGlobalDescriptor)>(loc,
builder);
auto fTy = callee.getFunctionType();
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
builder, loc, fTy, hostAddr, sourceFile, sourceLine)};
builder.create<fir::CallOp>(loc, callee, args);
op.erase();
return mlir::success();
}

private:
const mlir::SymbolTable &symTab;
};

class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
public:
void runOnOperation() override {
Expand Down Expand Up @@ -851,7 +890,8 @@ void cuf::populateCUFToFIRConversionPatterns(
CUFFreeOpConversion>(patterns.getContext());
patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab,
&dl, &converter);
patterns.insert<CUFLaunchOpConversion>(patterns.getContext(), symtab);
patterns.insert<CUFLaunchOpConversion, CUFSyncDescriptorOpConversion>(
patterns.getContext(), symtab);
}

void cuf::populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab,
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@

#include "flang/Common/Fortran.h"
#include "flang/Optimizer/Builder/BoxValue.h"
#include "flang/Optimizer/Builder/CUFCommon.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "flang/Optimizer/Transforms/CUFCommon.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "flang/Optimizer/Transforms/Utils.h"
#include "flang/Runtime/entry-names.h"
Expand Down
7 changes: 7 additions & 0 deletions flang/runtime/CUDA/descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ void RTDEF(CUFDescriptorSync)(Descriptor *dst, const Descriptor *src,
(void *)dst, (const void *)src, count, cudaMemcpyHostToDevice));
}

void RTDEF(CUFSyncGlobalDescriptor)(
void *hostPtr, const char *sourceFile, int sourceLine) {
void *devAddr{RTNAME(CUFGetDeviceAddress)(hostPtr, sourceFile, sourceLine)};
RTNAME(CUFDescriptorSync)
((Descriptor *)devAddr, (Descriptor *)hostPtr, sourceFile, sourceLine);
}

RT_EXT_API_GROUP_END
}
} // namespace Fortran::runtime::cuda
20 changes: 20 additions & 0 deletions flang/test/Fir/CUDA/cuda-sync-desc.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// RUN: fir-opt --cuf-convert %s | FileCheck %s

module attributes {dlti.dl_spec = #dlti.dl_spec<i16 = dense<16> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, f80 = dense<128> : vector<2xi64>, i1 = dense<8> : vector<2xi64>, !llvm.ptr = dense<64> : vector<4xi64>, i64 = dense<64> : vector<2xi64>, i128 = dense<128> : vector<2xi64>, f128 = dense<128> : vector<2xi64>, !llvm.ptr<270> = dense<32> : vector<4xi64>, f64 = dense<64> : vector<2xi64>, f16 = dense<16> : vector<2xi64>, !llvm.ptr<271> = dense<32> : vector<4xi64>, !llvm.ptr<272> = dense<64> : vector<4xi64>, "dlti.endianness" = "little", "dlti.stack_alignment" = 128 : i64>, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.ident = "flang version 20.0.0 ([email protected]:clementval/llvm-project.git f37e52237791f58438790c77edeb8de08f692987)", llvm.target_triple = "x86_64-unknown-linux-gnu"} {
fir.global @_QMdevptrEdev_ptr {data_attr = #cuf.cuda<device>} : !fir.box<!fir.ptr<!fir.array<?xf32>>> {
%0 = fir.zero_bits !fir.ptr<!fir.array<?xf32>>
%c0 = arith.constant 0 : index
%1 = fir.shape %c0 : (index) -> !fir.shape<1>
%2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.ptr<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.ptr<!fir.array<?xf32>>>
fir.has_value %2 : !fir.box<!fir.ptr<!fir.array<?xf32>>>
}
func.func @_QQmain() {
cuf.sync_descriptor @_QMdevptrEdev_ptr
return
}
}

// CHECK-LABEL: func.func @_QQmain()
// CHECK: %[[HOST_ADDR:.*]] = fir.address_of(@_QMdevptrEdev_ptr) : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
// CHECK: %[[HOST_ADDR_PTR:.*]] = fir.convert %[[HOST_ADDR]] : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.llvm_ptr<i8>
// CHECK: fir.call @_FortranACUFSyncGlobalDescriptor(%[[HOST_ADDR_PTR]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32)
17 changes: 17 additions & 0 deletions flang/test/Lower/CUDA/cuda-pointer-sync.cuf
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s

module devptr
real, device, pointer, dimension(:) :: dev_ptr
end module

use devptr
real, device, target, dimension(4) :: a_dev
a_dev = 42.0
dev_ptr => a_dev
end

! CHECK: fir.global @_QMdevptrEdev_ptr {data_attr = #cuf.cuda<device>} : !fir.box<!fir.ptr<!fir.array<?xf32>>>
! CHECK-LABEL: func.func @_QQmain()
! CHECK: fir.embox
! CHECK: fir.store
! CHECK: cuf.sync_descriptor @_QMdevptrEdev_ptr
Loading