Skip to content

Commit fa9ebf5

Browse files
committed
Add local accessor multidimentional tests
1 parent 583eb9b commit fa9ebf5

File tree

1 file changed

+57
-9
lines changed

1 file changed

+57
-9
lines changed

numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_local_accessors.py

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,34 +23,82 @@
2323

2424

2525
@dpex_exp.kernel
26-
def _kernel(nd_item: NdItem, a, slm):
26+
def _kernel1(nd_item: NdItem, a, slm):
2727
i = nd_item.get_global_linear_id()
28-
j = nd_item.get_local_linear_id()
2928

30-
slm[j] = 0
29+
# TODO: overload nd_item.get_local_id()
30+
j = (nd_item.get_local_id(0),)
31+
32+
slm[*j] = 0
3133
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)
3234

3335
for m in range(100):
34-
slm[j] += i * m
36+
slm[*j] += i * m
3537
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)
3638

37-
a[i] = slm[j]
39+
a[i] = slm[*j]
40+
41+
42+
@dpex_exp.kernel
43+
def _kernel2(nd_item: NdItem, a, slm):
44+
i = nd_item.get_global_linear_id()
45+
46+
# TODO: overload nd_item.get_local_id()
47+
j = (nd_item.get_local_id(0), nd_item.get_local_id(1))
48+
49+
slm[*j] = 0
50+
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)
51+
52+
for m in range(100):
53+
slm[*j] += i * m
54+
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)
55+
56+
a[i] = slm[*j]
57+
58+
59+
@dpex_exp.kernel
60+
def _kernel3(nd_item: NdItem, a, slm):
61+
i = nd_item.get_global_linear_id()
62+
63+
# TODO: overload nd_item.get_local_id()
64+
j = (
65+
nd_item.get_local_id(0),
66+
nd_item.get_local_id(1),
67+
nd_item.get_local_id(2),
68+
)
69+
70+
slm[*j] = 0
71+
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)
72+
73+
for m in range(100):
74+
slm[*j] += i * m
75+
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)
76+
77+
a[i] = slm[*j]
3878

3979

4080
@pytest.mark.parametrize("supported_dtype", list_of_supported_dtypes)
41-
def test_local_accessor(supported_dtype):
81+
@pytest.mark.parametrize(
82+
"nd_range, _kernel",
83+
[
84+
(dpex.NdRange((32,), (32,)), _kernel1),
85+
(dpex.NdRange((32, 1), (32, 1)), _kernel2),
86+
(dpex.NdRange((1, 32, 1), (1, 32, 1)), _kernel3),
87+
],
88+
)
89+
def test_local_accessor(supported_dtype, nd_range: dpex.NdRange, _kernel):
4290
"""A test for passing a LocalAccessor object as a kernel argument."""
4391

4492
N = 32
4593
a = dpnp.empty(N, dtype=supported_dtype)
46-
slm = LocalAccessor((32 * 64), dtype=a.dtype)
94+
slm = LocalAccessor(nd_range.local_range, dtype=a.dtype)
4795

4896
# A single work group with 32 work items is launched. Each work item
4997
# computes the sum of (0..99) * its get_global_linear_id i.e.,
5098
# `4950 * get_global_linear_id` and stores it into the work groups local
5199
# memory. The local memory is of size 32*64 elements of the requested dtype.
52100
# The result is then stored into `a` in global memory
53-
dpex_exp.call_kernel(_kernel, dpex.NdRange((N,), (32,)), a, slm)
101+
dpex_exp.call_kernel(_kernel, nd_range, a, slm)
54102

55103
for idx in range(N):
56104
assert a[idx] == 4950 * idx
@@ -68,4 +116,4 @@ def test_local_accessor_argument_to_range_kernel():
68116
# A TypeError is raised if NUMBA_CAPTURED_ERROR=new_style and a
69117
# numba.TypingError is raised if NUMBA_CAPTURED_ERROR=old_style
70118
with pytest.raises((TypeError, TypingError)):
71-
dpex_exp.call_kernel(_kernel, dpex.Range(N), a, slm)
119+
dpex_exp.call_kernel(_kernel1, dpex.Range(N), a, slm)

0 commit comments

Comments
 (0)