Skip to content

Commit 18a9c85

Browse files
committed
also test nvrtc
1 parent 650e3ae commit 18a9c85

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

cuda_core/tests/cython/test_get_cuda_native_handle.pyx

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@ from libc.stdint cimport intptr_t
99

1010
from cuda.bindings.driver cimport (CUstream as pyCUstream,
1111
CUevent as pyCUevent)
12+
from cuda.bindings.nvrtc cimport nvrtcProgram as pynvrtcProgram
1213
from cuda.bindings.cydriver cimport CUstream, CUevent
14+
from cuda.bindings.cynvrtc cimport nvrtcProgram
1315

14-
from cuda.core.experimental import Device
16+
from cuda.core.experimental import Device, Program
1517

1618

1719
cdef extern from "utility.hpp":
@@ -32,5 +34,12 @@ def test_get_cuda_native_handle():
3234
cdef CUevent e_c = <CUevent>get_cuda_native_handle(e_py)
3335
assert <intptr_t>(e_c) == <intptr_t>(int(e_py))
3436

37+
prog = Program("extern \"C\" __global__ void dummy() {}", "c++")
38+
assert prog.backend == "NVRTC"
39+
cdef pynvrtcProgram prog_py = prog.handle
40+
cdef nvrtcProgram prog_c = <nvrtcProgram>get_cuda_native_handle(prog_py)
41+
assert <intptr_t>(prog_c) == <intptr_t>(int(prog_py))
42+
43+
prog.close()
3544
e.close()
3645
s.close()

0 commit comments

Comments
 (0)