Skip to content

Commit 69ccb13

Browse files
authored
[flang][cuda] Make argument passed by value for sync functions (#125909)
`syncthreads_and`, `syncthreads_count`, `syncthreads_or`, `synwrap` must take their argument by value. This patch updates the interfaces and makes sure these functions can be called inside a cuff kernel as well.
1 parent 718b16a commit 69ccb13

File tree

2 files changed

+28
-14
lines changed

2 files changed

+28
-14
lines changed

flang/module/cudadevice.f90

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,28 +29,28 @@ attributes(device) subroutine syncthreads()
2929

3030
interface
3131
attributes(device) integer function syncthreads_and(value)
32-
integer :: value
32+
integer, value :: value
3333
end function
3434
end interface
3535
public :: syncthreads_and
3636

3737
interface
3838
attributes(device) integer function syncthreads_count(value)
39-
integer :: value
39+
integer, value :: value
4040
end function
4141
end interface
4242
public :: syncthreads_count
4343

4444
interface
4545
attributes(device) integer function syncthreads_or(value)
46-
integer :: value
46+
integer, value :: value
4747
end function
4848
end interface
4949
public :: syncthreads_or
5050

5151
interface
5252
attributes(device) subroutine syncwarp(mask) bind(c, name='__syncwarp')
53-
integer :: mask
53+
integer, value :: mask
5454
end subroutine
5555
end interface
5656
public :: syncwarp

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ end
4747

4848
! CHECK-LABEL: func.func @_QPdevsub() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
4949
! CHECK: fir.call @llvm.nvvm.barrier0() fastmath<contract> : () -> ()
50-
! CHECK: fir.call @__syncwarp(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> ()
50+
! CHECK: fir.call @__syncwarp(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> ()
5151
! CHECK: fir.call @llvm.nvvm.membar.gl() fastmath<contract> : () -> ()
5252
! CHECK: fir.call @llvm.nvvm.membar.cta() fastmath<contract> : () -> ()
5353
! CHECK: fir.call @llvm.nvvm.membar.sys() fastmath<contract> : () -> ()
@@ -79,17 +79,9 @@ end
7979
! CHECK: %{{.*}} = llvm.atomicrmw uinc_wrap %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32
8080
! CHECK: %{{.*}} = llvm.atomicrmw udec_wrap %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32
8181

82-
! CHECK: func.func private @llvm.nvvm.barrier0()
83-
! CHECK: func.func private @__syncwarp(!fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>}) attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncwarp", fir.proc_attrs = #fir.proc_attrs<bind_c>}
84-
! CHECK: func.func private @llvm.nvvm.membar.gl()
85-
! CHECK: func.func private @llvm.nvvm.membar.cta()
86-
! CHECK: func.func private @llvm.nvvm.membar.sys()
87-
! CHECK: func.func private @llvm.nvvm.barrier0.and(i32) -> i32
88-
! CHECK: func.func private @llvm.nvvm.barrier0.popc(i32) -> i32
89-
! CHECK: func.func private @llvm.nvvm.barrier0.or(i32) -> i32
90-
9182
subroutine host1()
9283
integer, device :: a(32)
84+
integer, device :: ret
9385
integer :: i, j
9486

9587
block; use cudadevice
@@ -98,6 +90,28 @@ block; use cudadevice
9890
a(i) = a(i) * 2.0
9991
call syncthreads()
10092
a(i) = a(i) + a(j) - 34.0
93+
94+
call syncwarp(1)
95+
ret = syncthreads_and(1)
96+
ret = syncthreads_count(1)
97+
ret = syncthreads_or(1)
10198
end do
10299
end block
103100
end
101+
102+
! CHECK-LABEL: func.func @_QPhost1()
103+
! CHECK: cuf.kernel
104+
! CHECK: fir.call @llvm.nvvm.barrier0() fastmath<contract> : () -> ()
105+
! CHECK: fir.call @__syncwarp(%c1{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> ()
106+
! CHECK: fir.call @llvm.nvvm.barrier0.and(%c1{{.*}}) fastmath<contract> : (i32) -> i32
107+
! CHECK: fir.call @llvm.nvvm.barrier0.popc(%c1{{.*}}) fastmath<contract> : (i32) -> i32
108+
! CHECK: fir.call @llvm.nvvm.barrier0.or(%c1{{.*}}) fastmath<contract> : (i32) -> i32
109+
110+
! CHECK: func.func private @llvm.nvvm.barrier0()
111+
! CHECK: func.func private @__syncwarp(i32) attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncwarp", fir.proc_attrs = #fir.proc_attrs<bind_c>}
112+
! CHECK: func.func private @llvm.nvvm.membar.gl()
113+
! CHECK: func.func private @llvm.nvvm.membar.cta()
114+
! CHECK: func.func private @llvm.nvvm.membar.sys()
115+
! CHECK: func.func private @llvm.nvvm.barrier0.and(i32) -> i32
116+
! CHECK: func.func private @llvm.nvvm.barrier0.popc(i32) -> i32
117+
! CHECK: func.func private @llvm.nvvm.barrier0.or(i32) -> i32

0 commit comments

Comments
 (0)