Skip to content

Commit ae0a1b9

Browse files
author
Diptorup Deb
committed
Disallow LocalAccessor arguments to RangeType kernels
1 parent eb8d7ac commit ae0a1b9

File tree

2 files changed

+59
-12
lines changed

2 files changed

+59
-12
lines changed

numba_dpex/experimental/launcher.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
ItemType,
2626
NdItemType,
2727
)
28+
from numba_dpex.core.types.kernel_api.local_accessor import LocalAccessorType
2829
from numba_dpex.core.utils import kernel_launcher as kl
2930
from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl
3031
from numba_dpex.dpctl_iface.wrappers import wrap_event_reference
@@ -42,6 +43,23 @@ class _LLRange(NamedTuple):
4243
local_range_extents: list
4344

4445

46+
def _has_a_local_accessor_argument(args):
47+
"""Checks if there exists at least one LocalAccessorType object in the
48+
input tuple.
49+
50+
Args:
51+
args (_type_): A tuple of numba.core.Type objects
52+
53+
Returns:
54+
bool : True if at least one LocalAccessorType object was found,
55+
otherwise False.
56+
"""
57+
for arg in args:
58+
if isinstance(arg, LocalAccessorType):
59+
return True
60+
return False
61+
62+
4563
def _wrap_event_reference_tuple(ctx, builder, event1, event2):
4664
"""Creates tuple data model from two event data models, so it can be
4765
boxed to Python."""
@@ -153,6 +171,18 @@ def _submit_kernel( # pylint: disable=too-many-arguments
153171
DeprecationWarning,
154172
)
155173

174+
# Validate local accessor arguments are passed only to a kernel that is
175+
# launched with an NdRange index space. Reference section 4.7.6.11. of the
176+
# SYCL 2020 specification: A local_accessor must not be used in a SYCL
177+
# kernel function that is invoked via single_task or via the simple form of
178+
# parallel_for that takes a range parameter.
179+
if _has_a_local_accessor_argument(ty_kernel_args_tuple) and isinstance(
180+
ty_index_space, RangeType
181+
):
182+
raise TypeError(
183+
"A RangeType kernel cannot have a LocalAccessor argument"
184+
)
185+
156186
# ty_kernel_fn is type specific to exact function, so we can get function
157187
# directly from type and compile it. Thats why we don't need to get it in
158188
# codegen

numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_local_accessors.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import dpnp
77
import pytest
8+
from numba.core.errors import TypingError
89

910
import numba_dpex as dpex
1011
import numba_dpex.experimental as dpex_exp
@@ -21,23 +22,24 @@
2122
)
2223

2324

24-
@pytest.mark.parametrize("supported_dtype", list_of_supported_dtypes)
25-
def test_local_accessor(supported_dtype):
26-
"""A test for passing a LocalAccessor object as a kernel argument."""
25+
@dpex_exp.kernel
26+
def _kernel(nd_item: NdItem, a, slm):
27+
i = nd_item.get_global_linear_id()
28+
j = nd_item.get_local_linear_id()
2729

28-
@dpex_exp.kernel
29-
def _kernel(nd_item: NdItem, a, slm):
30-
i = nd_item.get_global_linear_id()
31-
j = nd_item.get_local_linear_id()
30+
slm[j] = 0
31+
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)
3232

33-
slm[j] = 0
33+
for m in range(100):
34+
slm[j] += i * m
3435
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)
3536

36-
for m in range(100):
37-
slm[j] += i * m
38-
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)
37+
a[i] = slm[j]
3938

40-
a[i] = slm[j]
39+
40+
@pytest.mark.parametrize("supported_dtype", list_of_supported_dtypes)
41+
def test_local_accessor(supported_dtype):
42+
"""A test for passing a LocalAccessor object as a kernel argument."""
4143

4244
N = 32
4345
a = dpnp.empty(N, dtype=supported_dtype)
@@ -52,3 +54,18 @@ def _kernel(nd_item: NdItem, a, slm):
5254

5355
for idx in range(N):
5456
assert a[idx] == 4950 * idx
57+
58+
59+
def test_local_accessor_argument_to_range_kernel():
60+
"""Checks if an exception is raised when passing a local accessor to a
61+
RangeType kernel.
62+
"""
63+
N = 32
64+
a = dpnp.empty(N)
65+
slm = LocalAccessor((32 * 64), dtype=a.dtype)
66+
67+
# Passing a local_accessor to a RangeType kernel should raise an exception.
68+
# A TypeError is raised if NUMBA_CAPTURED_ERROR=new_style and a
69+
# numba.TypingError is raised if NUMBA_CAPTURED_ERROR=old_style
70+
with pytest.raises((TypeError, TypingError)):
71+
dpex_exp.call_kernel(_kernel, dpex.Range(N), a, slm)

0 commit comments

Comments
 (0)