-
Notifications
You must be signed in to change notification settings - Fork 171
PyTorch example #579
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
PyTorch example #579
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
598c105
PyTorch example
msaroufim ff161bd
lint
msaroufim e33c2c2
simplify example
msaroufim aa58a6e
signoff
msaroufim a4b9e82
Merge branch 'main' into pytorch_example
msaroufim 1378e6a
test suite changes
msaroufim 4609a2f
Update cuda_core/examples/pytorch_example.py
msaroufim c265c2a
Update cuda_core/examples/pytorch_example.py
msaroufim 3917b40
Update requirements-cu12.txt
msaroufim 3bc93b0
Update requirements-cu11.txt
msaroufim 01ad808
Merge branch 'main' into pytorch_example
msaroufim cef69e5
Update requirements-cu12.txt
msaroufim f840546
Update requirements-cu11.txt
msaroufim 5f68dbe
Update pytorch_example.py
msaroufim 66964ee
Update pytorch_example.py
msaroufim a750cdc
Update cuda_core/examples/pytorch_example.py
msaroufim 7113750
remove .item() call
msaroufim c2b7702
Merge branch 'main' into pytorch_example
msaroufim 7d3582f
Defer CI setup to later
leofang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
## Usage: pip install "cuda-core[cu12]" | ||
## python python_example.py | ||
import sys | ||
|
||
import torch | ||
|
||
from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions, launch | ||
|
||
# SAXPY kernel - passing a as a pointer to avoid any type issues | ||
code = """ | ||
template<typename T> | ||
__global__ void saxpy_kernel(const T* a, const T* x, const T* y, T* out, size_t N) { | ||
const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x; | ||
if (tid < N) { | ||
// Dereference a to get the scalar value | ||
out[tid] = (*a) * x[tid] + y[tid]; | ||
} | ||
} | ||
""" | ||
|
||
dev = Device() | ||
dev.set_current() | ||
|
||
# Get PyTorch's current stream | ||
pt_stream = torch.cuda.current_stream() | ||
print(f"PyTorch stream: {pt_stream}") | ||
|
||
|
||
# Create a wrapper class that implements __cuda_stream__ | ||
class PyTorchStreamWrapper: | ||
def __init__(self, pt_stream): | ||
self.pt_stream = pt_stream | ||
|
||
def __cuda_stream__(self): | ||
stream_id = self.pt_stream.cuda_stream | ||
return (0, stream_id) # Return format required by CUDA Python | ||
|
||
|
||
s = PyTorchStreamWrapper(pt_stream) | ||
|
||
# prepare program | ||
arch = "".join(f"{i}" for i in dev.compute_capability) | ||
program_options = ProgramOptions(std="c++11", arch=f"sm_{arch}") | ||
prog = Program(code, code_type="c++", options=program_options) | ||
mod = prog.compile( | ||
"cubin", | ||
logs=sys.stdout, | ||
name_expressions=("saxpy_kernel<float>", "saxpy_kernel<double>"), | ||
) | ||
|
||
# Run in single precision | ||
ker = mod.get_kernel("saxpy_kernel<float>") | ||
dtype = torch.float32 | ||
|
||
# prepare input/output | ||
size = 64 | ||
# Use a single element tensor for 'a' | ||
a = torch.tensor([10.0], dtype=dtype, device="cuda") | ||
x = torch.rand(size, dtype=dtype, device="cuda") | ||
y = torch.rand(size, dtype=dtype, device="cuda") | ||
out = torch.empty_like(x) | ||
|
||
# prepare launch | ||
block = 32 | ||
grid = int((size + block - 1) // block) | ||
config = LaunchConfig(grid=grid, block=block) | ||
ker_args = (a.data_ptr(), x.data_ptr(), y.data_ptr(), out.data_ptr(), size) | ||
|
||
# launch kernel on our stream | ||
launch(s, config, ker, *ker_args) | ||
|
||
# check result | ||
assert torch.allclose(out, a.item() * x + y) | ||
print("Single precision test passed!") | ||
|
||
# let's repeat again with double precision | ||
ker = mod.get_kernel("saxpy_kernel<double>") | ||
dtype = torch.float64 | ||
|
||
# prepare input | ||
size = 128 | ||
# Use a single element tensor for 'a' | ||
a = torch.tensor([42.0], dtype=dtype, device="cuda") | ||
x = torch.rand(size, dtype=dtype, device="cuda") | ||
y = torch.rand(size, dtype=dtype, device="cuda") | ||
|
||
# prepare output | ||
out = torch.empty_like(x) | ||
|
||
# prepare launch | ||
block = 64 | ||
grid = int((size + block - 1) // block) | ||
config = LaunchConfig(grid=grid, block=block) | ||
ker_args = (a.data_ptr(), x.data_ptr(), y.data_ptr(), out.data_ptr(), size) | ||
|
||
# launch kernel on PyTorch's stream | ||
launch(s, config, ker, *ker_args) | ||
|
||
# check result | ||
assert torch.allclose(out, a * x + y) | ||
print("Double precision test passed!") | ||
print("All tests passed successfully!") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.