Skip to content

Commit 2452fdb

Browse files
committed
Save result of factoring out load_dl_common.py, load_dl_linux.py, load_dl_windows.py with the help of Cursor.
1 parent 2d65d44 commit 2452fdb

File tree

4 files changed

+223
-239
lines changed

4 files changed

+223
-239
lines changed
Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
11
# Copyright 2025 NVIDIA Corporation. All rights reserved.
2-
#
32
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
43

54
from dataclasses import dataclass
6-
from typing import Optional
5+
from typing import Callable, Optional
6+
7+
from .supported_libs import DIRECT_DEPENDENCIES
78

89

910
@dataclass
1011
class LoadedDL:
12+
"""Represents a loaded dynamic library.
13+
14+
Attributes:
15+
handle: The library handle (can be converted to void* in Cython)
16+
abs_path: The absolute path to the library file
17+
was_already_loaded_from_elsewhere: Whether the library was already loaded
18+
"""
19+
1120
# ATTENTION: To convert `handle` back to `void*` in cython:
1221
# Linux: `cdef void* ptr = <void*><uintptr_t>`
1322
# Windows: `cdef void* ptr = <void*><intptr_t>`
@@ -17,13 +26,35 @@ class LoadedDL:
1726

1827

1928
def add_dll_directory(dll_abs_path: str) -> None:
20-
"""Add a DLL directory to the search path and update PATH environment variable."""
29+
"""Add a DLL directory to the search path and update PATH environment variable.
30+
31+
Args:
32+
dll_abs_path: Absolute path to the DLL file
33+
34+
Raises:
35+
AssertionError: If the directory containing the DLL does not exist
36+
"""
2137
import os
22-
38+
2339
dirpath = os.path.dirname(dll_abs_path)
2440
assert os.path.isdir(dirpath), dll_abs_path
2541
# Add the DLL directory to the search path
2642
os.add_dll_directory(dirpath)
2743
# Update PATH as a fallback for dependent DLL resolution
2844
curr_path = os.environ.get("PATH")
29-
os.environ["PATH"] = dirpath if curr_path is None else os.pathsep.join((curr_path, dirpath))
45+
os.environ["PATH"] = dirpath if curr_path is None else os.pathsep.join((curr_path, dirpath))
46+
47+
48+
def load_dependencies(libname: str, load_func: Callable[[str], LoadedDL]) -> None:
49+
"""Load all dependencies for a given library.
50+
51+
Args:
52+
libname: The name of the library whose dependencies should be loaded
53+
load_func: The function to use for loading libraries (e.g. load_nvidia_dynamic_library)
54+
55+
Example:
56+
>>> load_dependencies("cudart", load_nvidia_dynamic_library)
57+
# This will load all dependencies of cudart using the provided loading function
58+
"""
59+
for dep in DIRECT_DEPENDENCIES.get(libname, ()):
60+
load_func(dep)
Lines changed: 79 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,24 @@
11
# Copyright 2025 NVIDIA Corporation. All rights reserved.
2-
#
32
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
43

54
import ctypes
65
import ctypes.util
76
import os
8-
from typing import Optional, Tuple
7+
from typing import Optional
98

109
from .load_dl_common import LoadedDL
1110

12-
_LINUX_CDLL_MODE = os.RTLD_NOW | os.RTLD_GLOBAL
11+
CDLL_MODE = os.RTLD_NOW | os.RTLD_GLOBAL
1312

14-
_LIBDL_PATH = ctypes.util.find_library("dl") or "libdl.so.2"
15-
_LIBDL = ctypes.CDLL(_LIBDL_PATH)
16-
_LIBDL.dladdr.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
17-
_LIBDL.dladdr.restype = ctypes.c_int
13+
LIBDL_PATH = ctypes.util.find_library("dl") or "libdl.so.2"
14+
LIBDL = ctypes.CDLL(LIBDL_PATH)
15+
LIBDL.dladdr.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
16+
LIBDL.dladdr.restype = ctypes.c_int
1817

1918

2019
class Dl_info(ctypes.Structure):
20+
"""Structure used by dladdr to return information about a loaded symbol."""
21+
2122
_fields_ = [
2223
("dli_fname", ctypes.c_char_p), # path to .so
2324
("dli_fbase", ctypes.c_void_p),
@@ -27,50 +28,98 @@ class Dl_info(ctypes.Structure):
2728

2829

2930
def abs_path_for_dynamic_library(libname: str, handle: int) -> Optional[str]:
30-
"""Get the absolute path of a loaded dynamic library on Linux."""
31+
"""Get the absolute path of a loaded dynamic library on Linux.
32+
33+
Args:
34+
libname: The name of the library
35+
handle: The library handle
36+
37+
Returns:
38+
The absolute path to the library file, or None if no expected symbol is found
39+
40+
Raises:
41+
OSError: If dladdr fails to get information about the symbol
42+
"""
3143
from .supported_libs import EXPECTED_LIB_SYMBOLS
32-
44+
3345
for symbol_name in EXPECTED_LIB_SYMBOLS[libname]:
3446
symbol = getattr(handle, symbol_name, None)
3547
if symbol is not None:
3648
break
3749
else:
3850
return None
39-
51+
4052
addr = ctypes.cast(symbol, ctypes.c_void_p)
4153
info = Dl_info()
42-
if _LIBDL.dladdr(addr, ctypes.byref(info)) == 0:
54+
if LIBDL.dladdr(addr, ctypes.byref(info)) == 0:
4355
raise OSError(f"dladdr failed for {libname=!r}")
4456
return info.dli_fname.decode()
4557

4658

47-
def load_and_report_path(libname: str, soname: str) -> Tuple[int, str]:
48-
"""Load a dynamic library and return its handle and absolute path."""
49-
handle = ctypes.CDLL(soname, _LINUX_CDLL_MODE)
50-
abs_path = abs_path_for_dynamic_library(libname, handle)
51-
if abs_path is None:
52-
raise RuntimeError(f"No expected symbol for {libname=!r}")
53-
return handle._handle, abs_path
59+
def check_if_already_loaded(libname: str) -> Optional[LoadedDL]:
60+
"""Check if the library is already loaded in the process.
5461
62+
Args:
63+
libname: The name of the library to check
5564
56-
def load_dynamic_library(libname: str, found_path: str) -> LoadedDL:
57-
"""Load a dynamic library from the given path."""
58-
try:
59-
handle = ctypes.CDLL(found_path, _LINUX_CDLL_MODE)
60-
except OSError as e:
61-
raise RuntimeError(f"Failed to dlopen {found_path}: {e}") from e
62-
return LoadedDL(handle._handle, found_path, False)
63-
65+
Returns:
66+
A LoadedDL object if the library is already loaded, None otherwise
6467
65-
def check_if_already_loaded(libname: str) -> Optional[LoadedDL]:
66-
"""Check if the library is already loaded in the process."""
68+
Example:
69+
>>> loaded = check_if_already_loaded("cudart")
70+
>>> if loaded is not None:
71+
... print(f"Library already loaded from {loaded.abs_path}")
72+
"""
6773
from .supported_libs import SUPPORTED_LINUX_SONAMES
68-
74+
6975
for soname in SUPPORTED_LINUX_SONAMES.get(libname, ()):
7076
try:
7177
handle = ctypes.CDLL(soname, mode=os.RTLD_NOLOAD)
7278
except OSError:
7379
continue
7480
else:
7581
return LoadedDL(handle._handle, abs_path_for_dynamic_library(libname, handle), True)
76-
return None
82+
return None
83+
84+
85+
def load_with_system_search(libname: str, soname: str) -> Optional[LoadedDL]:
86+
"""Try to load a library using system search paths.
87+
88+
Args:
89+
libname: The name of the library to load
90+
soname: The soname to search for
91+
92+
Returns:
93+
A LoadedDL object if successful, None if the library cannot be loaded
94+
95+
Raises:
96+
RuntimeError: If the library is loaded but no expected symbol is found
97+
"""
98+
try:
99+
handle = ctypes.CDLL(soname, CDLL_MODE)
100+
abs_path = abs_path_for_dynamic_library(libname, handle)
101+
if abs_path is None:
102+
raise RuntimeError(f"No expected symbol for {libname=!r}")
103+
return LoadedDL(handle._handle, abs_path, False)
104+
except OSError:
105+
return None
106+
107+
108+
def load_with_abs_path(libname: str, found_path: str) -> LoadedDL:
109+
"""Load a dynamic library from the given path.
110+
111+
Args:
112+
libname: The name of the library to load
113+
found_path: The absolute path to the library file
114+
115+
Returns:
116+
A LoadedDL object representing the loaded library
117+
118+
Raises:
119+
RuntimeError: If the library cannot be loaded
120+
"""
121+
try:
122+
handle = ctypes.CDLL(found_path, CDLL_MODE)
123+
except OSError as e:
124+
raise RuntimeError(f"Failed to dlopen {found_path}: {e}") from e
125+
return LoadedDL(handle._handle, found_path, False)

cuda_bindings/cuda/bindings/_path_finder/load_dl_windows.py

Lines changed: 80 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,33 @@
11
# Copyright 2025 NVIDIA Corporation. All rights reserved.
2-
#
32
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
43

54
import ctypes
65
import ctypes.wintypes
7-
from typing import Optional, Tuple
6+
import functools
7+
from typing import Optional
88

99
import pywintypes
1010
import win32api
1111

12-
from .load_dl_common import LoadedDL
12+
from .load_dl_common import LoadedDL, add_dll_directory
1313

1414
# Mirrors WinBase.h (unfortunately not defined already elsewhere)
15-
_WINBASE_LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
16-
_WINBASE_LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
15+
WINBASE_LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
16+
WINBASE_LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
1717

1818

1919
def abs_path_for_dynamic_library(handle: int) -> str:
20-
"""Get the absolute path of a loaded dynamic library on Windows."""
20+
"""Get the absolute path of a loaded dynamic library on Windows.
21+
22+
Args:
23+
handle: The library handle
24+
25+
Returns:
26+
The absolute path to the DLL file
27+
28+
Raises:
29+
OSError: If GetModuleFileNameW fails
30+
"""
2131
buf = ctypes.create_unicode_buffer(260)
2232
n_chars = ctypes.windll.kernel32.GetModuleFileNameW(ctypes.wintypes.HMODULE(handle), buf, len(buf))
2333
if n_chars == 0:
@@ -27,7 +37,14 @@ def abs_path_for_dynamic_library(handle: int) -> str:
2737

2838
@functools.cache
2939
def cuDriverGetVersion() -> int:
30-
"""Get the CUDA driver version."""
40+
"""Get the CUDA driver version.
41+
42+
Returns:
43+
The CUDA driver version number
44+
45+
Raises:
46+
AssertionError: If the driver version cannot be obtained
47+
"""
3148
handle = win32api.LoadLibrary("nvcuda.dll")
3249

3350
kernel32 = ctypes.WinDLL("kernel32", use_last_error=True)
@@ -45,50 +62,80 @@ def cuDriverGetVersion() -> int:
4562
return driver_ver.value
4663

4764

48-
@functools.cache
49-
def load_with_dll_basename(name: str) -> Tuple[Optional[int], Optional[str]]:
50-
"""Try to load a DLL by its basename."""
65+
def check_if_already_loaded(libname: str) -> Optional[LoadedDL]:
66+
"""Check if the library is already loaded in the process.
67+
68+
Args:
69+
libname: The name of the library to check
70+
71+
Returns:
72+
A LoadedDL object if the library is already loaded, None otherwise
73+
74+
Example:
75+
>>> loaded = check_if_already_loaded("cudart")
76+
>>> if loaded is not None:
77+
... print(f"Library already loaded from {loaded.abs_path}")
78+
"""
79+
from .supported_libs import SUPPORTED_WINDOWS_DLLS
80+
81+
for dll_name in SUPPORTED_WINDOWS_DLLS.get(libname, ()):
82+
try:
83+
handle = win32api.GetModuleHandle(dll_name)
84+
except pywintypes.error:
85+
continue
86+
else:
87+
return LoadedDL(handle, abs_path_for_dynamic_library(handle), True)
88+
return None
89+
90+
91+
def load_with_system_search(name: str, _unused: str) -> Optional[LoadedDL]:
92+
"""Try to load a DLL using system search paths.
93+
94+
Args:
95+
name: The name of the library to load
96+
_unused: Unused parameter (kept for interface consistency)
97+
98+
Returns:
99+
A LoadedDL object if successful, None if the library cannot be loaded
100+
"""
51101
from .supported_libs import SUPPORTED_WINDOWS_DLLS
52-
102+
53103
driver_ver = cuDriverGetVersion()
54104
del driver_ver # Keeping this here because it will probably be needed in the future.
55105

56106
dll_names = SUPPORTED_WINDOWS_DLLS.get(name)
57107
if dll_names is None:
58-
return None, None
108+
return None
59109

60110
for dll_name in dll_names:
61111
handle = ctypes.windll.kernel32.LoadLibraryW(ctypes.c_wchar_p(dll_name))
62112
if handle:
63-
return handle, abs_path_for_dynamic_library(handle)
113+
return LoadedDL(handle, abs_path_for_dynamic_library(handle), False)
114+
115+
return None
64116

65-
return None, None
66117

118+
def load_with_abs_path(libname: str, found_path: str) -> LoadedDL:
119+
"""Load a dynamic library from the given path.
67120
68-
def load_dynamic_library(libname: str, found_path: str) -> LoadedDL:
69-
"""Load a dynamic library from the given path."""
121+
Args:
122+
libname: The name of the library to load
123+
found_path: The absolute path to the DLL file
124+
125+
Returns:
126+
A LoadedDL object representing the loaded library
127+
128+
Raises:
129+
RuntimeError: If the DLL cannot be loaded
130+
"""
70131
from .supported_libs import LIBNAMES_REQUIRING_OS_ADD_DLL_DIRECTORY
71-
132+
72133
if libname in LIBNAMES_REQUIRING_OS_ADD_DLL_DIRECTORY:
73134
add_dll_directory(found_path)
74-
75-
flags = _WINBASE_LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | _WINBASE_LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR
135+
136+
flags = WINBASE_LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | WINBASE_LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR
76137
try:
77138
handle = win32api.LoadLibraryEx(found_path, 0, flags)
78139
except pywintypes.error as e:
79140
raise RuntimeError(f"Failed to load DLL at {found_path}: {e}") from e
80141
return LoadedDL(handle, found_path, False)
81-
82-
83-
def check_if_already_loaded(libname: str) -> Optional[LoadedDL]:
84-
"""Check if the library is already loaded in the process."""
85-
from .supported_libs import SUPPORTED_WINDOWS_DLLS
86-
87-
for dll_name in SUPPORTED_WINDOWS_DLLS.get(libname, ()):
88-
try:
89-
handle = win32api.GetModuleHandle(dll_name)
90-
except pywintypes.error:
91-
continue
92-
else:
93-
return LoadedDL(handle, abs_path_for_dynamic_library(handle), True)
94-
return None

0 commit comments

Comments
 (0)