9
9
10
10
import numba_dpex as dpex
11
11
import numba_dpex .experimental as dpex_exp
12
- from numba_dpex .kernel_api import (
13
- LocalAccessor ,
14
- MemoryScope ,
15
- NdItem ,
16
- group_barrier ,
17
- )
12
+ from numba_dpex .kernel_api import LocalAccessor , NdItem
13
+ from numba_dpex .kernel_api import call_kernel as kapi_call_kernel
18
14
from numba_dpex .tests ._helper import get_all_dtypes
19
15
20
16
list_of_supported_dtypes = get_all_dtypes (
21
17
no_bool = True , no_float16 = True , no_none = True , no_complex = True
22
18
)
23
19
24
20
25
- @dpex_exp .kernel
26
21
def _kernel1 (nd_item : NdItem , a , slm ):
27
22
i = nd_item .get_global_linear_id ()
28
23
29
24
# TODO: overload nd_item.get_local_id()
30
25
j = (nd_item .get_local_id (0 ),)
31
26
32
27
slm [j ] = 0
33
- group_barrier (nd_item .get_group (), MemoryScope .WORK_GROUP )
34
28
35
29
for m in range (100 ):
36
30
slm [j ] += i * m
37
- group_barrier (nd_item .get_group (), MemoryScope .WORK_GROUP )
38
31
39
32
a [i ] = slm [j ]
40
33
41
34
42
- @dpex_exp .kernel
43
35
def _kernel2 (nd_item : NdItem , a , slm ):
44
36
i = nd_item .get_global_linear_id ()
45
37
46
38
# TODO: overload nd_item.get_local_id()
47
39
j = (nd_item .get_local_id (0 ), nd_item .get_local_id (1 ))
48
40
49
41
slm [j ] = 0
50
- group_barrier (nd_item .get_group (), MemoryScope .WORK_GROUP )
51
42
52
43
for m in range (100 ):
53
44
slm [j ] += i * m
54
- group_barrier (nd_item .get_group (), MemoryScope .WORK_GROUP )
55
45
56
46
a [i ] = slm [j ]
57
47
58
48
59
- @dpex_exp .kernel
60
49
def _kernel3 (nd_item : NdItem , a , slm ):
61
50
i = nd_item .get_global_linear_id ()
62
51
@@ -68,15 +57,23 @@ def _kernel3(nd_item: NdItem, a, slm):
68
57
)
69
58
70
59
slm [j ] = 0
71
- group_barrier (nd_item .get_group (), MemoryScope .WORK_GROUP )
72
60
73
61
for m in range (100 ):
74
62
slm [j ] += i * m
75
- group_barrier (nd_item .get_group (), MemoryScope .WORK_GROUP )
76
63
77
64
a [i ] = slm [j ]
78
65
79
66
67
+ def device_func_kernel (func ):
68
+ _df = dpex_exp .device_func (func )
69
+
70
+ @dpex_exp .kernel
71
+ def _kernel (item , a , slm ):
72
+ _df (item , a , slm )
73
+
74
+ return _kernel
75
+
76
+
80
77
@pytest .mark .parametrize ("supported_dtype" , list_of_supported_dtypes )
81
78
@pytest .mark .parametrize (
82
79
"nd_range, _kernel" ,
@@ -86,7 +83,17 @@ def _kernel3(nd_item: NdItem, a, slm):
86
83
(dpex .NdRange ((1 , 32 , 1 ), (1 , 32 , 1 )), _kernel3 ),
87
84
],
88
85
)
89
- def test_local_accessor (supported_dtype , nd_range : dpex .NdRange , _kernel ):
86
+ @pytest .mark .parametrize (
87
+ "call_kernel, kernel" ,
88
+ [
89
+ (dpex_exp .call_kernel , dpex_exp .kernel ),
90
+ (dpex_exp .call_kernel , device_func_kernel ),
91
+ (kapi_call_kernel , lambda f : f ),
92
+ ],
93
+ )
94
+ def test_local_accessor (
95
+ supported_dtype , nd_range : dpex .NdRange , _kernel , call_kernel , kernel
96
+ ):
90
97
"""A test for passing a LocalAccessor object as a kernel argument."""
91
98
92
99
N = 32
@@ -98,7 +105,7 @@ def test_local_accessor(supported_dtype, nd_range: dpex.NdRange, _kernel):
98
105
# `4950 * get_global_linear_id` and stores it into the work groups local
99
106
# memory. The local memory is of size 32*64 elements of the requested dtype.
100
107
# The result is then stored into `a` in global memory
101
- dpex_exp . call_kernel (_kernel , nd_range , a , slm )
108
+ call_kernel (kernel ( _kernel ) , nd_range , a , slm )
102
109
103
110
for idx in range (N ):
104
111
assert a [idx ] == 4950 * idx
0 commit comments