Skip to content

Commit 9572f8a

Browse files
committed
add a code sample for strided memory view
1 parent b5cfdce commit 9572f8a

File tree

2 files changed

+218
-36
lines changed

2 files changed

+218
-36
lines changed
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
2+
#
3+
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
4+
5+
# ################################################################################
6+
#
7+
# This demo aims to illustrate two takeaways:
8+
#
9+
# 1. The similarity between CPU and GPU JIT-compilation for C++ sources
10+
# 2. How to use StridedMemoryView to interface with foreign CPU/GPU functions
11+
# at low-level
12+
#
13+
# To facilitate this demo, we use cffi (https://cffi.readthedocs.io/) for the CPU
14+
# path, which can be easily installed from pip or conda following their instruction.
15+
# We also use NumPy/CuPy as the CPU/GPU array container.
16+
#
17+
# ################################################################################
18+
19+
import string
20+
import sys
21+
22+
try:
23+
from cffi import FFI
24+
except ImportError:
25+
print("cffi is not installed, the CPU example would be skipped", file=sys.stderr)
26+
cffi = None
27+
try:
28+
import cupy as cp
29+
except ImportError:
30+
print("cupy is not installed, the GPU example would be skipped", file=sys.stderr)
31+
cp = None
32+
import numpy as np
33+
34+
from cuda.core.experimental import Device, Program
35+
from cuda.core.experimental import launch, LaunchConfig
36+
from cuda.core.experimental.utils import args_viewable_as_strided_memory
37+
from cuda.core.experimental.utils import StridedMemoryView
38+
39+
40+
# ################################################################################
41+
#
42+
# Usually this entire code block is in a separate file, built as a Python extension
43+
# module that can be imported by users at run time. For illustrative purposes we
44+
# use JIT compilation to make this demo self-contained.
45+
#
46+
# Here we assume an in-place operation, equivalent to the following NumPy code:
47+
#
48+
# >>> arr = ...
49+
# >>> assert arr.dtype == np.int32
50+
# >>> assert arr.ndim == 1
51+
# >>> arr += np.arange(arr.size, dtype=arr.dtype)
52+
#
53+
# is implemented for both CPU and GPU at low-level, with the following C function
54+
# signature:
55+
func_name = "inplace_plus_arange_N"
56+
func_sig = f"void {func_name}(int* data, size_t N)"
57+
58+
# Here is a concrete (very naive!) implementation on CPU:
59+
if FFI:
60+
cpu_code = string.Template(r"""
61+
extern "C" {
62+
$func_sig {
63+
for (size_t i = 0; i < N; i++) {
64+
data[i] += i;
65+
}
66+
}
67+
}
68+
""").substitute(func_sig=func_sig)
69+
cpu_prog = FFI()
70+
cpu_prog.set_source("_cpu_obj", cpu_code, source_extension=".cpp")
71+
cpu_prog.cdef(f"{func_sig};")
72+
cpu_prog.compile()
73+
# This is cffi's way of loading a CPU function. cffi builds an extension module
74+
# that has the Python binding to the underlying C function. (For more details,
75+
# please refer to cffi's documentation.)
76+
from _cpu_obj.lib import inplace_plus_arange_N as cpu_func
77+
78+
# Here is a concrete (again, very naive!) implementation on GPU:
79+
if cp:
80+
gpu_code = string.Template(r"""
81+
extern "C"
82+
__global__ $func_sig {
83+
const size_t tid = threadIdx.x + blockIdx.x * blockDim.x;
84+
const size_t stride_size = gridDim.x * blockDim.x;
85+
for (size_t i = tid; i < N; i += stride_size) {
86+
data[i] += i;
87+
}
88+
}
89+
""").substitute(func_sig=func_sig)
90+
gpu_prog = Program(gpu_code, code_type="c++")
91+
# To know the GPU's compute capability, we need to identify which GPU to use.
92+
dev = Device(0)
93+
arch = "".join(f"{i}" for i in dev.compute_capability)
94+
mod = gpu_prog.compile(
95+
target_type="cubin",
96+
# TODO: update this after NVIDIA/cuda-python#237 is merged
97+
options=(f"-arch=sm_{arch}", "-std=c++11"))
98+
gpu_ker = mod.get_kernel(func_name)
99+
100+
# Now we are prepared to run the code from the user's perspective!
101+
#
102+
# ################################################################################
103+
104+
105+
# Below, as a user we want to perform the said in-place operation on either CPU
106+
# or GPU, by calling the corresponding function implemented "elsewhere" (done above).
107+
108+
@args_viewable_as_strided_memory((0,))
109+
def my_func(arr, work_stream):
110+
# create a memory view over arr, assumed to be a 1D array of int32
111+
view = arr.view(work_stream.handle if work_stream else -1)
112+
assert isinstance(view, StridedMemoryView)
113+
assert len(view.shape) == 1
114+
assert view.dtype == np.int32
115+
116+
size = view.shape[0]
117+
if view.is_device_accessible:
118+
block = 256
119+
grid = size // 256
120+
config = LaunchConfig(grid=grid, block=block, stream=work_stream)
121+
launch(gpu_ker, config, view.ptr, np.uint64(size))
122+
# here we're being conservative and synchronize over our work stream,
123+
# assuming we do not know the (producer/source) stream; if we know
124+
# then we could just order the producer/consumer streams here, e.g.
125+
#
126+
# producer_stream.wait(work_stream)
127+
#
128+
# without an expansive synchronization.
129+
work_stream.sync()
130+
else:
131+
cpu_func(cpu_prog.cast("int*", view.ptr), size)
132+
133+
134+
# This takes the CPU path
135+
if FFI:
136+
# Create input array on CPU
137+
arr_cpu = np.zeros(1024, dtype=np.int32)
138+
print(f"before: {arr_cpu[:10]=}")
139+
140+
# Run the workload
141+
my_func(arr_cpu, None)
142+
143+
# Check the result
144+
print(f"after: {arr_cpu[:10]=}")
145+
assert np.allclose(arr_cpu, np.arange(1024, dtype=np.int32))
146+
147+
148+
# This takes the GPU path
149+
if cp:
150+
dev.set_current()
151+
s = dev.create_stream()
152+
# Create input array on GPU
153+
arr_gpu = cp.ones(1024, dtype=cp.int32)
154+
print(f"before: {arr_gpu[:10]=}")
155+
156+
# Run the workload
157+
my_func(arr_gpu, s)
158+
159+
# Check the result
160+
print(f"after: {arr_gpu[:10]=}")
161+
assert cp.allclose(arr_gpu, 1 + cp.arange(1024, dtype=cp.int32))
162+
s.close()

cuda_core/tests/conftest.py

Lines changed: 56 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,56 @@
1-
# Copyright 2024 NVIDIA Corporation. All rights reserved.
2-
#
3-
# Please refer to the NVIDIA end user license agreement (EULA) associated
4-
# with this source code for terms and conditions that govern your use of
5-
# this software. Any use, reproduction, disclosure, or distribution of
6-
# this software and related documentation outside the terms of the EULA
7-
# is strictly prohibited.
8-
try:
9-
from cuda.bindings import driver
10-
except ImportError:
11-
from cuda import cuda as driver
12-
13-
import pytest
14-
15-
from cuda.core.experimental import Device, _device
16-
from cuda.core.experimental._utils import handle_return
17-
18-
19-
@pytest.fixture(scope="function")
20-
def init_cuda():
21-
device = Device()
22-
device.set_current()
23-
yield
24-
_device_unset_current()
25-
26-
27-
def _device_unset_current():
28-
handle_return(driver.cuCtxPopCurrent())
29-
with _device._tls_lock:
30-
del _device._tls.devices
31-
32-
33-
@pytest.fixture(scope="function")
34-
def deinit_cuda():
35-
yield
36-
_device_unset_current()
1+
# Copyright 2024 NVIDIA Corporation. All rights reserved.
2+
#
3+
# Please refer to the NVIDIA end user license agreement (EULA) associated
4+
# with this source code for terms and conditions that govern your use of
5+
# this software. Any use, reproduction, disclosure, or distribution of
6+
# this software and related documentation outside the terms of the EULA
7+
# is strictly prohibited.
8+
9+
import glob
10+
import os
11+
import sys
12+
13+
try:
14+
from cuda.bindings import driver
15+
except ImportError:
16+
from cuda import cuda as driver
17+
18+
import pytest
19+
20+
from cuda.core.experimental import Device, _device
21+
from cuda.core.experimental._utils import handle_return
22+
23+
24+
@pytest.fixture(scope="function")
25+
def init_cuda():
26+
device = Device()
27+
device.set_current()
28+
yield
29+
_device_unset_current()
30+
31+
32+
def _device_unset_current():
33+
handle_return(driver.cuCtxPopCurrent())
34+
with _device._tls_lock:
35+
del _device._tls.devices
36+
37+
38+
@pytest.fixture(scope="function")
39+
def deinit_cuda():
40+
yield
41+
_device_unset_current()
42+
43+
44+
# samples relying on cffi could fail as the modules cannot be imported
45+
sys.path.append(os.getcwd())
46+
47+
48+
@pytest.fixture(scope="session", autouse=True)
49+
def clean_up_cffi_files():
50+
yield
51+
files = glob.glob(os.path.join(os.getcwd(), "_cpu_obj*"))
52+
for f in files:
53+
try:
54+
os.remove(f)
55+
except FileNotFoundError:
56+
pass

0 commit comments

Comments
 (0)