Skip to content

Commit ba4dc5a

Browse files
authored
[flang][cuda] Pass the device address for global descriptor (#122802)
1 parent 5ea1c87 commit ba4dc5a

File tree

2 files changed

+86
-21
lines changed

2 files changed

+86
-21
lines changed

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,23 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
366366
const fir::LLVMTypeConverter *typeConverter;
367367
};
368368

369+
static mlir::Value genGetDeviceAddress(mlir::PatternRewriter &rewriter,
370+
mlir::ModuleOp mod, mlir::Location loc,
371+
mlir::Value inputArg) {
372+
fir::FirOpBuilder builder(rewriter, mod);
373+
mlir::func::FuncOp callee =
374+
fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc, builder);
375+
auto fTy = callee.getFunctionType();
376+
mlir::Value conv = createConvertOp(rewriter, loc, fTy.getInput(0), inputArg);
377+
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
378+
mlir::Value sourceLine =
379+
fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
380+
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
381+
builder, loc, fTy, conv, sourceFile, sourceLine)};
382+
auto call = rewriter.create<fir::CallOp>(loc, callee, args);
383+
return createConvertOp(rewriter, loc, inputArg.getType(), call->getResult(0));
384+
}
385+
369386
struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
370387
using OpRewritePattern::OpRewritePattern;
371388

@@ -382,26 +399,10 @@ struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
382399
if (cuf::isRegisteredDeviceGlobal(global)) {
383400
rewriter.setInsertionPointAfter(addrOfOp);
384401
auto mod = op->getParentOfType<mlir::ModuleOp>();
385-
fir::FirOpBuilder builder(rewriter, mod);
386-
mlir::Location loc = op.getLoc();
387-
mlir::func::FuncOp callee =
388-
fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(
389-
loc, builder);
390-
auto fTy = callee.getFunctionType();
391-
mlir::Type toTy = fTy.getInput(0);
392-
mlir::Value inputArg =
393-
createConvertOp(rewriter, loc, toTy, addrOfOp.getResult());
394-
mlir::Value sourceFile =
395-
fir::factory::locationToFilename(builder, loc);
396-
mlir::Value sourceLine =
397-
fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
398-
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
399-
builder, loc, fTy, inputArg, sourceFile, sourceLine)};
400-
auto call = rewriter.create<fir::CallOp>(loc, callee, args);
401-
mlir::Value cast = createConvertOp(
402-
rewriter, loc, op.getMemref().getType(), call->getResult(0));
402+
mlir::Value devAddr = genGetDeviceAddress(rewriter, mod, op.getLoc(),
403+
addrOfOp.getResult());
403404
rewriter.startOpModification(op);
404-
op.getMemrefMutable().assign(cast);
405+
op.getMemrefMutable().assign(devAddr);
405406
rewriter.finalizeOpModification(op);
406407
return success();
407408
}
@@ -771,10 +772,32 @@ struct CUFLaunchOpConversion
771772
loc, clusterDimsAttr.getZ().getInt());
772773
}
773774
}
775+
llvm::SmallVector<mlir::Value> args;
776+
auto mod = op->getParentOfType<mlir::ModuleOp>();
777+
for (mlir::Value arg : op.getArgs()) {
778+
// If the argument is a global descriptor, make sure we pass the device
779+
// copy of this descriptor and not the host one.
780+
if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(arg.getType()))) {
781+
if (auto declareOp =
782+
mlir::dyn_cast_or_null<fir::DeclareOp>(arg.getDefiningOp())) {
783+
if (auto addrOfOp = mlir::dyn_cast_or_null<fir::AddrOfOp>(
784+
declareOp.getMemref().getDefiningOp())) {
785+
if (auto global = symTab.lookup<fir::GlobalOp>(
786+
addrOfOp.getSymbol().getRootReference().getValue())) {
787+
if (cuf::isRegisteredDeviceGlobal(global)) {
788+
arg = genGetDeviceAddress(rewriter, mod, op.getLoc(),
789+
declareOp.getResult());
790+
}
791+
}
792+
}
793+
}
794+
}
795+
args.push_back(arg);
796+
}
797+
774798
auto gpuLaunchOp = rewriter.create<mlir::gpu::LaunchFuncOp>(
775799
loc, kernelName, mlir::gpu::KernelDim3{gridSizeX, gridSizeY, gridSizeZ},
776-
mlir::gpu::KernelDim3{blockSizeX, blockSizeY, blockSizeZ}, zero,
777-
op.getArgs());
800+
mlir::gpu::KernelDim3{blockSizeX, blockSizeY, blockSizeZ}, zero, args);
778801
if (clusterDimX && clusterDimY && clusterDimZ) {
779802
gpuLaunchOp.getClusterSizeXMutable().assign(clusterDimX);
780803
gpuLaunchOp.getClusterSizeYMutable().assign(clusterDimY);

flang/test/Fir/CUDA/cuda-launch.fir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,45 @@ module attributes {gpu.container_module, dlti.dl_spec = #dlti.dl_spec<#dlti.dl_e
6262
// CHECK-LABEL: func.func @_QMmod1Phost_sub()
6363
// CHECK: gpu.launch_func @cuda_device_mod::@_QMmod1Psub1 clusters in (%c2{{.*}}, %c2{{.*}}, %c1{{.*}})
6464

65+
// -----
66+
67+
module attributes {gpu.container_module, dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
68+
gpu.module @cuda_device_mod {
69+
gpu.func @_QMdevptrPtest(%arg0: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) kernel {
70+
gpu.return
71+
}
72+
}
73+
fir.global @_QMdevptrEdev_ptr {data_attr = #cuf.cuda<device>} : !fir.box<!fir.ptr<!fir.array<?xf32>>> {
74+
%c0 = arith.constant 0 : index
75+
%0 = fir.zero_bits !fir.ptr<!fir.array<?xf32>>
76+
%1 = fir.shape %c0 : (index) -> !fir.shape<1>
77+
%2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.ptr<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.ptr<!fir.array<?xf32>>>
78+
fir.has_value %2 : !fir.box<!fir.ptr<!fir.array<?xf32>>>
79+
}
80+
func.func @_QMdevptrPtest(%arg0: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "dp"}) attributes {cuf.proc_attr = #cuf.cuda_proc<global>} {
81+
return
82+
}
83+
func.func @_QQmain() {
84+
%c1_i32 = arith.constant 1 : i32
85+
%c4 = arith.constant 4 : index
86+
%0 = cuf.alloc !fir.array<4xf32> {bindc_name = "a_dev", data_attr = #cuf.cuda<device>, uniq_name = "_QFEa_dev"} -> !fir.ref<!fir.array<4xf32>>
87+
%1 = fir.shape %c4 : (index) -> !fir.shape<1>
88+
%2 = fir.declare %0(%1) {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<target>, uniq_name = "_QFEa_dev"} : (!fir.ref<!fir.array<4xf32>>, !fir.shape<1>) -> !fir.ref<!fir.array<4xf32>>
89+
%3 = fir.address_of(@_QMdevptrEdev_ptr) : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
90+
%4 = fir.declare %3 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QMdevptrEdev_ptr"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
91+
%5 = fir.embox %2(%1) : (!fir.ref<!fir.array<4xf32>>, !fir.shape<1>) -> !fir.box<!fir.ptr<!fir.array<?xf32>>>
92+
fir.store %5 to %4 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
93+
cuf.sync_descriptor @_QMdevptrEdev_ptr
94+
cuf.kernel_launch @_QMdevptrPtest<<<%c1_i32, %c1_i32, %c1_i32, %c1_i32, %c1_i32, %c1_i32>>>(%4) : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
95+
cuf.free %2 : !fir.ref<!fir.array<4xf32>> {data_attr = #cuf.cuda<device>}
96+
return
97+
}
98+
}
99+
100+
// CHECK-LABEL: func.func @_QQmain()
101+
// CHECK: %[[ADDROF:.*]] = fir.address_of(@_QMdevptrEdev_ptr) : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
102+
// CHECK: %[[DECL:.*]] = fir.declare %[[ADDROF]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QMdevptrEdev_ptr"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
103+
// CHECK: %[[CONV_DECL:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.llvm_ptr<i8>
104+
// CHECK: %[[DEVADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[CONV_DECL]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
105+
// CHECK: %[[CONV_DEVADDR:.*]] = fir.convert %[[DEVADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
106+
// CHECK: gpu.launch_func @cuda_device_mod::@_QMdevptrPtest blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) dynamic_shared_memory_size %{{.*}} args(%[[CONV_DEVADDR]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)

0 commit comments

Comments
 (0)