|
1 |
| -# Copyright 2024 NVIDIA Corporation. All rights reserved. |
| 1 | +# Copyright 2024-2025 NVIDIA Corporation. All rights reserved. |
2 | 2 | # SPDX-License-Identifier: Apache-2.0
|
3 | 3 |
|
| 4 | +import ctypes |
| 5 | +import os |
| 6 | +import pathlib |
| 7 | + |
| 8 | +import numpy as np |
4 | 9 | import pytest
|
5 | 10 |
|
6 |
| -from cuda.core.experimental import Device, LaunchConfig, Program, launch |
| 11 | +from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions, launch |
| 12 | +from cuda.core.experimental._memory import _DefaultPinnedMemorySource |
7 | 13 |
|
8 | 14 |
|
9 | 15 | def test_launch_config_init(init_cuda):
|
@@ -59,3 +65,85 @@ def test_launch_invalid_values(init_cuda):
|
59 | 65 | launch(stream, ker, None)
|
60 | 66 |
|
61 | 67 | launch(stream, config, ker)
|
| 68 | + |
| 69 | + |
| 70 | +# Parametrize: (python_type, cpp_type, init_value) |
| 71 | +PARAMS = ( |
| 72 | + (bool, 'bool', True), |
| 73 | + (float, 'double', 2.718), |
| 74 | + (np.bool, 'bool', True), |
| 75 | + (np.int8, 'signed char', -42), |
| 76 | + (np.int16, 'signed short', -1234), |
| 77 | + (np.int32, 'signed int', -123456), |
| 78 | + (np.int64, 'signed long long', -123456789), |
| 79 | + (np.uint8, 'unsigned char', 42), |
| 80 | + (np.uint16, 'unsigned short', 1234), |
| 81 | + (np.uint32, 'unsigned int', 123456), |
| 82 | + (np.uint64, 'unsigned long long', 123456789), |
| 83 | + (np.float32, 'float', 3.14), |
| 84 | + (np.float64, 'double', 2.718), |
| 85 | + (ctypes.c_bool, 'bool', True), |
| 86 | + (ctypes.c_int8, 'signed char', -42), |
| 87 | + (ctypes.c_int16, 'signed short', -1234), |
| 88 | + (ctypes.c_int32, 'signed int', -123456), |
| 89 | + (ctypes.c_int64, 'signed long long', -123456789), |
| 90 | + (ctypes.c_uint8, 'unsigned char', 42), |
| 91 | + (ctypes.c_uint16, 'unsigned short', 1234), |
| 92 | + (ctypes.c_uint32, 'unsigned int', 123456), |
| 93 | + (ctypes.c_uint64, 'unsigned long long', 123456789), |
| 94 | + (ctypes.c_float, 'float', 3.14), |
| 95 | + (ctypes.c_double, 'double', 2.718), |
| 96 | +) |
| 97 | +if os.environ.get("CUDA_PATH"): |
| 98 | + PARAMS += ( |
| 99 | + (np.float16, 'half', 0.78), |
| 100 | + (np.complex64, 'cuda::std::complex<float>', 1+2j), |
| 101 | + (np.complex128, 'cuda::std::complex<double>', -3-4j), |
| 102 | + (complex, 'cuda::std::complex<double>', 5-7j), |
| 103 | + ) |
| 104 | + |
| 105 | +@pytest.mark.parametrize("python_type, cpp_type, init_value", PARAMS) |
| 106 | +def test_launch_scalar_argument(python_type, cpp_type, init_value): |
| 107 | + dev = Device() |
| 108 | + dev.set_current() |
| 109 | + |
| 110 | + # Prepare pinned host array |
| 111 | + mr = _DefaultPinnedMemorySource() |
| 112 | + b = mr.allocate(np.dtype(python_type).itemsize) |
| 113 | + arr = np.from_dlpack(b).view(python_type) |
| 114 | + arr[:] = 0 |
| 115 | + |
| 116 | + # Prepare scalar argument in Python |
| 117 | + scalar = python_type(init_value) |
| 118 | + |
| 119 | + # CUDA kernel templated on type T |
| 120 | + code = r""" |
| 121 | + template <typename T> |
| 122 | + __global__ void write_scalar(T* arr, T val) { |
| 123 | + arr[0] = val; |
| 124 | + } |
| 125 | + """ |
| 126 | + |
| 127 | + # Compile and force instantiation for this type |
| 128 | + arch = "".join(f"{i}" for i in dev.compute_capability) |
| 129 | + if os.environ.get("CUDA_PATH"): |
| 130 | + include_path=str(pathlib.Path(os.environ["CUDA_PATH"]) / pathlib.Path("include")) |
| 131 | + code = r""" |
| 132 | + #include <cuda_fp16.h> |
| 133 | + #include <cuda/std/complex> |
| 134 | + """ + code |
| 135 | + else: |
| 136 | + include_path=None |
| 137 | + pro_opts = ProgramOptions(std="c++11", arch=f"sm_{arch}", include_path=include_path) |
| 138 | + prog = Program(code, code_type="c++", options=pro_opts) |
| 139 | + ker_name = f"write_scalar<{cpp_type}>" |
| 140 | + mod = prog.compile("cubin", name_expressions=(ker_name,)) |
| 141 | + ker = mod.get_kernel(ker_name) |
| 142 | + |
| 143 | + # Launch with 1 thread |
| 144 | + config = LaunchConfig(grid=1, block=1) |
| 145 | + launch(dev.default_stream, config, ker, arr.ctypes.data, scalar) |
| 146 | + dev.default_stream.sync() |
| 147 | + |
| 148 | + # Check result |
| 149 | + assert arr[0] == init_value, f"Expected {init_value}, got {arr[0]}" |
0 commit comments