Skip to content

[flang][cuda] Pass the device address for global descriptor #122802

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 14, 2025

Conversation

clementval
Copy link
Contributor

Module variables requiring a descriptors are implemented with two descriptors. One residing on the host and one on the device.
When passing a global descriptor to a kernel launch, the address of the device descriptor must be substituted so the kernel will access the descriptor on the device.
This patch insert calls to CUFGetDeviceAddress during the conversion of cuf.kernel_launch operation so the arguments are correct.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Jan 13, 2025
@llvmbot
Copy link
Member

llvmbot commented Jan 13, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Module variables requiring a descriptors are implemented with two descriptors. One residing on the host and one on the device.
When passing a global descriptor to a kernel launch, the address of the device descriptor must be substituted so the kernel will access the descriptor on the device.
This patch insert calls to CUFGetDeviceAddress during the conversion of cuf.kernel_launch operation so the arguments are correct.


Full diff: https://github.com/llvm/llvm-project/pull/122802.diff

2 Files Affected:

  • (modified) flang/lib/Optimizer/Transforms/CUFOpConversion.cpp (+44-21)
  • (modified) flang/test/Fir/CUDA/cuda-launch.fir (+42)
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 8c525fc6daff5e..d61d9f63cb2949 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -366,6 +366,23 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
   const fir::LLVMTypeConverter *typeConverter;
 };
 
+static mlir::Value genGetDeviceAddress(mlir::PatternRewriter &rewriter,
+                                       mlir::ModuleOp mod, mlir::Location loc,
+                                       mlir::Value inputArg) {
+  fir::FirOpBuilder builder(rewriter, mod);
+  mlir::func::FuncOp callee =
+      fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc, builder);
+  auto fTy = callee.getFunctionType();
+  mlir::Value conv = createConvertOp(rewriter, loc, fTy.getInput(0), inputArg);
+  mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+  mlir::Value sourceLine =
+      fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
+  llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+      builder, loc, fTy, conv, sourceFile, sourceLine)};
+  auto call = rewriter.create<fir::CallOp>(loc, callee, args);
+  return createConvertOp(rewriter, loc, inputArg.getType(), call->getResult(0));
+}
+
 struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -382,26 +399,10 @@ struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
         if (cuf::isRegisteredDeviceGlobal(global)) {
           rewriter.setInsertionPointAfter(addrOfOp);
           auto mod = op->getParentOfType<mlir::ModuleOp>();
-          fir::FirOpBuilder builder(rewriter, mod);
-          mlir::Location loc = op.getLoc();
-          mlir::func::FuncOp callee =
-              fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(
-                  loc, builder);
-          auto fTy = callee.getFunctionType();
-          mlir::Type toTy = fTy.getInput(0);
-          mlir::Value inputArg =
-              createConvertOp(rewriter, loc, toTy, addrOfOp.getResult());
-          mlir::Value sourceFile =
-              fir::factory::locationToFilename(builder, loc);
-          mlir::Value sourceLine =
-              fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
-          llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
-              builder, loc, fTy, inputArg, sourceFile, sourceLine)};
-          auto call = rewriter.create<fir::CallOp>(loc, callee, args);
-          mlir::Value cast = createConvertOp(
-              rewriter, loc, op.getMemref().getType(), call->getResult(0));
+          mlir::Value devAddr = genGetDeviceAddress(rewriter, mod, op.getLoc(),
+                                                    addrOfOp.getResult());
           rewriter.startOpModification(op);
-          op.getMemrefMutable().assign(cast);
+          op.getMemrefMutable().assign(devAddr);
           rewriter.finalizeOpModification(op);
           return success();
         }
@@ -771,10 +772,32 @@ struct CUFLaunchOpConversion
             loc, clusterDimsAttr.getZ().getInt());
       }
     }
+    llvm::SmallVector<mlir::Value> args;
+    auto mod = op->getParentOfType<mlir::ModuleOp>();
+    for (mlir::Value arg : op.getArgs()) {
+      // If the argument is a global descriptor, make sure we pass the device
+      // copy of this descriptor and not the host one.
+      if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(arg.getType()))) {
+        if (auto declareOp =
+                mlir::dyn_cast_or_null<fir::DeclareOp>(arg.getDefiningOp())) {
+          if (auto addrOfOp = mlir::dyn_cast_or_null<fir::AddrOfOp>(
+                  declareOp.getMemref().getDefiningOp())) {
+            if (auto global = symTab.lookup<fir::GlobalOp>(
+                    addrOfOp.getSymbol().getRootReference().getValue())) {
+              if (cuf::isRegisteredDeviceGlobal(global)) {
+                arg = genGetDeviceAddress(rewriter, mod, op.getLoc(),
+                                          declareOp.getResult());
+              }
+            }
+          }
+        }
+      }
+      args.push_back(arg);
+    }
+
     auto gpuLaunchOp = rewriter.create<mlir::gpu::LaunchFuncOp>(
         loc, kernelName, mlir::gpu::KernelDim3{gridSizeX, gridSizeY, gridSizeZ},
-        mlir::gpu::KernelDim3{blockSizeX, blockSizeY, blockSizeZ}, zero,
-        op.getArgs());
+        mlir::gpu::KernelDim3{blockSizeX, blockSizeY, blockSizeZ}, zero, args);
     if (clusterDimX && clusterDimY && clusterDimZ) {
       gpuLaunchOp.getClusterSizeXMutable().assign(clusterDimX);
       gpuLaunchOp.getClusterSizeYMutable().assign(clusterDimY);
diff --git a/flang/test/Fir/CUDA/cuda-launch.fir b/flang/test/Fir/CUDA/cuda-launch.fir
index f11bcbdb7fce55..1e19b3bea1296f 100644
--- a/flang/test/Fir/CUDA/cuda-launch.fir
+++ b/flang/test/Fir/CUDA/cuda-launch.fir
@@ -62,3 +62,45 @@ module attributes {gpu.container_module, dlti.dl_spec = #dlti.dl_spec<#dlti.dl_e
 // CHECK-LABEL: func.func @_QMmod1Phost_sub()
 // CHECK: gpu.launch_func  @cuda_device_mod::@_QMmod1Psub1 clusters in (%c2{{.*}}, %c2{{.*}}, %c1{{.*}})
 
+// -----
+
+module attributes {gpu.container_module, dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
+  gpu.module @cuda_device_mod {
+    gpu.func @_QMdevptrPtest(%arg0: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) kernel {
+      gpu.return
+    }
+  }
+  fir.global @_QMdevptrEdev_ptr {data_attr = #cuf.cuda<device>} : !fir.box<!fir.ptr<!fir.array<?xf32>>> {
+    %c0 = arith.constant 0 : index
+    %0 = fir.zero_bits !fir.ptr<!fir.array<?xf32>>
+    %1 = fir.shape %c0 : (index) -> !fir.shape<1>
+    %2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.ptr<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.ptr<!fir.array<?xf32>>>
+    fir.has_value %2 : !fir.box<!fir.ptr<!fir.array<?xf32>>>
+  }
+  func.func @_QMdevptrPtest(%arg0: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "dp"}) attributes {cuf.proc_attr = #cuf.cuda_proc<global>} {
+    return
+  }
+  func.func @_QQmain() {
+    %c1_i32 = arith.constant 1 : i32
+    %c4 = arith.constant 4 : index
+    %0 = cuf.alloc !fir.array<4xf32> {bindc_name = "a_dev", data_attr = #cuf.cuda<device>, uniq_name = "_QFEa_dev"} -> !fir.ref<!fir.array<4xf32>>
+    %1 = fir.shape %c4 : (index) -> !fir.shape<1>
+    %2 = fir.declare %0(%1) {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<target>, uniq_name = "_QFEa_dev"} : (!fir.ref<!fir.array<4xf32>>, !fir.shape<1>) -> !fir.ref<!fir.array<4xf32>>
+    %3 = fir.address_of(@_QMdevptrEdev_ptr) : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+    %4 = fir.declare %3 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QMdevptrEdev_ptr"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+    %5 = fir.embox %2(%1) : (!fir.ref<!fir.array<4xf32>>, !fir.shape<1>) -> !fir.box<!fir.ptr<!fir.array<?xf32>>>
+    fir.store %5 to %4 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+    cuf.sync_descriptor @_QMdevptrEdev_ptr
+    cuf.kernel_launch @_QMdevptrPtest<<<%c1_i32, %c1_i32, %c1_i32, %c1_i32, %c1_i32, %c1_i32>>>(%4) : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
+    cuf.free %2 : !fir.ref<!fir.array<4xf32>> {data_attr = #cuf.cuda<device>}
+    return
+  }
+}
+
+// CHECK-LABEL: func.func @_QQmain()
+// CHECK: %[[ADDROF:.*]] = fir.address_of(@_QMdevptrEdev_ptr) : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+// CHECK: %[[DECL:.*]] = fir.declare %[[ADDROF]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QMdevptrEdev_ptr"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+// CHECK: %[[CONV_DECL:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[DEVADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[CONV_DECL]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: %[[CONV_DEVADDR:.*]] = fir.convert %[[DEVADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+// CHECK: gpu.launch_func  @cuda_device_mod::@_QMdevptrPtest blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}})  dynamic_shared_memory_size %{{.*}} args(%[[CONV_DEVADDR]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)

Copy link
Contributor

@Renaud-K Renaud-K left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@clementval clementval merged commit ba4dc5a into llvm:main Jan 14, 2025
9 of 11 checks passed
@clementval clementval deleted the cuf_global_dev_addr branch January 14, 2025 01:23
clementval added a commit that referenced this pull request Jan 15, 2025
Introduce a new op to get the device address from a host symbol. This
simplify the current conversion and this is also in preparation for some
legalization work that need to be done in cuf kernel and cuf kernel
launch similar to
#122802
github-actions bot pushed a commit to arm/arm-toolchain that referenced this pull request Jan 15, 2025
Introduce a new op to get the device address from a host symbol. This
simplify the current conversion and this is also in preparation for some
legalization work that need to be done in cuf kernel and cuf kernel
launch similar to
llvm/llvm-project#122802
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants