-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[flang][cuda] Pass the device address for global descriptor #122802
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-flang-fir-hlfir Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesModule variables requiring a descriptors are implemented with two descriptors. One residing on the host and one on the device. Full diff: https://github.com/llvm/llvm-project/pull/122802.diff 2 Files Affected:
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 8c525fc6daff5e..d61d9f63cb2949 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -366,6 +366,23 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
const fir::LLVMTypeConverter *typeConverter;
};
+static mlir::Value genGetDeviceAddress(mlir::PatternRewriter &rewriter,
+ mlir::ModuleOp mod, mlir::Location loc,
+ mlir::Value inputArg) {
+ fir::FirOpBuilder builder(rewriter, mod);
+ mlir::func::FuncOp callee =
+ fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc, builder);
+ auto fTy = callee.getFunctionType();
+ mlir::Value conv = createConvertOp(rewriter, loc, fTy.getInput(0), inputArg);
+ mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+ mlir::Value sourceLine =
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+ builder, loc, fTy, conv, sourceFile, sourceLine)};
+ auto call = rewriter.create<fir::CallOp>(loc, callee, args);
+ return createConvertOp(rewriter, loc, inputArg.getType(), call->getResult(0));
+}
+
struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
using OpRewritePattern::OpRewritePattern;
@@ -382,26 +399,10 @@ struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
if (cuf::isRegisteredDeviceGlobal(global)) {
rewriter.setInsertionPointAfter(addrOfOp);
auto mod = op->getParentOfType<mlir::ModuleOp>();
- fir::FirOpBuilder builder(rewriter, mod);
- mlir::Location loc = op.getLoc();
- mlir::func::FuncOp callee =
- fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(
- loc, builder);
- auto fTy = callee.getFunctionType();
- mlir::Type toTy = fTy.getInput(0);
- mlir::Value inputArg =
- createConvertOp(rewriter, loc, toTy, addrOfOp.getResult());
- mlir::Value sourceFile =
- fir::factory::locationToFilename(builder, loc);
- mlir::Value sourceLine =
- fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
- llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
- builder, loc, fTy, inputArg, sourceFile, sourceLine)};
- auto call = rewriter.create<fir::CallOp>(loc, callee, args);
- mlir::Value cast = createConvertOp(
- rewriter, loc, op.getMemref().getType(), call->getResult(0));
+ mlir::Value devAddr = genGetDeviceAddress(rewriter, mod, op.getLoc(),
+ addrOfOp.getResult());
rewriter.startOpModification(op);
- op.getMemrefMutable().assign(cast);
+ op.getMemrefMutable().assign(devAddr);
rewriter.finalizeOpModification(op);
return success();
}
@@ -771,10 +772,32 @@ struct CUFLaunchOpConversion
loc, clusterDimsAttr.getZ().getInt());
}
}
+ llvm::SmallVector<mlir::Value> args;
+ auto mod = op->getParentOfType<mlir::ModuleOp>();
+ for (mlir::Value arg : op.getArgs()) {
+ // If the argument is a global descriptor, make sure we pass the device
+ // copy of this descriptor and not the host one.
+ if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(arg.getType()))) {
+ if (auto declareOp =
+ mlir::dyn_cast_or_null<fir::DeclareOp>(arg.getDefiningOp())) {
+ if (auto addrOfOp = mlir::dyn_cast_or_null<fir::AddrOfOp>(
+ declareOp.getMemref().getDefiningOp())) {
+ if (auto global = symTab.lookup<fir::GlobalOp>(
+ addrOfOp.getSymbol().getRootReference().getValue())) {
+ if (cuf::isRegisteredDeviceGlobal(global)) {
+ arg = genGetDeviceAddress(rewriter, mod, op.getLoc(),
+ declareOp.getResult());
+ }
+ }
+ }
+ }
+ }
+ args.push_back(arg);
+ }
+
auto gpuLaunchOp = rewriter.create<mlir::gpu::LaunchFuncOp>(
loc, kernelName, mlir::gpu::KernelDim3{gridSizeX, gridSizeY, gridSizeZ},
- mlir::gpu::KernelDim3{blockSizeX, blockSizeY, blockSizeZ}, zero,
- op.getArgs());
+ mlir::gpu::KernelDim3{blockSizeX, blockSizeY, blockSizeZ}, zero, args);
if (clusterDimX && clusterDimY && clusterDimZ) {
gpuLaunchOp.getClusterSizeXMutable().assign(clusterDimX);
gpuLaunchOp.getClusterSizeYMutable().assign(clusterDimY);
diff --git a/flang/test/Fir/CUDA/cuda-launch.fir b/flang/test/Fir/CUDA/cuda-launch.fir
index f11bcbdb7fce55..1e19b3bea1296f 100644
--- a/flang/test/Fir/CUDA/cuda-launch.fir
+++ b/flang/test/Fir/CUDA/cuda-launch.fir
@@ -62,3 +62,45 @@ module attributes {gpu.container_module, dlti.dl_spec = #dlti.dl_spec<#dlti.dl_e
// CHECK-LABEL: func.func @_QMmod1Phost_sub()
// CHECK: gpu.launch_func @cuda_device_mod::@_QMmod1Psub1 clusters in (%c2{{.*}}, %c2{{.*}}, %c1{{.*}})
+// -----
+
+module attributes {gpu.container_module, dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
+ gpu.module @cuda_device_mod {
+ gpu.func @_QMdevptrPtest(%arg0: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) kernel {
+ gpu.return
+ }
+ }
+ fir.global @_QMdevptrEdev_ptr {data_attr = #cuf.cuda<device>} : !fir.box<!fir.ptr<!fir.array<?xf32>>> {
+ %c0 = arith.constant 0 : index
+ %0 = fir.zero_bits !fir.ptr<!fir.array<?xf32>>
+ %1 = fir.shape %c0 : (index) -> !fir.shape<1>
+ %2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.ptr<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.ptr<!fir.array<?xf32>>>
+ fir.has_value %2 : !fir.box<!fir.ptr<!fir.array<?xf32>>>
+ }
+ func.func @_QMdevptrPtest(%arg0: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "dp"}) attributes {cuf.proc_attr = #cuf.cuda_proc<global>} {
+ return
+ }
+ func.func @_QQmain() {
+ %c1_i32 = arith.constant 1 : i32
+ %c4 = arith.constant 4 : index
+ %0 = cuf.alloc !fir.array<4xf32> {bindc_name = "a_dev", data_attr = #cuf.cuda<device>, uniq_name = "_QFEa_dev"} -> !fir.ref<!fir.array<4xf32>>
+ %1 = fir.shape %c4 : (index) -> !fir.shape<1>
+ %2 = fir.declare %0(%1) {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<target>, uniq_name = "_QFEa_dev"} : (!fir.ref<!fir.array<4xf32>>, !fir.shape<1>) -> !fir.ref<!fir.array<4xf32>>
+ %3 = fir.address_of(@_QMdevptrEdev_ptr) : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+ %4 = fir.declare %3 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QMdevptrEdev_ptr"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+ %5 = fir.embox %2(%1) : (!fir.ref<!fir.array<4xf32>>, !fir.shape<1>) -> !fir.box<!fir.ptr<!fir.array<?xf32>>>
+ fir.store %5 to %4 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+ cuf.sync_descriptor @_QMdevptrEdev_ptr
+ cuf.kernel_launch @_QMdevptrPtest<<<%c1_i32, %c1_i32, %c1_i32, %c1_i32, %c1_i32, %c1_i32>>>(%4) : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
+ cuf.free %2 : !fir.ref<!fir.array<4xf32>> {data_attr = #cuf.cuda<device>}
+ return
+ }
+}
+
+// CHECK-LABEL: func.func @_QQmain()
+// CHECK: %[[ADDROF:.*]] = fir.address_of(@_QMdevptrEdev_ptr) : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+// CHECK: %[[DECL:.*]] = fir.declare %[[ADDROF]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QMdevptrEdev_ptr"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+// CHECK: %[[CONV_DECL:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[DEVADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[CONV_DECL]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: %[[CONV_DEVADDR:.*]] = fir.convert %[[DEVADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+// CHECK: gpu.launch_func @cuda_device_mod::@_QMdevptrPtest blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) dynamic_shared_memory_size %{{.*}} args(%[[CONV_DEVADDR]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Introduce a new op to get the device address from a host symbol. This simplify the current conversion and this is also in preparation for some legalization work that need to be done in cuf kernel and cuf kernel launch similar to #122802
Introduce a new op to get the device address from a host symbol. This simplify the current conversion and this is also in preparation for some legalization work that need to be done in cuf kernel and cuf kernel launch similar to llvm/llvm-project#122802
Module variables requiring a descriptors are implemented with two descriptors. One residing on the host and one on the device.
When passing a global descriptor to a kernel launch, the address of the device descriptor must be substituted so the kernel will access the descriptor on the device.
This patch insert calls to CUFGetDeviceAddress during the conversion of
cuf.kernel_launch
operation so the arguments are correct.