23
23
24
24
25
25
@dpex_exp .kernel
26
- def _kernel (nd_item : NdItem , a , slm ):
26
+ def _kernel1 (nd_item : NdItem , a , slm ):
27
27
i = nd_item .get_global_linear_id ()
28
- j = nd_item .get_local_linear_id ()
28
+
29
+ # TODO: overload nd_item.get_local_id()
30
+ j = (nd_item .get_local_id (0 ),)
31
+
32
+ slm [j ] = 0
33
+ group_barrier (nd_item .get_group (), MemoryScope .WORK_GROUP )
34
+
35
+ for m in range (100 ):
36
+ slm [j ] += i * m
37
+ group_barrier (nd_item .get_group (), MemoryScope .WORK_GROUP )
38
+
39
+ a [i ] = slm [j ]
40
+
41
+
42
+ @dpex_exp .kernel
43
+ def _kernel2 (nd_item : NdItem , a , slm ):
44
+ i = nd_item .get_global_linear_id ()
45
+
46
+ # TODO: overload nd_item.get_local_id()
47
+ j = (nd_item .get_local_id (0 ), nd_item .get_local_id (1 ))
48
+
49
+ slm [j ] = 0
50
+ group_barrier (nd_item .get_group (), MemoryScope .WORK_GROUP )
51
+
52
+ for m in range (100 ):
53
+ slm [j ] += i * m
54
+ group_barrier (nd_item .get_group (), MemoryScope .WORK_GROUP )
55
+
56
+ a [i ] = slm [j ]
57
+
58
+
59
+ @dpex_exp .kernel
60
+ def _kernel3 (nd_item : NdItem , a , slm ):
61
+ i = nd_item .get_global_linear_id ()
62
+
63
+ # TODO: overload nd_item.get_local_id()
64
+ j = (
65
+ nd_item .get_local_id (0 ),
66
+ nd_item .get_local_id (1 ),
67
+ nd_item .get_local_id (2 ),
68
+ )
29
69
30
70
slm [j ] = 0
31
71
group_barrier (nd_item .get_group (), MemoryScope .WORK_GROUP )
@@ -38,19 +78,27 @@ def _kernel(nd_item: NdItem, a, slm):
38
78
39
79
40
80
@pytest .mark .parametrize ("supported_dtype" , list_of_supported_dtypes )
41
- def test_local_accessor (supported_dtype ):
81
+ @pytest .mark .parametrize (
82
+ "nd_range, _kernel" ,
83
+ [
84
+ (dpex .NdRange ((32 ,), (32 ,)), _kernel1 ),
85
+ (dpex .NdRange ((32 , 1 ), (32 , 1 )), _kernel2 ),
86
+ (dpex .NdRange ((1 , 32 , 1 ), (1 , 32 , 1 )), _kernel3 ),
87
+ ],
88
+ )
89
+ def test_local_accessor (supported_dtype , nd_range : dpex .NdRange , _kernel ):
42
90
"""A test for passing a LocalAccessor object as a kernel argument."""
43
91
44
92
N = 32
45
93
a = dpnp .empty (N , dtype = supported_dtype )
46
- slm = LocalAccessor (( 32 * 64 ) , dtype = a .dtype )
94
+ slm = LocalAccessor (nd_range . local_range , dtype = a .dtype )
47
95
48
96
# A single work group with 32 work items is launched. Each work item
49
97
# computes the sum of (0..99) * its get_global_linear_id i.e.,
50
98
# `4950 * get_global_linear_id` and stores it into the work groups local
51
99
# memory. The local memory is of size 32*64 elements of the requested dtype.
52
100
# The result is then stored into `a` in global memory
53
- dpex_exp .call_kernel (_kernel , dpex . NdRange (( N ,), ( 32 ,)) , a , slm )
101
+ dpex_exp .call_kernel (_kernel , nd_range , a , slm )
54
102
55
103
for idx in range (N ):
56
104
assert a [idx ] == 4950 * idx
@@ -68,4 +116,4 @@ def test_local_accessor_argument_to_range_kernel():
68
116
# A TypeError is raised if NUMBA_CAPTURED_ERROR=new_style and a
69
117
# numba.TypingError is raised if NUMBA_CAPTURED_ERROR=old_style
70
118
with pytest .raises ((TypeError , TypingError )):
71
- dpex_exp .call_kernel (_kernel , dpex .Range (N ), a , slm )
119
+ dpex_exp .call_kernel (_kernel1 , dpex .Range (N ), a , slm )
0 commit comments