Skip to content

Commit 547382d

Browse files
author
Diptorup Deb
authored
Merge pull request #1112 from IntelPython/fix/kernel_func_name
Generate proper mangled name for kernel functions
2 parents 4d332e8 + d3d7ef6 commit 547382d

File tree

4 files changed

+21
-9
lines changed

4 files changed

+21
-9
lines changed

numba_dpex/core/codegen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def _optimize_final_module(self):
3535
pmb.opt_level = config.OPT
3636

3737
pmb.disable_unit_at_a_time = False
38+
pmb.inlining_threshold = 2
3839
pmb.disable_unroll_loops = True
3940
pmb.loop_vectorize = False
4041
pmb.slp_vectorize = False

numba_dpex/core/kernel_interface/spirv_kernel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def compile(
136136
kernel = cres.target_context.prepare_ocl_kernel(
137137
func, cres.signature.args
138138
)
139+
cres.library._optimize_final_module()
139140
self._llvm_module = kernel.module.__str__()
140141
self._module_name = kernel.name
141142

numba_dpex/core/targets/kernel_target.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from llvmlite import binding as ll
1010
from llvmlite import ir as llvmir
1111
from numba import typeof
12-
from numba.core import cgutils, types, typing, utils
12+
from numba.core import cgutils, funcdesc, types, typing, utils
1313
from numba.core.base import BaseContext
1414
from numba.core.callconv import MinimalCallConv
1515
from numba.core.registry import cpu_target
@@ -240,7 +240,7 @@ def _generate_kernel_wrapper(self, func, argtypes):
240240
llvmir.VoidType(), arginfo.argument_types
241241
)
242242
wrapper_module = self.create_module("dpex.kernel.wrapper")
243-
wrappername = "dpexPy_{name}".format(name=func.name)
243+
wrappername = func.name.replace("dpex_fn", "dpex_kernel")
244244
argtys = list(arginfo.argument_types)
245245
fnty = llvmir.FunctionType(
246246
llvmir.IntType(32),
@@ -373,13 +373,9 @@ def target_data(self):
373373
return self._target_data
374374

375375
def mangler(self, name, argtypes, abi_tags=(), uid=None):
376-
def repl(m):
377-
ch = m.group(0)
378-
return "_%X_" % ord(ch)
379-
380-
qualified = name + "." + ".".join(str(a) for a in argtypes)
381-
mangled = VALID_CHARS.sub(repl, qualified)
382-
return "dpex_py_devfn_" + mangled
376+
return funcdesc.default_mangler(
377+
name + "dpex_fn", argtypes, abi_tags=abi_tags, uid=uid
378+
)
383379

384380
def prepare_ocl_kernel(self, func, argtypes):
385381
module = func.module

numba_dpex/core/types/usm_ndarray_type.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,3 +222,17 @@ def as_array(self):
222222
@property
223223
def box_type(self):
224224
return dpctl.tensor.usm_ndarray
225+
226+
@property
227+
def mangling_args(self):
228+
"""Returns a list of parameters used to create a mangled name for a
229+
USMNdArray type.
230+
"""
231+
filter_str_splits = self.device.split(":")
232+
args = [
233+
self.dtype,
234+
self.ndim,
235+
self.layout,
236+
filter_str_splits[0] + "_" + filter_str_splits[1],
237+
]
238+
return self.__class__.__name__, args

0 commit comments

Comments
 (0)