2
2
#
3
3
# SPDX-License-Identifier: Apache-2.0
4
4
5
- import re
5
+
6
6
from functools import cached_property
7
7
8
8
import dpnp
19
19
from numba_dpex .core .datamodel .models import _init_data_model_manager
20
20
from numba_dpex .core .exceptions import UnsupportedKernelArgumentError
21
21
from numba_dpex .core .typeconv import to_usm_ndarray
22
- from numba_dpex .core .types import DpnpNdArray , USMNdArray
22
+ from numba_dpex .core .types import USMNdArray
23
23
from numba_dpex .core .utils import get_info_from_suai
24
24
from numba_dpex .utils import address_space , calling_conv
25
25
26
26
from .. import codegen
27
27
28
28
CC_SPIR_KERNEL = "spir_kernel"
29
29
CC_SPIR_FUNC = "spir_func"
30
- VALID_CHARS = re .compile (r"[^a-z0-9]" , re .I )
31
- LINK_ATOMIC = 111
32
30
LLVM_SPIRV_ARGS = 112
33
31
34
32
@@ -89,14 +87,15 @@ def resolve_argument_type(self, val):
89
87
90
88
def load_additional_registries (self ):
91
89
"""Register the OpenCL API and math and other functions."""
92
- from numba .core .typing import cmathdecl , npydecl
90
+ from numba .core .typing import cmathdecl , enumdecl , npydecl
93
91
94
92
from ...ocl import mathdecl , ocldecl
95
93
96
94
self .install_registry (ocldecl .registry )
97
95
self .install_registry (mathdecl .registry )
98
96
self .install_registry (cmathdecl .registry )
99
97
self .install_registry (npydecl .registry )
98
+ self .install_registry (enumdecl .registry )
100
99
101
100
102
101
class SyclDevice (GPU ):
@@ -105,7 +104,7 @@ class SyclDevice(GPU):
105
104
pass
106
105
107
106
108
- DPEX_KERNEL_TARGET_NAME = "SyclDevice "
107
+ DPEX_KERNEL_TARGET_NAME = "dpex_kernel "
109
108
110
109
target_registry [DPEX_KERNEL_TARGET_NAME ] = SyclDevice
111
110
@@ -165,7 +164,7 @@ def _gen_arg_base_type(self, fn):
165
164
name = llvmir .MetaDataString (mod , "kernel_arg_base_type" )
166
165
return mod .add_metadata ([name ] + consts )
167
166
168
- def _finalize_wrapper_module (self , fn ):
167
+ def _finalize_kernel_wrapper_module (self , fn ):
169
168
"""Add metadata and calling convention to the wrapper function.
170
169
171
170
The helper function adds function metadata to the wrapper function and
@@ -177,41 +176,12 @@ def _finalize_wrapper_module(self, fn):
177
176
fn: LLVM function representing the "kernel" wrapper function.
178
177
179
178
"""
180
- mod = fn .module
181
179
# Set norecurse
182
180
fn .attributes .add ("norecurse" )
183
181
# Set SPIR kernel calling convention
184
182
fn .calling_convention = CC_SPIR_KERNEL
185
183
186
- # Mark kernels
187
- ocl_kernels = cgutils .get_or_insert_named_metadata (
188
- mod , "opencl.kernels"
189
- )
190
- ocl_kernels .add (
191
- mod .add_metadata (
192
- [
193
- fn ,
194
- self ._gen_arg_addrspace_md (fn ),
195
- self ._gen_arg_type (fn ),
196
- self ._gen_arg_type_qual (fn ),
197
- self ._gen_arg_base_type (fn ),
198
- ],
199
- )
200
- )
201
-
202
- # Other metadata
203
- others = [
204
- "opencl.used.extensions" ,
205
- "opencl.used.optional.core.features" ,
206
- "opencl.compiler.options" ,
207
- ]
208
-
209
- for name in others :
210
- nmd = cgutils .get_or_insert_named_metadata (mod , name )
211
- if not nmd .operands :
212
- mod .add_metadata ([])
213
-
214
- def _generate_kernel_wrapper (self , func , argtypes ):
184
+ def _generate_spir_kernel_wrapper (self , func , argtypes ):
215
185
module = func .module
216
186
arginfo = self .get_arg_packer (argtypes )
217
187
wrapperfnty = llvmir .FunctionType (
@@ -227,7 +197,7 @@ def _generate_kernel_wrapper(self, func, argtypes):
227
197
func = llvmir .Function (wrapper_module , fnty , name = func .name )
228
198
func .calling_convention = CC_SPIR_FUNC
229
199
wrapper = llvmir .Function (wrapper_module , wrapperfnty , name = wrappername )
230
- builder = llvmir .IRBuilder (wrapper .append_basic_block ("" ))
200
+ builder = llvmir .IRBuilder (wrapper .append_basic_block ("entry " ))
231
201
232
202
callargs = arginfo .from_arguments (builder , wrapper .args )
233
203
@@ -237,7 +207,7 @@ def _generate_kernel_wrapper(self, func, argtypes):
237
207
)
238
208
builder .ret_void ()
239
209
240
- self ._finalize_wrapper_module (wrapper )
210
+ self ._finalize_kernel_wrapper_module (wrapper )
241
211
242
212
# Link the spir_func module to the wrapper module
243
213
module .link_in (ll .parse_assembly (str (wrapper_module )))
@@ -251,7 +221,10 @@ def __init__(self, typingctx, target=DPEX_KERNEL_TARGET_NAME):
251
221
super ().__init__ (typingctx , target )
252
222
253
223
def init (self ):
254
- self ._internal_codegen = codegen .JITSPIRVCodegen ("numba_dpex.jit" )
224
+ """Called by the super().__init__ constructor to initalize the child
225
+ class.
226
+ """
227
+ self ._internal_codegen = codegen .JITSPIRVCodegen ("numba_dpex.kernel" )
255
228
self ._target_data = ll .create_target_data (
256
229
codegen .SPIR_DATA_LAYOUT [utils .MACHINE_BITS ]
257
230
)
@@ -271,7 +244,6 @@ def init(self):
271
244
self .ufunc_db = copy .deepcopy (ufunc_db )
272
245
self .cpu_context = cpu_target .target_context
273
246
274
- # Overrides
275
247
def create_module (self , name ):
276
248
return self ._internal_codegen ._create_empty_module (name )
277
249
@@ -355,14 +327,14 @@ def mangler(self, name, argtypes, abi_tags=(), uid=None):
355
327
name + "dpex_fn" , argtypes , abi_tags = abi_tags , uid = uid
356
328
)
357
329
358
- def prepare_ocl_kernel (self , func , argtypes ):
330
+ def prepare_spir_kernel (self , func , argtypes ):
359
331
module = func .module
360
332
func .linkage = "linkonce_odr"
361
333
module .data_layout = codegen .SPIR_DATA_LAYOUT [self .address_size ]
362
- wrapper = self ._generate_kernel_wrapper (func , argtypes )
334
+ wrapper = self ._generate_spir_kernel_wrapper (func , argtypes )
363
335
return wrapper
364
336
365
- def mark_ocl_device (self , func ):
337
+ def set_spir_func_calling_conv (self , func ):
366
338
# Adapt to SPIR
367
339
func .calling_convention = CC_SPIR_FUNC
368
340
func .linkage = "linkonce_odr"
@@ -436,7 +408,6 @@ def addrspacecast(self, builder, src, addrspace):
436
408
ptras = llvmir .PointerType (src .type .pointee , addrspace = addrspace )
437
409
return builder .addrspacecast (src , ptras )
438
410
439
- # Overrides
440
411
def get_ufunc_info (self , ufunc_key ):
441
412
return self .ufunc_db [ufunc_key ]
442
413
@@ -446,7 +417,7 @@ class DpexCallConv(MinimalCallConv):
446
417
447
418
numba_dpex's calling convention derives from
448
419
:class:`numba.core.callconv import MinimalCallConv`. The
449
- :class:`DpexCallConv` overriddes :func:`call_function`.
420
+ :class:`DpexCallConv` overrides :func:`call_function`.
450
421
451
422
"""
452
423
0 commit comments