Skip to content

Commit 5f2a88a

Browse files
committed
fix fp16 scalar handling
1 parent 89909f3 commit 5f2a88a

File tree

3 files changed

+123
-5
lines changed

3 files changed

+123
-5
lines changed

cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,33 @@ ctypedef cpp_complex.complex[float] cpp_single_complex
2222
ctypedef cpp_complex.complex[double] cpp_double_complex
2323

2424

25+
# We need an identifier for fp16 for copying scalars on the host. This is a minimal
26+
# implementation borrowed from cuda_fp16.h.
27+
cdef extern from *:
28+
"""
29+
#if __cplusplus >= 201103L
30+
#define __CUDA_ALIGN__(n) alignas(n) /* C++11 kindly gives us a keyword for this */
31+
#else
32+
#if defined(__GNUC__)
33+
#define __CUDA_ALIGN__(n) __attribute__ ((aligned(n)))
34+
#elif defined(_MSC_VER)
35+
#define __CUDA_ALIGN__(n) __declspec(align(n))
36+
#else
37+
#define __CUDA_ALIGN__(n)
38+
#endif /* defined(__GNUC__) */
39+
#endif /* __cplusplus >= 201103L */
40+
41+
typedef struct __CUDA_ALIGN__(2) {
42+
/**
43+
* Storage field contains bits representation of the \p half floating-point number.
44+
*/
45+
unsigned short x;
46+
} __half_raw;
47+
"""
48+
ctypedef struct __half_raw:
49+
unsigned short x
50+
51+
2552
ctypedef fused supported_type:
2653
cpp_bool
2754
int8_t
@@ -32,6 +59,7 @@ ctypedef fused supported_type:
3259
uint16_t
3360
uint32_t
3461
uint64_t
62+
__half_raw
3563
float
3664
double
3765
intptr_t
@@ -85,6 +113,8 @@ cdef inline int prepare_arg(
85113
(<supported_type*>ptr)[0] = cpp_complex.complex[float](arg.real, arg.imag)
86114
elif supported_type is cpp_double_complex:
87115
(<supported_type*>ptr)[0] = cpp_complex.complex[double](arg.real, arg.imag)
116+
elif supported_type is __half_raw:
117+
(<supported_type*>ptr).x = <int16_t>(arg.view(numpy_int16))
88118
else:
89119
(<supported_type*>ptr)[0] = <supported_type>(arg)
90120
data_addresses[idx] = ptr # take the address to the scalar
@@ -147,8 +177,7 @@ cdef inline int prepare_numpy_arg(
147177
elif isinstance(arg, numpy_uint64):
148178
return prepare_arg[uint64_t](data, data_addresses, arg, idx)
149179
elif isinstance(arg, numpy_float16):
150-
# use int16 as a proxy
151-
return prepare_arg[int16_t](data, data_addresses, arg, idx)
180+
return prepare_arg[__half_raw](data, data_addresses, arg, idx)
152181
elif isinstance(arg, numpy_float32):
153182
return prepare_arg[float](data, data_addresses, arg, idx)
154183
elif isinstance(arg, numpy_float64):
@@ -207,7 +236,7 @@ cdef class ParamHolder:
207236
not_prepared = prepare_ctypes_arg(self.data, self.data_addresses, arg, i)
208237
if not_prepared:
209238
# TODO: support ctypes/numpy struct
210-
raise TypeError
239+
raise TypeError("the argument is of unsupported type: " + str(type(arg)))
211240

212241
self.kernel_args = kernel_args
213242
self.ptr = <intptr_t>self.data_addresses.data()

cuda_core/docs/source/release/0.3.0-notes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,4 @@ Fixes and enhancements
3030
----------------------
3131

3232
- An :class:`Event` can now be used to look up its corresponding device and context using the ``.device`` and ``.context`` attributes respectively.
33+
- The :func:`launch` function's handling of fp16 scalars was incorrect and is fixed

cuda_core/tests/test_launcher.py

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1-
# Copyright 2024 NVIDIA Corporation. All rights reserved.
1+
# Copyright 2024-2025 NVIDIA Corporation. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import ctypes
5+
import os
6+
import pathlib
7+
8+
import numpy as np
49
import pytest
510

6-
from cuda.core.experimental import Device, LaunchConfig, Program, launch
11+
from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions, launch
12+
from cuda.core.experimental._memory import _DefaultPinnedMemorySource
713

814

915
def test_launch_config_init(init_cuda):
@@ -59,3 +65,85 @@ def test_launch_invalid_values(init_cuda):
5965
launch(stream, ker, None)
6066

6167
launch(stream, config, ker)
68+
69+
70+
# Parametrize: (python_type, cpp_type, init_value)
71+
PARAMS = (
72+
(bool, 'bool', True),
73+
(float, 'double', 2.718),
74+
(np.bool, 'bool', True),
75+
(np.int8, 'signed char', -42),
76+
(np.int16, 'signed short', -1234),
77+
(np.int32, 'signed int', -123456),
78+
(np.int64, 'signed long long', -123456789),
79+
(np.uint8, 'unsigned char', 42),
80+
(np.uint16, 'unsigned short', 1234),
81+
(np.uint32, 'unsigned int', 123456),
82+
(np.uint64, 'unsigned long long', 123456789),
83+
(np.float32, 'float', 3.14),
84+
(np.float64, 'double', 2.718),
85+
(ctypes.c_bool, 'bool', True),
86+
(ctypes.c_int8, 'signed char', -42),
87+
(ctypes.c_int16, 'signed short', -1234),
88+
(ctypes.c_int32, 'signed int', -123456),
89+
(ctypes.c_int64, 'signed long long', -123456789),
90+
(ctypes.c_uint8, 'unsigned char', 42),
91+
(ctypes.c_uint16, 'unsigned short', 1234),
92+
(ctypes.c_uint32, 'unsigned int', 123456),
93+
(ctypes.c_uint64, 'unsigned long long', 123456789),
94+
(ctypes.c_float, 'float', 3.14),
95+
(ctypes.c_double, 'double', 2.718),
96+
)
97+
if os.environ.get("CUDA_PATH"):
98+
PARAMS += (
99+
(np.float16, 'half', 0.78),
100+
(np.complex64, 'cuda::std::complex<float>', 1+2j),
101+
(np.complex128, 'cuda::std::complex<double>', -3-4j),
102+
(complex, 'cuda::std::complex<double>', 5-7j),
103+
)
104+
105+
@pytest.mark.parametrize("python_type, cpp_type, init_value", PARAMS)
106+
def test_launch_scalar_argument(python_type, cpp_type, init_value):
107+
dev = Device()
108+
dev.set_current()
109+
110+
# Prepare pinned host array
111+
mr = _DefaultPinnedMemorySource()
112+
b = mr.allocate(np.dtype(python_type).itemsize)
113+
arr = np.from_dlpack(b).view(python_type)
114+
arr[:] = 0
115+
116+
# Prepare scalar argument in Python
117+
scalar = python_type(init_value)
118+
119+
# CUDA kernel templated on type T
120+
code = r"""
121+
template <typename T>
122+
__global__ void write_scalar(T* arr, T val) {
123+
arr[0] = val;
124+
}
125+
"""
126+
127+
# Compile and force instantiation for this type
128+
arch = "".join(f"{i}" for i in dev.compute_capability)
129+
if os.environ.get("CUDA_PATH"):
130+
include_path=str(pathlib.Path(os.environ["CUDA_PATH"]) / pathlib.Path("include"))
131+
code = r"""
132+
#include <cuda_fp16.h>
133+
#include <cuda/std/complex>
134+
""" + code
135+
else:
136+
include_path=None
137+
pro_opts = ProgramOptions(std="c++11", arch=f"sm_{arch}", include_path=include_path)
138+
prog = Program(code, code_type="c++", options=pro_opts)
139+
ker_name = f"write_scalar<{cpp_type}>"
140+
mod = prog.compile("cubin", name_expressions=(ker_name,))
141+
ker = mod.get_kernel(ker_name)
142+
143+
# Launch with 1 thread
144+
config = LaunchConfig(grid=1, block=1)
145+
launch(dev.default_stream, config, ker, arr.ctypes.data, scalar)
146+
dev.default_stream.sync()
147+
148+
# Check result
149+
assert arr[0] == init_value, f"Expected {init_value}, got {arr[0]}"

0 commit comments

Comments
 (0)