Skip to content

Commit f666abe

Browse files
author
Diptorup Deb
committed
Fix mangled name generation for kernel functions.
1 parent 4d332e8 commit f666abe

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

numba_dpex/core/targets/kernel_target.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from llvmlite import binding as ll
1010
from llvmlite import ir as llvmir
1111
from numba import typeof
12-
from numba.core import cgutils, types, typing, utils
12+
from numba.core import cgutils, funcdesc, types, typing, utils
1313
from numba.core.base import BaseContext
1414
from numba.core.callconv import MinimalCallConv
1515
from numba.core.registry import cpu_target
@@ -240,7 +240,7 @@ def _generate_kernel_wrapper(self, func, argtypes):
240240
llvmir.VoidType(), arginfo.argument_types
241241
)
242242
wrapper_module = self.create_module("dpex.kernel.wrapper")
243-
wrappername = "dpexPy_{name}".format(name=func.name)
243+
wrappername = func.name.replace("dpex_fn", "dpex_kernel")
244244
argtys = list(arginfo.argument_types)
245245
fnty = llvmir.FunctionType(
246246
llvmir.IntType(32),
@@ -373,13 +373,9 @@ def target_data(self):
373373
return self._target_data
374374

375375
def mangler(self, name, argtypes, abi_tags=(), uid=None):
376-
def repl(m):
377-
ch = m.group(0)
378-
return "_%X_" % ord(ch)
379-
380-
qualified = name + "." + ".".join(str(a) for a in argtypes)
381-
mangled = VALID_CHARS.sub(repl, qualified)
382-
return "dpex_py_devfn_" + mangled
376+
return funcdesc.default_mangler(
377+
name + "dpex_fn", argtypes, abi_tags=abi_tags, uid=uid
378+
)
383379

384380
def prepare_ocl_kernel(self, func, argtypes):
385381
module = func.module

numba_dpex/core/types/usm_ndarray_type.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,3 +222,17 @@ def as_array(self):
222222
@property
223223
def box_type(self):
224224
return dpctl.tensor.usm_ndarray
225+
226+
@property
227+
def mangling_args(self):
228+
"""Returns a list of parameters used to create a mangled name for a
229+
USMNdArray type.
230+
"""
231+
filter_str_splits = self.device.split(":")
232+
args = [
233+
self.dtype,
234+
self.ndim,
235+
self.layout,
236+
filter_str_splits[0] + "_" + filter_str_splits[1],
237+
]
238+
return self.__class__.__name__, args

0 commit comments

Comments
 (0)