Skip to content

Cleanups to kernel target #1168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions numba_dpex/core/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from functools import cached_property

from numba.core import typing
from numba.core import options, targetconfig, typing
from numba.core.cpu import CPUTargetOptions
from numba.core.descriptors import TargetDescriptor

Expand All @@ -15,13 +15,42 @@
DpexKernelTypingContext,
)

_option_mapping = options._mapping


def _inherit_if_not_set(flags, options, name, default=targetconfig._NotSet):
if name in options:
setattr(flags, name, options[name])
return

cstk = targetconfig.ConfigStack()
if cstk:
# inherit
top = cstk.top()
if hasattr(top, name):
setattr(flags, name, getattr(top, name))
return

if default is not targetconfig._NotSet:
setattr(flags, name, default)


class DpexTargetOptions(CPUTargetOptions):
experimental = _option_mapping("experimental")
release_gil = _option_mapping("release_gil")

def finalize(self, flags, options):
super().finalize(flags, options)
_inherit_if_not_set(flags, options, "experimental", False)
_inherit_if_not_set(flags, options, "release_gil", False)


class DpexKernelTarget(TargetDescriptor):
"""
Implements a target descriptor for numba_dpex.kernel decorated functions.
"""

options = CPUTargetOptions
options = DpexTargetOptions

@cached_property
def _toplevel_target_context(self):
Expand Down
6 changes: 2 additions & 4 deletions numba_dpex/core/kernel_interface/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
#
# SPDX-License-Identifier: Apache-2.0

"""_summary_
"""

from numba.core import sigutils, types
from numba.core.typing.templates import AbstractTemplate, ConcreteTemplate
Expand Down Expand Up @@ -67,7 +65,7 @@ def compile(self, arg_types, return_types):
debug=self._debug,
)
func = cres.library.get_function(cres.fndesc.llvm_func_name)
cres.target_context.mark_ocl_device(func)
cres.target_context.set_spir_func_calling_conv(func)

return cres

Expand Down Expand Up @@ -159,7 +157,7 @@ def compile(self, args):
debug=self._debug,
)
func = cres.library.get_function(cres.fndesc.llvm_func_name)
cres.target_context.mark_ocl_device(func)
cres.target_context.set_spir_func_calling_conv(func)
libs = [cres.library]

cres.target_context.insert_user_function(self, cres.fndesc, libs)
Expand Down
2 changes: 1 addition & 1 deletion numba_dpex/core/kernel_interface/spirv_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def compile(
)

func = cres.library.get_function(cres.fndesc.llvm_func_name)
kernel = cres.target_context.prepare_ocl_kernel(
kernel = cres.target_context.prepare_spir_kernel(
func, cres.signature.args
)
cres.library._optimize_final_module()
Expand Down
63 changes: 17 additions & 46 deletions numba_dpex/core/targets/kernel_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import re

from functools import cached_property

import dpnp
Expand All @@ -19,16 +19,14 @@
from numba_dpex.core.datamodel.models import _init_data_model_manager
from numba_dpex.core.exceptions import UnsupportedKernelArgumentError
from numba_dpex.core.typeconv import to_usm_ndarray
from numba_dpex.core.types import DpnpNdArray, USMNdArray
from numba_dpex.core.types import USMNdArray
from numba_dpex.core.utils import get_info_from_suai
from numba_dpex.utils import address_space, calling_conv

from .. import codegen

CC_SPIR_KERNEL = "spir_kernel"
CC_SPIR_FUNC = "spir_func"
VALID_CHARS = re.compile(r"[^a-z0-9]", re.I)
LINK_ATOMIC = 111
LLVM_SPIRV_ARGS = 112


Expand Down Expand Up @@ -89,14 +87,15 @@ def resolve_argument_type(self, val):

def load_additional_registries(self):
"""Register the OpenCL API and math and other functions."""
from numba.core.typing import cmathdecl, npydecl
from numba.core.typing import cmathdecl, enumdecl, npydecl

from ...ocl import mathdecl, ocldecl

self.install_registry(ocldecl.registry)
self.install_registry(mathdecl.registry)
self.install_registry(cmathdecl.registry)
self.install_registry(npydecl.registry)
self.install_registry(enumdecl.registry)


class SyclDevice(GPU):
Expand All @@ -105,7 +104,7 @@ class SyclDevice(GPU):
pass


DPEX_KERNEL_TARGET_NAME = "SyclDevice"
DPEX_KERNEL_TARGET_NAME = "dpex_kernel"

target_registry[DPEX_KERNEL_TARGET_NAME] = SyclDevice

Expand Down Expand Up @@ -165,7 +164,7 @@ def _gen_arg_base_type(self, fn):
name = llvmir.MetaDataString(mod, "kernel_arg_base_type")
return mod.add_metadata([name] + consts)

def _finalize_wrapper_module(self, fn):
def _finalize_kernel_wrapper_module(self, fn):
"""Add metadata and calling convention to the wrapper function.

The helper function adds function metadata to the wrapper function and
Expand All @@ -177,41 +176,12 @@ def _finalize_wrapper_module(self, fn):
fn: LLVM function representing the "kernel" wrapper function.

"""
mod = fn.module
# Set norecurse
fn.attributes.add("norecurse")
# Set SPIR kernel calling convention
fn.calling_convention = CC_SPIR_KERNEL

# Mark kernels
ocl_kernels = cgutils.get_or_insert_named_metadata(
mod, "opencl.kernels"
)
ocl_kernels.add(
mod.add_metadata(
[
fn,
self._gen_arg_addrspace_md(fn),
self._gen_arg_type(fn),
self._gen_arg_type_qual(fn),
self._gen_arg_base_type(fn),
],
)
)

# Other metadata
others = [
"opencl.used.extensions",
"opencl.used.optional.core.features",
"opencl.compiler.options",
]

for name in others:
nmd = cgutils.get_or_insert_named_metadata(mod, name)
if not nmd.operands:
mod.add_metadata([])

def _generate_kernel_wrapper(self, func, argtypes):
def _generate_spir_kernel_wrapper(self, func, argtypes):
module = func.module
arginfo = self.get_arg_packer(argtypes)
wrapperfnty = llvmir.FunctionType(
Expand All @@ -227,7 +197,7 @@ def _generate_kernel_wrapper(self, func, argtypes):
func = llvmir.Function(wrapper_module, fnty, name=func.name)
func.calling_convention = CC_SPIR_FUNC
wrapper = llvmir.Function(wrapper_module, wrapperfnty, name=wrappername)
builder = llvmir.IRBuilder(wrapper.append_basic_block(""))
builder = llvmir.IRBuilder(wrapper.append_basic_block("entry"))

callargs = arginfo.from_arguments(builder, wrapper.args)

Expand All @@ -237,7 +207,7 @@ def _generate_kernel_wrapper(self, func, argtypes):
)
builder.ret_void()

self._finalize_wrapper_module(wrapper)
self._finalize_kernel_wrapper_module(wrapper)

# Link the spir_func module to the wrapper module
module.link_in(ll.parse_assembly(str(wrapper_module)))
Expand All @@ -251,7 +221,10 @@ def __init__(self, typingctx, target=DPEX_KERNEL_TARGET_NAME):
super().__init__(typingctx, target)

def init(self):
self._internal_codegen = codegen.JITSPIRVCodegen("numba_dpex.jit")
"""Called by the super().__init__ constructor to initalize the child
class.
"""
self._internal_codegen = codegen.JITSPIRVCodegen("numba_dpex.kernel")
self._target_data = ll.create_target_data(
codegen.SPIR_DATA_LAYOUT[utils.MACHINE_BITS]
)
Expand All @@ -271,7 +244,6 @@ def init(self):
self.ufunc_db = copy.deepcopy(ufunc_db)
self.cpu_context = cpu_target.target_context

# Overrides
def create_module(self, name):
return self._internal_codegen._create_empty_module(name)

Expand Down Expand Up @@ -355,14 +327,14 @@ def mangler(self, name, argtypes, abi_tags=(), uid=None):
name + "dpex_fn", argtypes, abi_tags=abi_tags, uid=uid
)

def prepare_ocl_kernel(self, func, argtypes):
def prepare_spir_kernel(self, func, argtypes):
module = func.module
func.linkage = "linkonce_odr"
module.data_layout = codegen.SPIR_DATA_LAYOUT[self.address_size]
wrapper = self._generate_kernel_wrapper(func, argtypes)
wrapper = self._generate_spir_kernel_wrapper(func, argtypes)
return wrapper

def mark_ocl_device(self, func):
def set_spir_func_calling_conv(self, func):
# Adapt to SPIR
func.calling_convention = CC_SPIR_FUNC
func.linkage = "linkonce_odr"
Expand Down Expand Up @@ -436,7 +408,6 @@ def addrspacecast(self, builder, src, addrspace):
ptras = llvmir.PointerType(src.type.pointee, addrspace=addrspace)
return builder.addrspacecast(src, ptras)

# Overrides
def get_ufunc_info(self, ufunc_key):
return self.ufunc_db[ufunc_key]

Expand All @@ -446,7 +417,7 @@ class DpexCallConv(MinimalCallConv):

numba_dpex's calling convention derives from
:class:`numba.core.callconv import MinimalCallConv`. The
:class:`DpexCallConv` overriddes :func:`call_function`.
:class:`DpexCallConv` overrides :func:`call_function`.

"""

Expand Down
2 changes: 1 addition & 1 deletion numba_dpex/spirv_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from subprocess import CalledProcessError, check_call

from numba_dpex import config
from numba_dpex.core.targets.kernel_target import LINK_ATOMIC, LLVM_SPIRV_ARGS
from numba_dpex.core.targets.kernel_target import LLVM_SPIRV_ARGS


def _raise_bad_env_path(msg, path, extra=None):
Expand Down