Skip to content

Make test_nvjitlink.py less noisy. #666

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 31, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 37 additions & 39 deletions cuda_bindings/tests/test_nvjitlink.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
# Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

import pytest
Expand All @@ -11,15 +10,13 @@
PTX_VERSIONS = ["5.0", "6.4", "7.0", "8.5"]


def ptx_header(version, arch):
return f"""
.version {version}
.target {arch}
PTX_HEADER = """\
.version {VERSION}
.target {ARCH}
.address_size 64
"""


ptx_kernel = """
PTX_KERNEL = """
.visible .entry _Z6kernelPi(
.param .u64 _Z6kernelPi_param_0
)
Expand All @@ -36,20 +33,21 @@ def ptx_header(version, arch):
}
"""

minimal_ptx_kernel = """
.func _MinimalKernel()
{
ret;
}
"""

ptx_kernel_bytes = [
(ptx_header(version, arch) + ptx_kernel).encode("utf-8") for version, arch in zip(PTX_VERSIONS, ARCHITECTURES)
]
minimal_ptx_kernel_bytes = [
(ptx_header(version, arch) + minimal_ptx_kernel).encode("utf-8")
for version, arch in zip(PTX_VERSIONS, ARCHITECTURES)
]
def _build_arch_ptx_parametrized_callable():
av = tuple(zip(ARCHITECTURES, PTX_VERSIONS))
return pytest.mark.parametrize(
("arch", "ptx_bytes"),
[(a, (PTX_HEADER.format(VERSION=v, ARCH=a) + PTX_KERNEL).encode("utf-8")) for a, v in av],
ids=[f"{a}_{v}" for a, v in av],
)


ARCH_PTX_PARAMETRIZED_CALLABLE = _build_arch_ptx_parametrized_callable()


def arch_ptx_parametrized(func):
return ARCH_PTX_PARAMETRIZED_CALLABLE(func)


def check_nvjitlink_usable():
Expand Down Expand Up @@ -108,27 +106,27 @@ def test_complete_empty(option):
nvjitlink.destroy(handle)


@pytest.mark.parametrize("option, ptx_bytes", zip(ARCHITECTURES, ptx_kernel_bytes))
def test_add_data(option, ptx_bytes):
handle = nvjitlink.create(1, [f"-arch={option}"])
@arch_ptx_parametrized
def test_add_data(arch, ptx_bytes):
handle = nvjitlink.create(1, [f"-arch={arch}"])
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
nvjitlink.complete(handle)
nvjitlink.destroy(handle)


@pytest.mark.parametrize("option, ptx_bytes", zip(ARCHITECTURES, ptx_kernel_bytes))
def test_add_file(option, ptx_bytes, tmp_path):
handle = nvjitlink.create(1, [f"-arch={option}"])
@arch_ptx_parametrized
def test_add_file(arch, ptx_bytes, tmp_path):
handle = nvjitlink.create(1, [f"-arch={arch}"])
file_path = tmp_path / "test_file.cubin"
file_path.write_bytes(ptx_bytes)
nvjitlink.add_file(handle, nvjitlink.InputType.ANY, str(file_path))
nvjitlink.complete(handle)
nvjitlink.destroy(handle)


@pytest.mark.parametrize("option", ARCHITECTURES)
def test_get_error_log(option):
handle = nvjitlink.create(1, [f"-arch={option}"])
@pytest.mark.parametrize("arch", ARCHITECTURES)
def test_get_error_log(arch):
handle = nvjitlink.create(1, [f"-arch={arch}"])
nvjitlink.complete(handle)
log_size = nvjitlink.get_error_log_size(handle)
log = bytearray(log_size)
Expand All @@ -137,9 +135,9 @@ def test_get_error_log(option):
nvjitlink.destroy(handle)


@pytest.mark.parametrize("option, ptx_bytes", zip(ARCHITECTURES, ptx_kernel_bytes))
def test_get_info_log(option, ptx_bytes):
handle = nvjitlink.create(1, [f"-arch={option}"])
@arch_ptx_parametrized
def test_get_info_log(arch, ptx_bytes):
handle = nvjitlink.create(1, [f"-arch={arch}"])
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
nvjitlink.complete(handle)
log_size = nvjitlink.get_info_log_size(handle)
Expand All @@ -149,9 +147,9 @@ def test_get_info_log(option, ptx_bytes):
nvjitlink.destroy(handle)


@pytest.mark.parametrize("option, ptx_bytes", zip(ARCHITECTURES, ptx_kernel_bytes))
def test_get_linked_cubin(option, ptx_bytes):
handle = nvjitlink.create(1, [f"-arch={option}"])
@arch_ptx_parametrized
def test_get_linked_cubin(arch, ptx_bytes):
handle = nvjitlink.create(1, [f"-arch={arch}"])
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
nvjitlink.complete(handle)
cubin_size = nvjitlink.get_linked_cubin_size(handle)
Expand All @@ -161,9 +159,9 @@ def test_get_linked_cubin(option, ptx_bytes):
nvjitlink.destroy(handle)


@pytest.mark.parametrize("option", ARCHITECTURES)
def test_get_linked_ptx(option, get_dummy_ltoir):
handle = nvjitlink.create(3, [f"-arch={option}", "-lto", "-ptx"])
@pytest.mark.parametrize("arch", ARCHITECTURES)
def test_get_linked_ptx(arch, get_dummy_ltoir):
handle = nvjitlink.create(3, [f"-arch={arch}", "-lto", "-ptx"])
nvjitlink.add_data(handle, nvjitlink.InputType.LTOIR, get_dummy_ltoir, len(get_dummy_ltoir), "test_data")
nvjitlink.complete(handle)
ptx_size = nvjitlink.get_linked_ptx_size(handle)
Expand Down
Loading