Skip to content

Commit 0437184

Browse files
author
Diptorup Deb
committed
Add a codegen unit test for local accessor kernel arg.
1 parent bb3234f commit 0437184

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# SPDX-FileCopyrightText: 2023 - 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import dpctl
6+
from llvmlite import ir as llvmir
7+
from numba.core import types
8+
9+
from numba_dpex import DpctlSyclQueue, DpnpNdArray
10+
from numba_dpex import experimental as dpex_exp
11+
from numba_dpex import int64
12+
from numba_dpex.core.types.kernel_api.index_space_ids import NdItemType
13+
from numba_dpex.core.types.kernel_api.local_accessor import LocalAccessorType
14+
from numba_dpex.kernel_api import (
15+
AddressSpace,
16+
MemoryScope,
17+
NdItem,
18+
group_barrier,
19+
)
20+
21+
22+
def kernel_func(nd_item: NdItem, a, slm):
23+
i = nd_item.get_global_linear_id()
24+
j = nd_item.get_local_linear_id()
25+
26+
slm[j] = 100
27+
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)
28+
29+
a[i] += slm[j]
30+
31+
32+
def test_codegen_local_accessor_kernel_arg():
33+
"""Tests if a kernel with a local accessor argument is generated with
34+
expected local address space pointer argument.
35+
"""
36+
37+
queue_ty = DpctlSyclQueue(dpctl.SyclQueue())
38+
i64arr_ty = DpnpNdArray(ndim=1, dtype=int64, layout="C", queue=queue_ty)
39+
slm_ty = LocalAccessorType(ndim=1, dtype=int64)
40+
disp = dpex_exp.kernel(inline_threshold=3)(kernel_func)
41+
dmm = disp.targetctx.data_model_manager
42+
43+
i64arr_ty_flattened_arg_count = dmm.lookup(i64arr_ty).flattened_field_count
44+
slm_ty_model = dmm.lookup(slm_ty)
45+
slm_ty_flattened_arg_count = slm_ty_model.flattened_field_count
46+
slm_ptr_pos = slm_ty_model.get_field_position("data")
47+
48+
llargtys = disp.targetctx.get_arg_packer([i64arr_ty, slm_ty]).argument_types
49+
50+
# Go over all the arguments to the spir_kernel_func and assert two things:
51+
# a) Number of arguments == i64arr_ty_flattened_arg_count
52+
# + slm_ty_flattened_arg_count
53+
# b) The argument corresponding to the data attribute of the local accessor
54+
# argument is a pointer in address space local address space
55+
56+
num_kernel_args = 0
57+
slm_data_ptr_arg = None
58+
for kernel_arg in llargtys:
59+
if num_kernel_args == i64arr_ty_flattened_arg_count + slm_ptr_pos:
60+
slm_data_ptr_arg = kernel_arg
61+
num_kernel_args += 1
62+
assert (
63+
num_kernel_args
64+
== i64arr_ty_flattened_arg_count + slm_ty_flattened_arg_count
65+
)
66+
assert isinstance(slm_data_ptr_arg, llvmir.PointerType)
67+
assert slm_data_ptr_arg.addrspace == AddressSpace.LOCAL

0 commit comments

Comments
 (0)