Skip to content

Commit 86e1be7

Browse files
author
Diptorup Deb
authored
Merge pull request #1168 from IntelPython/cleanups_to_kernel_target
Cleanups to kernel target
2 parents 1ece135 + 262ed52 commit 86e1be7

File tree

5 files changed

+52
-54
lines changed

5 files changed

+52
-54
lines changed

numba_dpex/core/descriptor.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from functools import cached_property
66

7-
from numba.core import typing
7+
from numba.core import options, targetconfig, typing
88
from numba.core.cpu import CPUTargetOptions
99
from numba.core.descriptors import TargetDescriptor
1010

@@ -15,13 +15,42 @@
1515
DpexKernelTypingContext,
1616
)
1717

18+
_option_mapping = options._mapping
19+
20+
21+
def _inherit_if_not_set(flags, options, name, default=targetconfig._NotSet):
22+
if name in options:
23+
setattr(flags, name, options[name])
24+
return
25+
26+
cstk = targetconfig.ConfigStack()
27+
if cstk:
28+
# inherit
29+
top = cstk.top()
30+
if hasattr(top, name):
31+
setattr(flags, name, getattr(top, name))
32+
return
33+
34+
if default is not targetconfig._NotSet:
35+
setattr(flags, name, default)
36+
37+
38+
class DpexTargetOptions(CPUTargetOptions):
39+
experimental = _option_mapping("experimental")
40+
release_gil = _option_mapping("release_gil")
41+
42+
def finalize(self, flags, options):
43+
super().finalize(flags, options)
44+
_inherit_if_not_set(flags, options, "experimental", False)
45+
_inherit_if_not_set(flags, options, "release_gil", False)
46+
1847

1948
class DpexKernelTarget(TargetDescriptor):
2049
"""
2150
Implements a target descriptor for numba_dpex.kernel decorated functions.
2251
"""
2352

24-
options = CPUTargetOptions
53+
options = DpexTargetOptions
2554

2655
@cached_property
2756
def _toplevel_target_context(self):

numba_dpex/core/kernel_interface/func.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
"""_summary_
6-
"""
75

86
from numba.core import sigutils, types
97
from numba.core.typing.templates import AbstractTemplate, ConcreteTemplate
@@ -67,7 +65,7 @@ def compile(self, arg_types, return_types):
6765
debug=self._debug,
6866
)
6967
func = cres.library.get_function(cres.fndesc.llvm_func_name)
70-
cres.target_context.mark_ocl_device(func)
68+
cres.target_context.set_spir_func_calling_conv(func)
7169

7270
return cres
7371

@@ -159,7 +157,7 @@ def compile(self, args):
159157
debug=self._debug,
160158
)
161159
func = cres.library.get_function(cres.fndesc.llvm_func_name)
162-
cres.target_context.mark_ocl_device(func)
160+
cres.target_context.set_spir_func_calling_conv(func)
163161
libs = [cres.library]
164162

165163
cres.target_context.insert_user_function(self, cres.fndesc, libs)

numba_dpex/core/kernel_interface/spirv_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def compile(
133133
)
134134

135135
func = cres.library.get_function(cres.fndesc.llvm_func_name)
136-
kernel = cres.target_context.prepare_ocl_kernel(
136+
kernel = cres.target_context.prepare_spir_kernel(
137137
func, cres.signature.args
138138
)
139139
cres.library._optimize_final_module()

numba_dpex/core/targets/kernel_target.py

Lines changed: 17 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
import re
5+
66
from functools import cached_property
77

88
import dpnp
@@ -19,16 +19,14 @@
1919
from numba_dpex.core.datamodel.models import _init_data_model_manager
2020
from numba_dpex.core.exceptions import UnsupportedKernelArgumentError
2121
from numba_dpex.core.typeconv import to_usm_ndarray
22-
from numba_dpex.core.types import DpnpNdArray, USMNdArray
22+
from numba_dpex.core.types import USMNdArray
2323
from numba_dpex.core.utils import get_info_from_suai
2424
from numba_dpex.utils import address_space, calling_conv
2525

2626
from .. import codegen
2727

2828
CC_SPIR_KERNEL = "spir_kernel"
2929
CC_SPIR_FUNC = "spir_func"
30-
VALID_CHARS = re.compile(r"[^a-z0-9]", re.I)
31-
LINK_ATOMIC = 111
3230
LLVM_SPIRV_ARGS = 112
3331

3432

@@ -89,14 +87,15 @@ def resolve_argument_type(self, val):
8987

9088
def load_additional_registries(self):
9189
"""Register the OpenCL API and math and other functions."""
92-
from numba.core.typing import cmathdecl, npydecl
90+
from numba.core.typing import cmathdecl, enumdecl, npydecl
9391

9492
from ...ocl import mathdecl, ocldecl
9593

9694
self.install_registry(ocldecl.registry)
9795
self.install_registry(mathdecl.registry)
9896
self.install_registry(cmathdecl.registry)
9997
self.install_registry(npydecl.registry)
98+
self.install_registry(enumdecl.registry)
10099

101100

102101
class SyclDevice(GPU):
@@ -105,7 +104,7 @@ class SyclDevice(GPU):
105104
pass
106105

107106

108-
DPEX_KERNEL_TARGET_NAME = "SyclDevice"
107+
DPEX_KERNEL_TARGET_NAME = "dpex_kernel"
109108

110109
target_registry[DPEX_KERNEL_TARGET_NAME] = SyclDevice
111110

@@ -165,7 +164,7 @@ def _gen_arg_base_type(self, fn):
165164
name = llvmir.MetaDataString(mod, "kernel_arg_base_type")
166165
return mod.add_metadata([name] + consts)
167166

168-
def _finalize_wrapper_module(self, fn):
167+
def _finalize_kernel_wrapper_module(self, fn):
169168
"""Add metadata and calling convention to the wrapper function.
170169
171170
The helper function adds function metadata to the wrapper function and
@@ -177,41 +176,12 @@ def _finalize_wrapper_module(self, fn):
177176
fn: LLVM function representing the "kernel" wrapper function.
178177
179178
"""
180-
mod = fn.module
181179
# Set norecurse
182180
fn.attributes.add("norecurse")
183181
# Set SPIR kernel calling convention
184182
fn.calling_convention = CC_SPIR_KERNEL
185183

186-
# Mark kernels
187-
ocl_kernels = cgutils.get_or_insert_named_metadata(
188-
mod, "opencl.kernels"
189-
)
190-
ocl_kernels.add(
191-
mod.add_metadata(
192-
[
193-
fn,
194-
self._gen_arg_addrspace_md(fn),
195-
self._gen_arg_type(fn),
196-
self._gen_arg_type_qual(fn),
197-
self._gen_arg_base_type(fn),
198-
],
199-
)
200-
)
201-
202-
# Other metadata
203-
others = [
204-
"opencl.used.extensions",
205-
"opencl.used.optional.core.features",
206-
"opencl.compiler.options",
207-
]
208-
209-
for name in others:
210-
nmd = cgutils.get_or_insert_named_metadata(mod, name)
211-
if not nmd.operands:
212-
mod.add_metadata([])
213-
214-
def _generate_kernel_wrapper(self, func, argtypes):
184+
def _generate_spir_kernel_wrapper(self, func, argtypes):
215185
module = func.module
216186
arginfo = self.get_arg_packer(argtypes)
217187
wrapperfnty = llvmir.FunctionType(
@@ -227,7 +197,7 @@ def _generate_kernel_wrapper(self, func, argtypes):
227197
func = llvmir.Function(wrapper_module, fnty, name=func.name)
228198
func.calling_convention = CC_SPIR_FUNC
229199
wrapper = llvmir.Function(wrapper_module, wrapperfnty, name=wrappername)
230-
builder = llvmir.IRBuilder(wrapper.append_basic_block(""))
200+
builder = llvmir.IRBuilder(wrapper.append_basic_block("entry"))
231201

232202
callargs = arginfo.from_arguments(builder, wrapper.args)
233203

@@ -237,7 +207,7 @@ def _generate_kernel_wrapper(self, func, argtypes):
237207
)
238208
builder.ret_void()
239209

240-
self._finalize_wrapper_module(wrapper)
210+
self._finalize_kernel_wrapper_module(wrapper)
241211

242212
# Link the spir_func module to the wrapper module
243213
module.link_in(ll.parse_assembly(str(wrapper_module)))
@@ -251,7 +221,10 @@ def __init__(self, typingctx, target=DPEX_KERNEL_TARGET_NAME):
251221
super().__init__(typingctx, target)
252222

253223
def init(self):
254-
self._internal_codegen = codegen.JITSPIRVCodegen("numba_dpex.jit")
224+
"""Called by the super().__init__ constructor to initalize the child
225+
class.
226+
"""
227+
self._internal_codegen = codegen.JITSPIRVCodegen("numba_dpex.kernel")
255228
self._target_data = ll.create_target_data(
256229
codegen.SPIR_DATA_LAYOUT[utils.MACHINE_BITS]
257230
)
@@ -271,7 +244,6 @@ def init(self):
271244
self.ufunc_db = copy.deepcopy(ufunc_db)
272245
self.cpu_context = cpu_target.target_context
273246

274-
# Overrides
275247
def create_module(self, name):
276248
return self._internal_codegen._create_empty_module(name)
277249

@@ -355,14 +327,14 @@ def mangler(self, name, argtypes, abi_tags=(), uid=None):
355327
name + "dpex_fn", argtypes, abi_tags=abi_tags, uid=uid
356328
)
357329

358-
def prepare_ocl_kernel(self, func, argtypes):
330+
def prepare_spir_kernel(self, func, argtypes):
359331
module = func.module
360332
func.linkage = "linkonce_odr"
361333
module.data_layout = codegen.SPIR_DATA_LAYOUT[self.address_size]
362-
wrapper = self._generate_kernel_wrapper(func, argtypes)
334+
wrapper = self._generate_spir_kernel_wrapper(func, argtypes)
363335
return wrapper
364336

365-
def mark_ocl_device(self, func):
337+
def set_spir_func_calling_conv(self, func):
366338
# Adapt to SPIR
367339
func.calling_convention = CC_SPIR_FUNC
368340
func.linkage = "linkonce_odr"
@@ -436,7 +408,6 @@ def addrspacecast(self, builder, src, addrspace):
436408
ptras = llvmir.PointerType(src.type.pointee, addrspace=addrspace)
437409
return builder.addrspacecast(src, ptras)
438410

439-
# Overrides
440411
def get_ufunc_info(self, ufunc_key):
441412
return self.ufunc_db[ufunc_key]
442413

@@ -446,7 +417,7 @@ class DpexCallConv(MinimalCallConv):
446417
447418
numba_dpex's calling convention derives from
448419
:class:`numba.core.callconv import MinimalCallConv`. The
449-
:class:`DpexCallConv` overriddes :func:`call_function`.
420+
:class:`DpexCallConv` overrides :func:`call_function`.
450421
451422
"""
452423

numba_dpex/spirv_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from subprocess import CalledProcessError, check_call
1010

1111
from numba_dpex import config
12-
from numba_dpex.core.targets.kernel_target import LINK_ATOMIC, LLVM_SPIRV_ARGS
12+
from numba_dpex.core.targets.kernel_target import LLVM_SPIRV_ARGS
1313

1414

1515
def _raise_bad_env_path(msg, path, extra=None):

0 commit comments

Comments
 (0)