32
32
#include " dpctl_sycl_device_manager.h"
33
33
#include " dpctl_sycl_type_casters.hpp"
34
34
#include < exception>
35
+ #include < iostream>
35
36
#include < stdexcept>
36
37
#include < sycl/sycl.hpp> /* SYCL headers */
37
38
#include < utility>
38
39
39
40
using namespace sycl ;
40
41
42
+ #define SET_LOCAL_ACCESSOR_ARG (CGH, NDIM, ARGTY, R, IDX ) \
43
+ do { \
44
+ switch ((ARGTY)) { \
45
+ case DPCTL_LONG_LONG: \
46
+ { \
47
+ auto la = local_accessor<long long , NDIM>(R, CGH); \
48
+ CGH.set_arg (IDX, la); \
49
+ return true ; \
50
+ } \
51
+ case DPCTL_UNSIGNED_LONG_LONG: \
52
+ { \
53
+ auto la = local_accessor<unsigned long long , NDIM>(R, CGH); \
54
+ CGH.set_arg (IDX, la); \
55
+ return true ; \
56
+ } \
57
+ case DPCTL_SIZE_T: \
58
+ { \
59
+ auto la = local_accessor<size_t , NDIM>(R, CGH); \
60
+ CGH.set_arg (IDX, la); \
61
+ return true ; \
62
+ } \
63
+ case DPCTL_FLOAT: \
64
+ { \
65
+ auto la = local_accessor<float , NDIM>(R, CGH); \
66
+ CGH.set_arg (IDX, la); \
67
+ return true ; \
68
+ } \
69
+ case DPCTL_DOUBLE: \
70
+ { \
71
+ auto la = local_accessor<double , NDIM>(R, CGH); \
72
+ CGH.set_arg (IDX, la); \
73
+ return true ; \
74
+ } \
75
+ default : \
76
+ error_handler (" Kernel argument could not be created." , __FILE__, \
77
+ __func__, __LINE__); \
78
+ return false ; \
79
+ } \
80
+ } while (0 );
81
+
41
82
namespace
42
83
{
43
84
static_assert (__SYCL_COMPILER_VERSION >= __SYCL_COMPILER_VERSION_REQUIRED,
@@ -51,11 +92,48 @@ typedef struct complex
51
92
uint64_t imag;
52
93
} complexNumber;
53
94
95
+ typedef struct MDLocalAccessorTy
96
+ {
97
+ size_t ndim;
98
+ DPCTLKernelArgType dpctl_type_id;
99
+ size_t dim0;
100
+ size_t dim1;
101
+ size_t dim2;
102
+ } MDLocalAccessor;
103
+
104
+ bool set_local_accessor_arg (handler &cgh,
105
+ size_t idx,
106
+ const MDLocalAccessor *mdstruct)
107
+ {
108
+ switch (mdstruct->ndim ) {
109
+ case 1 :
110
+ {
111
+ auto r = range<1 >(mdstruct->dim0 );
112
+ SET_LOCAL_ACCESSOR_ARG (cgh, 1 , mdstruct->dpctl_type_id , r, idx)
113
+ }
114
+ case 2 :
115
+ {
116
+ auto r = range<2 >(mdstruct->dim0 , mdstruct->dim1 );
117
+ SET_LOCAL_ACCESSOR_ARG (cgh, 2 , mdstruct->dpctl_type_id , r, idx)
118
+ }
119
+ case 3 :
120
+ {
121
+ auto r = range<3 >(mdstruct->dim0 , mdstruct->dim1 , mdstruct->dim2 );
122
+ SET_LOCAL_ACCESSOR_ARG (cgh, 3 , mdstruct->dpctl_type_id , r, idx)
123
+ }
124
+ default :
125
+ return false ;
126
+ }
127
+ }
54
128
/* !
55
129
* @brief Set the kernel arg object
56
130
*
57
- * @param cgh My Param doc
58
- * @param Arg My Param doc
131
+ * @param cgh SYCL command group handler using which a kernel is going to
132
+ * be submitted.
133
+ * @param idx The position of the argument in the list of arguments passed
134
+ * to a kernel.
135
+ * @param Arg A void* representing a kernel argument.
136
+ * @param Argty A typeid specifying the C++ type of the Arg parameter.
59
137
*/
60
138
bool set_kernel_arg (handler &cgh,
61
139
size_t idx,
@@ -113,6 +191,9 @@ bool set_kernel_arg(handler &cgh,
113
191
case DPCTL_VOID_PTR:
114
192
cgh.set_arg (idx, Arg);
115
193
break ;
194
+ case DPCTL_LOCAL_ACCESSOR:
195
+ arg_set = set_local_accessor_arg (cgh, idx, (MDLocalAccessor *)Arg);
196
+ break ;
116
197
default :
117
198
arg_set = false ;
118
199
error_handler (" Kernel argument could not be created." , __FILE__,
@@ -363,9 +444,9 @@ DPCTLQueue_SubmitRange(__dpctl_keep const DPCTLSyclKernelRef KRef,
363
444
cgh.depends_on (*unwrap<event>(DepEvents[i]));
364
445
365
446
for (auto i = 0ul ; i < NArgs; ++i) {
366
- // \todo add support for Sycl buffers
367
- if (! set_kernel_arg (cgh, i, Args[i], ArgTypes[i]))
368
- exit ( 1 );
447
+ if (! set_kernel_arg (cgh, i, Args[i], ArgTypes[i])) {
448
+ return nullptr ;
449
+ }
369
450
}
370
451
switch (NDims) {
371
452
case 1 :
@@ -418,9 +499,9 @@ DPCTLQueue_SubmitNDRange(__dpctl_keep const DPCTLSyclKernelRef KRef,
418
499
}
419
500
420
501
for (auto i = 0ul ; i < NArgs; ++i) {
421
- // \todo add support for Sycl buffers
422
- if (! set_kernel_arg (cgh, i, Args[i], ArgTypes[i]))
423
- exit ( 1 );
502
+ if (! set_kernel_arg (cgh, i, Args[i], ArgTypes[i])) {
503
+ return nullptr ;
504
+ }
424
505
}
425
506
switch (NDims) {
426
507
case 1 :
0 commit comments