2
2
#
3
3
# SPDX-License-Identifier: Apache-2.0
4
4
5
- import re
5
+
6
6
from functools import cached_property
7
7
8
8
import dpnp
19
19
from numba_dpex .core .datamodel .models import _init_data_model_manager
20
20
from numba_dpex .core .exceptions import UnsupportedKernelArgumentError
21
21
from numba_dpex .core .typeconv import to_usm_ndarray
22
- from numba_dpex .core .types import DpnpNdArray , USMNdArray
22
+ from numba_dpex .core .types import USMNdArray
23
23
from numba_dpex .core .utils import get_info_from_suai
24
24
from numba_dpex .utils import address_space , calling_conv
25
25
26
26
from .. import codegen
27
27
28
28
CC_SPIR_KERNEL = "spir_kernel"
29
29
CC_SPIR_FUNC = "spir_func"
30
- VALID_CHARS = re .compile (r"[^a-z0-9]" , re .I )
31
- LINK_ATOMIC = 111
32
30
LLVM_SPIRV_ARGS = 112
33
31
34
32
@@ -105,7 +103,7 @@ class SyclDevice(GPU):
105
103
pass
106
104
107
105
108
- DPEX_KERNEL_TARGET_NAME = "SyclDevice "
106
+ DPEX_KERNEL_TARGET_NAME = "dpex_kernel "
109
107
110
108
target_registry [DPEX_KERNEL_TARGET_NAME ] = SyclDevice
111
109
@@ -165,7 +163,7 @@ def _gen_arg_base_type(self, fn):
165
163
name = llvmir .MetaDataString (mod , "kernel_arg_base_type" )
166
164
return mod .add_metadata ([name ] + consts )
167
165
168
- def _finalize_wrapper_module (self , fn ):
166
+ def _finalize_kernel_wrapper_module (self , fn ):
169
167
"""Add metadata and calling convention to the wrapper function.
170
168
171
169
The helper function adds function metadata to the wrapper function and
@@ -177,41 +175,12 @@ def _finalize_wrapper_module(self, fn):
177
175
fn: LLVM function representing the "kernel" wrapper function.
178
176
179
177
"""
180
- mod = fn .module
181
178
# Set norecurse
182
179
fn .attributes .add ("norecurse" )
183
180
# Set SPIR kernel calling convention
184
181
fn .calling_convention = CC_SPIR_KERNEL
185
182
186
- # Mark kernels
187
- ocl_kernels = cgutils .get_or_insert_named_metadata (
188
- mod , "opencl.kernels"
189
- )
190
- ocl_kernels .add (
191
- mod .add_metadata (
192
- [
193
- fn ,
194
- self ._gen_arg_addrspace_md (fn ),
195
- self ._gen_arg_type (fn ),
196
- self ._gen_arg_type_qual (fn ),
197
- self ._gen_arg_base_type (fn ),
198
- ],
199
- )
200
- )
201
-
202
- # Other metadata
203
- others = [
204
- "opencl.used.extensions" ,
205
- "opencl.used.optional.core.features" ,
206
- "opencl.compiler.options" ,
207
- ]
208
-
209
- for name in others :
210
- nmd = cgutils .get_or_insert_named_metadata (mod , name )
211
- if not nmd .operands :
212
- mod .add_metadata ([])
213
-
214
- def _generate_kernel_wrapper (self , func , argtypes ):
183
+ def _generate_spir_kernel_wrapper (self , func , argtypes ):
215
184
module = func .module
216
185
arginfo = self .get_arg_packer (argtypes )
217
186
wrapperfnty = llvmir .FunctionType (
@@ -227,7 +196,7 @@ def _generate_kernel_wrapper(self, func, argtypes):
227
196
func = llvmir .Function (wrapper_module , fnty , name = func .name )
228
197
func .calling_convention = CC_SPIR_FUNC
229
198
wrapper = llvmir .Function (wrapper_module , wrapperfnty , name = wrappername )
230
- builder = llvmir .IRBuilder (wrapper .append_basic_block ("" ))
199
+ builder = llvmir .IRBuilder (wrapper .append_basic_block ("entry " ))
231
200
232
201
callargs = arginfo .from_arguments (builder , wrapper .args )
233
202
@@ -237,7 +206,7 @@ def _generate_kernel_wrapper(self, func, argtypes):
237
206
)
238
207
builder .ret_void ()
239
208
240
- self ._finalize_wrapper_module (wrapper )
209
+ self ._finalize_kernel_wrapper_module (wrapper )
241
210
242
211
# Link the spir_func module to the wrapper module
243
212
module .link_in (ll .parse_assembly (str (wrapper_module )))
@@ -251,7 +220,10 @@ def __init__(self, typingctx, target=DPEX_KERNEL_TARGET_NAME):
251
220
super ().__init__ (typingctx , target )
252
221
253
222
def init (self ):
254
- self ._internal_codegen = codegen .JITSPIRVCodegen ("numba_dpex.jit" )
223
+ """Called by the super().__init__ constructor to initalize the child
224
+ class.
225
+ """
226
+ self ._internal_codegen = codegen .JITSPIRVCodegen ("numba_dpex.kernel" )
255
227
self ._target_data = ll .create_target_data (
256
228
codegen .SPIR_DATA_LAYOUT [utils .MACHINE_BITS ]
257
229
)
@@ -271,7 +243,6 @@ def init(self):
271
243
self .ufunc_db = copy .deepcopy (ufunc_db )
272
244
self .cpu_context = cpu_target .target_context
273
245
274
- # Overrides
275
246
def create_module (self , name ):
276
247
return self ._internal_codegen ._create_empty_module (name )
277
248
@@ -355,14 +326,14 @@ def mangler(self, name, argtypes, abi_tags=(), uid=None):
355
326
name + "dpex_fn" , argtypes , abi_tags = abi_tags , uid = uid
356
327
)
357
328
358
- def prepare_ocl_kernel (self , func , argtypes ):
329
+ def prepare_spir_kernel (self , func , argtypes ):
359
330
module = func .module
360
331
func .linkage = "linkonce_odr"
361
332
module .data_layout = codegen .SPIR_DATA_LAYOUT [self .address_size ]
362
- wrapper = self ._generate_kernel_wrapper (func , argtypes )
333
+ wrapper = self ._generate_spir_kernel_wrapper (func , argtypes )
363
334
return wrapper
364
335
365
- def mark_ocl_device (self , func ):
336
+ def set_spir_func_calling_conv (self , func ):
366
337
# Adapt to SPIR
367
338
func .calling_convention = CC_SPIR_FUNC
368
339
func .linkage = "linkonce_odr"
@@ -436,7 +407,6 @@ def addrspacecast(self, builder, src, addrspace):
436
407
ptras = llvmir .PointerType (src .type .pointee , addrspace = addrspace )
437
408
return builder .addrspacecast (src , ptras )
438
409
439
- # Overrides
440
410
def get_ufunc_info (self , ufunc_key ):
441
411
return self .ufunc_db [ufunc_key ]
442
412
@@ -446,7 +416,7 @@ class DpexCallConv(MinimalCallConv):
446
416
447
417
numba_dpex's calling convention derives from
448
418
:class:`numba.core.callconv import MinimalCallConv`. The
449
- :class:`DpexCallConv` overriddes :func:`call_function`.
419
+ :class:`DpexCallConv` overrides :func:`call_function`.
450
420
451
421
"""
452
422
0 commit comments