Skip to content

Commit 5be9082

Browse files
authored
[flang][cuda] Carry over the dynamic shared memory size to gpu.launch_func (#132837)
1 parent 4c68061 commit 5be9082

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,7 @@ struct CUFLaunchOpConversion
810810
mlir::PatternRewriter &rewriter) const override {
811811
mlir::Location loc = op.getLoc();
812812
auto idxTy = mlir::IndexType::get(op.getContext());
813-
auto zero = rewriter.create<mlir::arith::ConstantOp>(
813+
mlir::Value zero = rewriter.create<mlir::arith::ConstantOp>(
814814
loc, rewriter.getIntegerType(32), rewriter.getI32IntegerAttr(0));
815815
auto gridSizeX =
816816
rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getGridX());
@@ -869,10 +869,11 @@ struct CUFLaunchOpConversion
869869
}
870870
args.push_back(arg);
871871
}
872-
872+
mlir::Value dynamicShmemSize = op.getBytes() ? op.getBytes() : zero;
873873
auto gpuLaunchOp = rewriter.create<mlir::gpu::LaunchFuncOp>(
874874
loc, kernelName, mlir::gpu::KernelDim3{gridSizeX, gridSizeY, gridSizeZ},
875-
mlir::gpu::KernelDim3{blockSizeX, blockSizeY, blockSizeZ}, zero, args);
875+
mlir::gpu::KernelDim3{blockSizeX, blockSizeY, blockSizeZ},
876+
dynamicShmemSize, args);
876877
if (clusterDimX && clusterDimY && clusterDimZ) {
877878
gpuLaunchOp.getClusterSizeXMutable().assign(clusterDimX);
878879
gpuLaunchOp.getClusterSizeYMutable().assign(clusterDimY);

flang/test/Fir/CUDA/cuda-launch.fir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,15 @@ module attributes {gpu.container_module, dlti.dl_spec = #dlti.dl_spec<#dlti.dl_e
2323
// CHECK: %[[ALLOCA:.*]] = fir.alloca f32
2424
%c1 = arith.constant 1 : index
2525
%c11_i32 = arith.constant 11 : i32
26+
%c1024_i32 = arith.constant 1024 : i32
2627
%c6_i32 = arith.constant 6 : i32
2728
%c1_i32 = arith.constant 1 : i32
2829
// CHECK: gpu.launch_func @cuda_device_mod::@_QPsub_device1 blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) dynamic_shared_memory_size %c0{{.*}}
2930
cuf.kernel_launch @cuda_device_mod::@_QPsub_device1<<<%c1_i32, %c1_i32, %c1_i32, %c1_i32, %c1_i32, %c1_i32>>>()
3031

32+
// CHECK: gpu.launch_func @cuda_device_mod::@_QPsub_device1 blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) dynamic_shared_memory_size %c1024{{.*}}
33+
cuf.kernel_launch @cuda_device_mod::@_QPsub_device1<<<%c1_i32, %c1_i32, %c1_i32, %c1_i32, %c1_i32, %c1_i32, %c1024_i32>>>()
34+
3135
// CHECK: gpu.launch_func @cuda_device_mod::@_QPsub_device2 blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) dynamic_shared_memory_size %c0{{.*}} args(%[[ALLOCA]] : !fir.ref<f32>)
3236
cuf.kernel_launch @cuda_device_mod::@_QPsub_device2<<<%c1_i32, %c1_i32, %c1_i32, %c1_i32, %c1_i32, %c1_i32>>>(%0) : (!fir.ref<f32>)
3337
return

0 commit comments

Comments
 (0)