Skip to content

Commit aa787db

Browse files
author
Diptorup Deb
committed
Cleaning up kernel_target.
- Renames functions to better indicate usage. (e.g., s/ocl/spir/g) - Renamed DPEX_KERNEL_TARGET_NAME to "dpex_kernel" from "SyclDevice". - Removes unused globals: VALID_CHARS, LINK_ATOMIC. - Removes unused imports. - Minor typo clean up and removes superfluous comments.
1 parent 004ca5f commit aa787db

File tree

4 files changed

+19
-51
lines changed

4 files changed

+19
-51
lines changed

numba_dpex/core/kernel_interface/func.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
"""_summary_
6-
"""
75

86
from numba.core import sigutils, types
97
from numba.core.typing.templates import AbstractTemplate, ConcreteTemplate
@@ -67,7 +65,7 @@ def compile(self, arg_types, return_types):
6765
debug=self._debug,
6866
)
6967
func = cres.library.get_function(cres.fndesc.llvm_func_name)
70-
cres.target_context.mark_ocl_device(func)
68+
cres.target_context.set_spir_func_calling_conv(func)
7169

7270
return cres
7371

@@ -159,7 +157,7 @@ def compile(self, args):
159157
debug=self._debug,
160158
)
161159
func = cres.library.get_function(cres.fndesc.llvm_func_name)
162-
cres.target_context.mark_ocl_device(func)
160+
cres.target_context.set_spir_func_calling_conv(func)
163161
libs = [cres.library]
164162

165163
cres.target_context.insert_user_function(self, cres.fndesc, libs)

numba_dpex/core/kernel_interface/spirv_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def compile(
133133
)
134134

135135
func = cres.library.get_function(cres.fndesc.llvm_func_name)
136-
kernel = cres.target_context.prepare_ocl_kernel(
136+
kernel = cres.target_context.prepare_spir_kernel(
137137
func, cres.signature.args
138138
)
139139
cres.library._optimize_final_module()

numba_dpex/core/targets/kernel_target.py

Lines changed: 15 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
import re
5+
66
from functools import cached_property
77

88
import dpnp
@@ -19,16 +19,14 @@
1919
from numba_dpex.core.datamodel.models import _init_data_model_manager
2020
from numba_dpex.core.exceptions import UnsupportedKernelArgumentError
2121
from numba_dpex.core.typeconv import to_usm_ndarray
22-
from numba_dpex.core.types import DpnpNdArray, USMNdArray
22+
from numba_dpex.core.types import USMNdArray
2323
from numba_dpex.core.utils import get_info_from_suai
2424
from numba_dpex.utils import address_space, calling_conv
2525

2626
from .. import codegen
2727

2828
CC_SPIR_KERNEL = "spir_kernel"
2929
CC_SPIR_FUNC = "spir_func"
30-
VALID_CHARS = re.compile(r"[^a-z0-9]", re.I)
31-
LINK_ATOMIC = 111
3230
LLVM_SPIRV_ARGS = 112
3331

3432

@@ -105,7 +103,7 @@ class SyclDevice(GPU):
105103
pass
106104

107105

108-
DPEX_KERNEL_TARGET_NAME = "SyclDevice"
106+
DPEX_KERNEL_TARGET_NAME = "dpex_kernel"
109107

110108
target_registry[DPEX_KERNEL_TARGET_NAME] = SyclDevice
111109

@@ -165,7 +163,7 @@ def _gen_arg_base_type(self, fn):
165163
name = llvmir.MetaDataString(mod, "kernel_arg_base_type")
166164
return mod.add_metadata([name] + consts)
167165

168-
def _finalize_wrapper_module(self, fn):
166+
def _finalize_kernel_wrapper_module(self, fn):
169167
"""Add metadata and calling convention to the wrapper function.
170168
171169
The helper function adds function metadata to the wrapper function and
@@ -177,41 +175,12 @@ def _finalize_wrapper_module(self, fn):
177175
fn: LLVM function representing the "kernel" wrapper function.
178176
179177
"""
180-
mod = fn.module
181178
# Set norecurse
182179
fn.attributes.add("norecurse")
183180
# Set SPIR kernel calling convention
184181
fn.calling_convention = CC_SPIR_KERNEL
185182

186-
# Mark kernels
187-
ocl_kernels = cgutils.get_or_insert_named_metadata(
188-
mod, "opencl.kernels"
189-
)
190-
ocl_kernels.add(
191-
mod.add_metadata(
192-
[
193-
fn,
194-
self._gen_arg_addrspace_md(fn),
195-
self._gen_arg_type(fn),
196-
self._gen_arg_type_qual(fn),
197-
self._gen_arg_base_type(fn),
198-
],
199-
)
200-
)
201-
202-
# Other metadata
203-
others = [
204-
"opencl.used.extensions",
205-
"opencl.used.optional.core.features",
206-
"opencl.compiler.options",
207-
]
208-
209-
for name in others:
210-
nmd = cgutils.get_or_insert_named_metadata(mod, name)
211-
if not nmd.operands:
212-
mod.add_metadata([])
213-
214-
def _generate_kernel_wrapper(self, func, argtypes):
183+
def _generate_spir_kernel_wrapper(self, func, argtypes):
215184
module = func.module
216185
arginfo = self.get_arg_packer(argtypes)
217186
wrapperfnty = llvmir.FunctionType(
@@ -227,7 +196,7 @@ def _generate_kernel_wrapper(self, func, argtypes):
227196
func = llvmir.Function(wrapper_module, fnty, name=func.name)
228197
func.calling_convention = CC_SPIR_FUNC
229198
wrapper = llvmir.Function(wrapper_module, wrapperfnty, name=wrappername)
230-
builder = llvmir.IRBuilder(wrapper.append_basic_block(""))
199+
builder = llvmir.IRBuilder(wrapper.append_basic_block("entry"))
231200

232201
callargs = arginfo.from_arguments(builder, wrapper.args)
233202

@@ -237,7 +206,7 @@ def _generate_kernel_wrapper(self, func, argtypes):
237206
)
238207
builder.ret_void()
239208

240-
self._finalize_wrapper_module(wrapper)
209+
self._finalize_kernel_wrapper_module(wrapper)
241210

242211
# Link the spir_func module to the wrapper module
243212
module.link_in(ll.parse_assembly(str(wrapper_module)))
@@ -251,7 +220,10 @@ def __init__(self, typingctx, target=DPEX_KERNEL_TARGET_NAME):
251220
super().__init__(typingctx, target)
252221

253222
def init(self):
254-
self._internal_codegen = codegen.JITSPIRVCodegen("numba_dpex.jit")
223+
"""Called by the super().__init__ constructor to initalize the child
224+
class.
225+
"""
226+
self._internal_codegen = codegen.JITSPIRVCodegen("numba_dpex.kernel")
255227
self._target_data = ll.create_target_data(
256228
codegen.SPIR_DATA_LAYOUT[utils.MACHINE_BITS]
257229
)
@@ -271,7 +243,6 @@ def init(self):
271243
self.ufunc_db = copy.deepcopy(ufunc_db)
272244
self.cpu_context = cpu_target.target_context
273245

274-
# Overrides
275246
def create_module(self, name):
276247
return self._internal_codegen._create_empty_module(name)
277248

@@ -355,14 +326,14 @@ def mangler(self, name, argtypes, abi_tags=(), uid=None):
355326
name + "dpex_fn", argtypes, abi_tags=abi_tags, uid=uid
356327
)
357328

358-
def prepare_ocl_kernel(self, func, argtypes):
329+
def prepare_spir_kernel(self, func, argtypes):
359330
module = func.module
360331
func.linkage = "linkonce_odr"
361332
module.data_layout = codegen.SPIR_DATA_LAYOUT[self.address_size]
362-
wrapper = self._generate_kernel_wrapper(func, argtypes)
333+
wrapper = self._generate_spir_kernel_wrapper(func, argtypes)
363334
return wrapper
364335

365-
def mark_ocl_device(self, func):
336+
def set_spir_func_calling_conv(self, func):
366337
# Adapt to SPIR
367338
func.calling_convention = CC_SPIR_FUNC
368339
func.linkage = "linkonce_odr"
@@ -436,7 +407,6 @@ def addrspacecast(self, builder, src, addrspace):
436407
ptras = llvmir.PointerType(src.type.pointee, addrspace=addrspace)
437408
return builder.addrspacecast(src, ptras)
438409

439-
# Overrides
440410
def get_ufunc_info(self, ufunc_key):
441411
return self.ufunc_db[ufunc_key]
442412

@@ -446,7 +416,7 @@ class DpexCallConv(MinimalCallConv):
446416
447417
numba_dpex's calling convention derives from
448418
:class:`numba.core.callconv import MinimalCallConv`. The
449-
:class:`DpexCallConv` overriddes :func:`call_function`.
419+
:class:`DpexCallConv` overrides :func:`call_function`.
450420
451421
"""
452422

numba_dpex/spirv_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from subprocess import CalledProcessError, check_call
1010

1111
from numba_dpex import config
12-
from numba_dpex.core.targets.kernel_target import LINK_ATOMIC, LLVM_SPIRV_ARGS
12+
from numba_dpex.core.targets.kernel_target import LLVM_SPIRV_ARGS
1313

1414

1515
def _raise_bad_env_path(msg, path, extra=None):

0 commit comments

Comments
 (0)