|
19 | 19 | class LaunchConfig:
|
20 | 20 | """
|
21 | 21 | """
|
| 22 | + # TODO: expand LaunchConfig to include other attributes |
22 | 23 | grid: Union[tuple, int] = None
|
23 | 24 | block: Union[tuple, int] = None
|
24 | 25 | stream: Stream = None
|
@@ -67,24 +68,30 @@ def launch(kernel, config, *kernel_args):
|
67 | 68 | if not isinstance(kernel, Kernel):
|
68 | 69 | raise ValueError
|
69 | 70 | config = check_or_create_options(LaunchConfig, config, "launch config")
|
| 71 | + if config.stream is None: |
| 72 | + raise CUDAError("stream cannot be None") |
| 73 | + |
70 | 74 | # TODO: can we ensure kernel_args is valid/safe to use here?
|
| 75 | + # TODO: merge with HelperKernelParams? |
| 76 | + kernel_args = ParamHolder(kernel_args) |
| 77 | + args_ptr = kernel_args.ptr |
71 | 78 |
|
72 | 79 | driver_ver = handle_return(cuda.cuDriverGetVersion())
|
73 | 80 | if driver_ver >= 12000:
|
74 | 81 | drv_cfg = cuda.CUlaunchConfig()
|
75 | 82 | drv_cfg.gridDimX, drv_cfg.gridDimY, drv_cfg.gridDimZ = config.grid
|
76 | 83 | drv_cfg.blockDimX, drv_cfg.blockDimY, drv_cfg.blockDimZ = config.block
|
77 |
| - if config.stream is None: |
78 |
| - raise CUDAError("stream cannot be None") |
79 | 84 | drv_cfg.hStream = config.stream._handle
|
80 | 85 | drv_cfg.sharedMemBytes = config.shmem_size
|
81 |
| - drv_cfg.numAttrs = 0 # FIXME |
82 |
| - |
83 |
| - # TODO: merge with HelperKernelParams? |
84 |
| - kernel_args = ParamHolder(kernel_args) |
85 |
| - args_ptr = kernel_args.ptr |
86 |
| - |
| 86 | + drv_cfg.numAttrs = 0 # TODO |
87 | 87 | handle_return(cuda.cuLaunchKernelEx(
|
88 | 88 | drv_cfg, int(kernel._handle), args_ptr, 0))
|
89 | 89 | else:
|
90 |
| - raise NotImplementedError("TODO") |
| 90 | + # TODO: check if config has any unsupported attrs |
| 91 | + handle_return(cuda.cuLaunchKernel( |
| 92 | + int(kernel._handle), |
| 93 | + *config.grid, |
| 94 | + *config.block, |
| 95 | + config.shmem_size, |
| 96 | + config.stream._handle, |
| 97 | + args_ptr, 0)) |
0 commit comments