Skip to content

Commit 299ce8d

Browse files
committed
improve type hints to make Ruff happy
1 parent 8384937 commit 299ce8d

File tree

4 files changed

+26
-7
lines changed

4 files changed

+26
-7
lines changed

cuda_core/cuda/core/experimental/_event.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,17 @@
22
#
33
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
44

5+
from __future__ import annotations
6+
57
import weakref
68
from dataclasses import dataclass
7-
from typing import Optional
9+
from typing import TYPE_CHECKING, Optional
810

911
from cuda.core.experimental._utils import CUDAError, check_or_create_options, driver, handle_return
1012

13+
if TYPE_CHECKING:
14+
import cuda.bindings
15+
1116

1217
@dataclass
1318
class EventOptions:
@@ -130,6 +135,6 @@ def is_done(self) -> bool:
130135
raise CUDAError(f"unexpected error: {result}")
131136

132137
@property
133-
def handle(self) -> "CUevent":
138+
def handle(self) -> cuda.bindings.driver.CUevent:
134139
"""Return the underlying cudaEvent_t pointer address as Python int."""
135140
return self._mnff.handle

cuda_core/cuda/core/experimental/_linker.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,18 @@
22
#
33
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
44

5+
from __future__ import annotations
6+
57
import ctypes
68
import weakref
79
from contextlib import contextmanager
810
from dataclasses import dataclass
9-
from typing import List, Optional, Tuple, Union
11+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
1012
from warnings import warn
1113

14+
if TYPE_CHECKING:
15+
import cuda.bindings
16+
1217
from cuda.core.experimental._device import Device
1318
from cuda.core.experimental._module import ObjectCode
1419
from cuda.core.experimental._utils import check_or_create_options, driver, handle_return, is_sequence
@@ -324,7 +329,7 @@ def _exception_manager(self):
324329

325330

326331
nvJitLinkHandleT = int
327-
LinkerHandleT = Union[nvJitLinkHandleT, "CUlinkState"]
332+
LinkerHandleT = Union[nvJitLinkHandleT, "cuda.bindings.driver.CUlinkState"]
328333

329334

330335
class Linker:

cuda_core/cuda/core/experimental/_program.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,16 @@
22
#
33
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
44

5+
from __future__ import annotations
6+
57
import weakref
68
from dataclasses import dataclass
7-
from typing import List, Optional, Tuple, Union
9+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
810
from warnings import warn
911

12+
if TYPE_CHECKING:
13+
import cuda.bindings
14+
1015
from cuda.core.experimental._device import Device
1116
from cuda.core.experimental._linker import Linker, LinkerHandleT, LinkerOptions
1217
from cuda.core.experimental._module import ObjectCode
@@ -331,6 +336,9 @@ def __repr__(self):
331336
return self._formatted_options
332337

333338

339+
ProgramHandleT = Union["cuda.bindings.nvrtc.nvrtcProgram", LinkerHandleT]
340+
341+
334342
class Program:
335343
"""Represent a compilation machinery to process programs into
336344
:obj:`~_module.ObjectCode`.
@@ -498,6 +506,6 @@ def backend(self) -> str:
498506
return self._backend
499507

500508
@property
501-
def handle(self) -> Union["nvrtcProgram", LinkerHandleT]:
509+
def handle(self) -> ProgramHandleT:
502510
"""Return the underlying handle object."""
503511
return self._mnff.handle

cuda_core/cuda/core/experimental/_stream.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import TYPE_CHECKING, Optional, Tuple, Union
1212

1313
if TYPE_CHECKING:
14+
import cuda.bindings
1415
from cuda.core.experimental._device import Device
1516
from cuda.core.experimental._context import Context
1617
from cuda.core.experimental._event import Event, EventOptions
@@ -147,7 +148,7 @@ def __cuda_stream__(self) -> Tuple[int, int]:
147148
return (0, self.handle)
148149

149150
@property
150-
def handle(self) -> "CUstream":
151+
def handle(self) -> cuda.bindings.driver.CUstream:
151152
"""Return the underlying cudaStream_t pointer address as Python int."""
152153
return self._mnff.handle
153154

0 commit comments

Comments
 (0)