Skip to content

Commit b4d3c2c

Browse files
authored
[flang][cuda] Update FIROps.td to add $grid_z to CudaKernelLaunch (#85318)
grid can be 3 dimensional. (@clementval)
1 parent 2ad9706 commit b4d3c2c

File tree

3 files changed

+10
-9
lines changed

3 files changed

+10
-9
lines changed

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2454,6 +2454,7 @@ def fir_CUDAKernelLaunch : fir_Op<"cuda_kernel_launch", [CallOpInterface,
24542454
SymbolRefAttr:$callee,
24552455
I32:$grid_x,
24562456
I32:$grid_y,
2457+
I32:$grid_z,
24572458
I32:$block_x,
24582459
I32:$block_y,
24592460
I32:$block_z,
@@ -2463,8 +2464,8 @@ def fir_CUDAKernelLaunch : fir_Op<"cuda_kernel_launch", [CallOpInterface,
24632464
);
24642465

24652466
let assemblyFormat = [{
2466-
$callee `<` `<` `<` $grid_x `,` $grid_y `,` $block_x `,` $block_y `,`
2467-
$block_z ( `,` $bytes^ ( `,` $stream^ )? )? `>` `>` `>`
2467+
$callee `<` `<` `<` $grid_x `,` $grid_y `,` $grid_z `,`$block_x `,`
2468+
$block_y `,` $block_z ( `,` $bytes^ ( `,` $stream^ )? )? `>` `>` `>`
24682469
`` `(` ( $args^ `:` type($args) )? `)` attr-dict
24692470
}];
24702471

flang/lib/Lower/ConvertCall.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -466,8 +466,8 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
466466
caller.getCallDescription().chevrons()[3], stmtCtx)));
467467

468468
builder.create<fir::CUDAKernelLaunch>(
469-
loc, funcType.getResults(), funcSymbolAttr, grid_x, grid_y, block_x,
470-
block_y, block_z, bytes, stream, operands);
469+
loc, funcType.getResults(), funcSymbolAttr, grid_x, grid_y, one,
470+
block_x, block_y, block_z, bytes, stream, operands);
471471
callNumResults = 0;
472472
} else if (caller.requireDispatchCall()) {
473473
// Procedure call requiring a dynamic dispatch. Call is created with

flang/test/Lower/CUDA/cuda-kernel-calls.cuf

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ contains
1818
! CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<device>, uniq_name = "_QMtest_callFhostEa"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
1919

2020
call dev_kernel0<<<10, 20>>>()
21-
! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}>>>()
21+
! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}>>>()
2222

2323
call dev_kernel0<<< __builtin_dim3(1,1), __builtin_dim3(32,1,1) >>>
2424
! CHECK: %[[ADDR_DIM3_GRID:.*]] = fir.address_of(@_QQro._QM__fortran_builtinsT__builtin_dim3.{{.*}}) : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>
@@ -35,16 +35,16 @@ contains
3535
! CHECK: %[[BLOCK_Y_LOAD:.*]] = fir.load %[[BLOCK_Y]] : !fir.ref<i32>
3636
! CHECK: %[[BLOCK_Z:.*]] = hlfir.designate %[[DIM3_BLOCK]]#1{"z"} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32>
3737
! CHECK: %[[BLOCK_Z_LOAD:.*]] = fir.load %[[BLOCK_Z]] : !fir.ref<i32>
38-
! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%[[GRID_X_LOAD]], %[[GRID_Y_LOAD]], %[[BLOCK_X_LOAD]], %[[BLOCK_Y_LOAD]], %[[BLOCK_Z_LOAD]]>>>()
38+
! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%[[GRID_X_LOAD]], %[[GRID_Y_LOAD]], %c1{{.*}}, %[[BLOCK_X_LOAD]], %[[BLOCK_Y_LOAD]], %[[BLOCK_Z_LOAD]]>>>()
3939

4040
call dev_kernel0<<<10, 20, 2>>>()
41-
! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}, %c2{{.*}}>>>()
41+
! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}, %c2{{.*}}>>>()
4242

4343
call dev_kernel0<<<10, 20, 2, 0>>>()
44-
! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}, %c2{{.*}}, %c0{{.*}}>>>()
44+
! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}, %c2{{.*}}, %c0{{.*}}>>>()
4545

4646
call dev_kernel1<<<1, 32>>>(a)
47-
! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel1<<<%c1{{.*}}, %c1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}>>>(%1#1 : !fir.ref<f32>)
47+
! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel1<<<%c1{{.*}}, %c1{{.*}}, %c1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}>>>(%1#1 : !fir.ref<f32>)
4848
end
4949

5050
end

0 commit comments

Comments
 (0)