Skip to content

Commit 8cad614

Browse files
committed
Add local accessor multidimentional tests
1 parent c086ae6 commit 8cad614

File tree

1 file changed

+54
-6
lines changed

1 file changed

+54
-6
lines changed

numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_local_accessors.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,49 @@
2323

2424

2525
@dpex_exp.kernel
26-
def _kernel(nd_item: NdItem, a, slm):
26+
def _kernel1(nd_item: NdItem, a, slm):
2727
i = nd_item.get_global_linear_id()
28-
j = nd_item.get_local_linear_id()
28+
29+
# TODO: overload nd_item.get_local_id()
30+
j = (nd_item.get_local_id(0),)
31+
32+
slm[j] = 0
33+
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)
34+
35+
for m in range(100):
36+
slm[j] += i * m
37+
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)
38+
39+
a[i] = slm[j]
40+
41+
42+
@dpex_exp.kernel
43+
def _kernel2(nd_item: NdItem, a, slm):
44+
i = nd_item.get_global_linear_id()
45+
46+
# TODO: overload nd_item.get_local_id()
47+
j = (nd_item.get_local_id(0), nd_item.get_local_id(1))
48+
49+
slm[j] = 0
50+
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)
51+
52+
for m in range(100):
53+
slm[j] += i * m
54+
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)
55+
56+
a[i] = slm[j]
57+
58+
59+
@dpex_exp.kernel
60+
def _kernel3(nd_item: NdItem, a, slm):
61+
i = nd_item.get_global_linear_id()
62+
63+
# TODO: overload nd_item.get_local_id()
64+
j = (
65+
nd_item.get_local_id(0),
66+
nd_item.get_local_id(1),
67+
nd_item.get_local_id(2),
68+
)
2969

3070
slm[j] = 0
3171
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)
@@ -38,19 +78,27 @@ def _kernel(nd_item: NdItem, a, slm):
3878

3979

4080
@pytest.mark.parametrize("supported_dtype", list_of_supported_dtypes)
41-
def test_local_accessor(supported_dtype):
81+
@pytest.mark.parametrize(
82+
"nd_range, _kernel",
83+
[
84+
(dpex.NdRange((32,), (32,)), _kernel1),
85+
(dpex.NdRange((32, 1), (32, 1)), _kernel2),
86+
(dpex.NdRange((1, 32, 1), (1, 32, 1)), _kernel3),
87+
],
88+
)
89+
def test_local_accessor(supported_dtype, nd_range: dpex.NdRange, _kernel):
4290
"""A test for passing a LocalAccessor object as a kernel argument."""
4391

4492
N = 32
4593
a = dpnp.empty(N, dtype=supported_dtype)
46-
slm = LocalAccessor((32 * 64), dtype=a.dtype)
94+
slm = LocalAccessor(nd_range.local_range, dtype=a.dtype)
4795

4896
# A single work group with 32 work items is launched. Each work item
4997
# computes the sum of (0..99) * its get_global_linear_id i.e.,
5098
# `4950 * get_global_linear_id` and stores it into the work groups local
5199
# memory. The local memory is of size 32*64 elements of the requested dtype.
52100
# The result is then stored into `a` in global memory
53-
dpex_exp.call_kernel(_kernel, dpex.NdRange((N,), (32,)), a, slm)
101+
dpex_exp.call_kernel(_kernel, nd_range, a, slm)
54102

55103
for idx in range(N):
56104
assert a[idx] == 4950 * idx
@@ -68,4 +116,4 @@ def test_local_accessor_argument_to_range_kernel():
68116
# A TypeError is raised if NUMBA_CAPTURED_ERROR=new_style and a
69117
# numba.TypingError is raised if NUMBA_CAPTURED_ERROR=old_style
70118
with pytest.raises((TypeError, TypingError)):
71-
dpex_exp.call_kernel(_kernel, dpex.Range(N), a, slm)
119+
dpex_exp.call_kernel(_kernel1, dpex.Range(N), a, slm)

0 commit comments

Comments
 (0)