Skip to content

Commit d9d33b3

Browse files
committed
Add local accessor device func and python simulator tests
1 parent fa9ebf5 commit d9d33b3

File tree

1 file changed

+24
-17
lines changed

1 file changed

+24
-17
lines changed

numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_local_accessors.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,54 +9,43 @@
99

1010
import numba_dpex as dpex
1111
import numba_dpex.experimental as dpex_exp
12-
from numba_dpex.kernel_api import (
13-
LocalAccessor,
14-
MemoryScope,
15-
NdItem,
16-
group_barrier,
17-
)
12+
from numba_dpex.kernel_api import LocalAccessor, NdItem
13+
from numba_dpex.kernel_api import call_kernel as kapi_call_kernel
1814
from numba_dpex.tests._helper import get_all_dtypes
1915

2016
list_of_supported_dtypes = get_all_dtypes(
2117
no_bool=True, no_float16=True, no_none=True, no_complex=True
2218
)
2319

2420

25-
@dpex_exp.kernel
2621
def _kernel1(nd_item: NdItem, a, slm):
2722
i = nd_item.get_global_linear_id()
2823

2924
# TODO: overload nd_item.get_local_id()
3025
j = (nd_item.get_local_id(0),)
3126

3227
slm[*j] = 0
33-
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)
3428

3529
for m in range(100):
3630
slm[*j] += i * m
37-
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)
3831

3932
a[i] = slm[*j]
4033

4134

42-
@dpex_exp.kernel
4335
def _kernel2(nd_item: NdItem, a, slm):
4436
i = nd_item.get_global_linear_id()
4537

4638
# TODO: overload nd_item.get_local_id()
4739
j = (nd_item.get_local_id(0), nd_item.get_local_id(1))
4840

4941
slm[*j] = 0
50-
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)
5142

5243
for m in range(100):
5344
slm[*j] += i * m
54-
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)
5545

5646
a[i] = slm[*j]
5747

5848

59-
@dpex_exp.kernel
6049
def _kernel3(nd_item: NdItem, a, slm):
6150
i = nd_item.get_global_linear_id()
6251

@@ -68,15 +57,23 @@ def _kernel3(nd_item: NdItem, a, slm):
6857
)
6958

7059
slm[*j] = 0
71-
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)
7260

7361
for m in range(100):
7462
slm[*j] += i * m
75-
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)
7663

7764
a[i] = slm[*j]
7865

7966

67+
def device_func_kernel(func):
68+
_df = dpex_exp.device_func(func)
69+
70+
@dpex_exp.kernel
71+
def _kernel(item, a, slm):
72+
_df(item, a, slm)
73+
74+
return _kernel
75+
76+
8077
@pytest.mark.parametrize("supported_dtype", list_of_supported_dtypes)
8178
@pytest.mark.parametrize(
8279
"nd_range, _kernel",
@@ -86,7 +83,17 @@ def _kernel3(nd_item: NdItem, a, slm):
8683
(dpex.NdRange((1, 32, 1), (1, 32, 1)), _kernel3),
8784
],
8885
)
89-
def test_local_accessor(supported_dtype, nd_range: dpex.NdRange, _kernel):
86+
@pytest.mark.parametrize(
87+
"call_kernel, kernel",
88+
[
89+
(dpex_exp.call_kernel, dpex_exp.kernel),
90+
(dpex_exp.call_kernel, device_func_kernel),
91+
(kapi_call_kernel, lambda f: f),
92+
],
93+
)
94+
def test_local_accessor(
95+
supported_dtype, nd_range: dpex.NdRange, _kernel, call_kernel, kernel
96+
):
9097
"""A test for passing a LocalAccessor object as a kernel argument."""
9198

9299
N = 32
@@ -98,7 +105,7 @@ def test_local_accessor(supported_dtype, nd_range: dpex.NdRange, _kernel):
98105
# `4950 * get_global_linear_id` and stores it into the work groups local
99106
# memory. The local memory is of size 32*64 elements of the requested dtype.
100107
# The result is then stored into `a` in global memory
101-
dpex_exp.call_kernel(_kernel, nd_range, a, slm)
108+
call_kernel(kernel(_kernel), nd_range, a, slm)
102109

103110
for idx in range(N):
104111
assert a[idx] == 4950 * idx

0 commit comments

Comments
 (0)