Skip to content

PyTorch example #579

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions cuda_core/examples/pytorch_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
# SPDX-License-Identifier: Apache-2.0

## Usage: pip install "cuda-core[cu12]"
## python python_example.py
import sys

import torch

from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions, launch

# SAXPY kernel - passing a as a pointer to avoid any type issues
code = """
template<typename T>
__global__ void saxpy_kernel(const T* a, const T* x, const T* y, T* out, size_t N) {
const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < N) {
// Dereference a to get the scalar value
out[tid] = (*a) * x[tid] + y[tid];
}
}
"""

dev = Device()
dev.set_current()

# Get PyTorch's current stream
pt_stream = torch.cuda.current_stream()
print(f"PyTorch stream: {pt_stream}")


# Create a wrapper class that implements __cuda_stream__
class PyTorchStreamWrapper:
def __init__(self, pt_stream):
self.pt_stream = pt_stream

def __cuda_stream__(self):
stream_id = self.pt_stream.cuda_stream
return (0, stream_id) # Return format required by CUDA Python


s = PyTorchStreamWrapper(pt_stream)

# prepare program
arch = "".join(f"{i}" for i in dev.compute_capability)
program_options = ProgramOptions(std="c++11", arch=f"sm_{arch}")
prog = Program(code, code_type="c++", options=program_options)
mod = prog.compile(
"cubin",
logs=sys.stdout,
name_expressions=("saxpy_kernel<float>", "saxpy_kernel<double>"),
)

# Run in single precision
ker = mod.get_kernel("saxpy_kernel<float>")
dtype = torch.float32

# prepare input/output
size = 64
# Use a single element tensor for 'a'
a = torch.tensor([10.0], dtype=dtype, device="cuda")
x = torch.rand(size, dtype=dtype, device="cuda")
y = torch.rand(size, dtype=dtype, device="cuda")
out = torch.empty_like(x)

# prepare launch
block = 32
grid = int((size + block - 1) // block)
config = LaunchConfig(grid=grid, block=block)
ker_args = (a.data_ptr(), x.data_ptr(), y.data_ptr(), out.data_ptr(), size)

# launch kernel on our stream
launch(s, config, ker, *ker_args)

# check result
assert torch.allclose(out, a.item() * x + y)
print("Single precision test passed!")

# let's repeat again with double precision
ker = mod.get_kernel("saxpy_kernel<double>")
dtype = torch.float64

# prepare input
size = 128
# Use a single element tensor for 'a'
a = torch.tensor([42.0], dtype=dtype, device="cuda")
x = torch.rand(size, dtype=dtype, device="cuda")
y = torch.rand(size, dtype=dtype, device="cuda")

# prepare output
out = torch.empty_like(x)

# prepare launch
block = 64
grid = int((size + block - 1) // block)
config = LaunchConfig(grid=grid, block=block)
ker_args = (a.data_ptr(), x.data_ptr(), y.data_ptr(), out.data_ptr(), size)

# launch kernel on PyTorch's stream
launch(s, config, ker, *ker_args)

# check result
assert torch.allclose(out, a * x + y)
print("Double precision test passed!")
print("All tests passed successfully!")
2 changes: 1 addition & 1 deletion cuda_core/tests/example_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def run_example(samples_path, filename, env=None):
exec(script, env if env else {}) # nosec B102
except ImportError as e:
# for samples requiring any of optional dependencies
for m in ("cupy",):
for m in ("cupy", "torch"):
if f"No module named '{m}'" in str(e):
pytest.skip(f"{m} not installed, skipping related tests")
break
Expand Down
Loading