Skip to content

Commit bb3234f

Browse files
author
Diptorup Deb
committed
Add a unit test for local accessor
1 parent 0dd344e commit bb3234f

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# SPDX-FileCopyrightText: 2023 - 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
6+
import dpnp
7+
import pytest
8+
9+
import numba_dpex as dpex
10+
import numba_dpex.experimental as dpex_exp
11+
from numba_dpex.kernel_api import (
12+
LocalAccessor,
13+
MemoryScope,
14+
NdItem,
15+
group_barrier,
16+
)
17+
from numba_dpex.tests._helper import get_all_dtypes
18+
19+
list_of_supported_dtypes = get_all_dtypes(
20+
no_bool=True, no_float16=True, no_none=True, no_complex=True
21+
)
22+
23+
24+
@pytest.mark.parametrize("supported_dtype", list_of_supported_dtypes)
25+
def test_local_accessor(supported_dtype):
26+
"""A test for passing a LocalAccessor object as a kernel argument."""
27+
28+
@dpex_exp.kernel
29+
def _kernel(nd_item: NdItem, a, slm):
30+
i = nd_item.get_global_linear_id()
31+
j = nd_item.get_local_linear_id()
32+
33+
slm[j] = 0
34+
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)
35+
36+
for m in range(100):
37+
slm[j] += i * m
38+
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)
39+
40+
a[i] = slm[j]
41+
42+
N = 32
43+
a = dpnp.empty(N, dtype=supported_dtype)
44+
slm = LocalAccessor((32 * 64), dtype=a.dtype)
45+
46+
# A single work group with 32 work items is launched. Each work item
47+
# computes the sum of (0..99) * its get_global_linear_id i.e.,
48+
# `4950 * get_global_linear_id` and stores it into the work groups local
49+
# memory. The local memory is of size 32*64 elements of the requested dtype.
50+
# The result is then stored into `a` in global memory
51+
dpex_exp.call_kernel(_kernel, dpex.NdRange((N,), (32,)), a, slm)
52+
53+
for idx in range(N):
54+
assert a[idx] == 4950 * idx

0 commit comments

Comments
 (0)