Skip to content

Commit dca1db1

Browse files
author
Diptorup Deb
committed
Initial changes to support local_accessor arguments to kernels.
1 parent 7ab3731 commit dca1db1

File tree

4 files changed

+118
-10
lines changed

4 files changed

+118
-10
lines changed

dpctl/enum_types.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
"device_type",
2727
"backend_type",
2828
"event_status_type",
29+
"kernel_arg_type",
2930
]
3031

3132

@@ -113,3 +114,28 @@ class global_mem_cache_type(Enum):
113114
none = auto()
114115
read_only = auto()
115116
read_write = auto()
117+
118+
119+
class kernel_arg_type(Enum):
120+
"""
121+
An enumeration of supported kernel argument types in
122+
:func:`dpctl.SyclQueue.submit`
123+
"""
124+
125+
dpctl_char = auto()
126+
dpctl_signed_char = auto()
127+
dpctl_unsigned_char = auto()
128+
dpctl_short = auto()
129+
dpctl_int = auto()
130+
dpctl_unsigned_int = auto()
131+
dpctl_unsigned_int8 = auto()
132+
dpctl_long = auto()
133+
dpctl_unsigned_long = auto()
134+
dpctl_long_long = auto()
135+
dpctl_unsigned_long_long = auto()
136+
dpctl_size_t = auto()
137+
dpctl_float = auto()
138+
dpctl_double = auto()
139+
dpctl_long_double = auto()
140+
dpctl_void_ptr = auto()
141+
dpctl_local_accessor = auto()

libsyclinterface/helper/include/dpctl_error_handlers.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
///
2121
/// \file
2222
/// A functor to use for passing an error handler callback function to sycl
23-
/// context and queue contructors.
23+
/// context and queue constructors.
2424
//===----------------------------------------------------------------------===//
2525

2626
#pragma once

libsyclinterface/include/dpctl_sycl_enum_types.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ typedef enum
102102
DPCTL_FLOAT,
103103
DPCTL_DOUBLE,
104104
DPCTL_LONG_DOUBLE,
105-
DPCTL_VOID_PTR
105+
DPCTL_VOID_PTR,
106+
DPCTL_LOCAL_ACCESSOR
106107
} DPCTLKernelArgType;
107108

108109
/*!

libsyclinterface/source/dpctl_sycl_queue_interface.cpp

Lines changed: 89 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,53 @@
3232
#include "dpctl_sycl_device_manager.h"
3333
#include "dpctl_sycl_type_casters.hpp"
3434
#include <exception>
35+
#include <iostream>
3536
#include <stdexcept>
3637
#include <sycl/sycl.hpp> /* SYCL headers */
3738
#include <utility>
3839

3940
using namespace sycl;
4041

42+
#define SET_LOCAL_ACCESSOR_ARG(CGH, NDIM, ARGTY, R, IDX) \
43+
do { \
44+
switch ((ARGTY)) { \
45+
case DPCTL_LONG_LONG: \
46+
{ \
47+
auto la = local_accessor<long long, NDIM>(R, CGH); \
48+
CGH.set_arg(IDX, la); \
49+
return true; \
50+
} \
51+
case DPCTL_UNSIGNED_LONG_LONG: \
52+
{ \
53+
auto la = local_accessor<unsigned long long, NDIM>(R, CGH); \
54+
CGH.set_arg(IDX, la); \
55+
return true; \
56+
} \
57+
case DPCTL_SIZE_T: \
58+
{ \
59+
auto la = local_accessor<size_t, NDIM>(R, CGH); \
60+
CGH.set_arg(IDX, la); \
61+
return true; \
62+
} \
63+
case DPCTL_FLOAT: \
64+
{ \
65+
auto la = local_accessor<float, NDIM>(R, CGH); \
66+
CGH.set_arg(IDX, la); \
67+
return true; \
68+
} \
69+
case DPCTL_DOUBLE: \
70+
{ \
71+
auto la = local_accessor<double, NDIM>(R, CGH); \
72+
CGH.set_arg(IDX, la); \
73+
return true; \
74+
} \
75+
default: \
76+
error_handler("Kernel argument could not be created.", __FILE__, \
77+
__func__, __LINE__); \
78+
return false; \
79+
} \
80+
} while (0);
81+
4182
namespace
4283
{
4384
static_assert(__SYCL_COMPILER_VERSION >= __SYCL_COMPILER_VERSION_REQUIRED,
@@ -51,11 +92,48 @@ typedef struct complex
5192
uint64_t imag;
5293
} complexNumber;
5394

95+
typedef struct MDLocalAccessorTy
96+
{
97+
size_t ndim;
98+
DPCTLKernelArgType dpctl_type_id;
99+
size_t dim0;
100+
size_t dim1;
101+
size_t dim2;
102+
} MDLocalAccessor;
103+
104+
bool set_local_accessor_arg(handler &cgh,
105+
size_t idx,
106+
const MDLocalAccessor *mdstruct)
107+
{
108+
switch (mdstruct->ndim) {
109+
case 1:
110+
{
111+
auto r = range<1>(mdstruct->dim0);
112+
SET_LOCAL_ACCESSOR_ARG(cgh, 1, mdstruct->dpctl_type_id, r, idx)
113+
}
114+
case 2:
115+
{
116+
auto r = range<2>(mdstruct->dim0, mdstruct->dim1);
117+
SET_LOCAL_ACCESSOR_ARG(cgh, 2, mdstruct->dpctl_type_id, r, idx)
118+
}
119+
case 3:
120+
{
121+
auto r = range<3>(mdstruct->dim0, mdstruct->dim1, mdstruct->dim2);
122+
SET_LOCAL_ACCESSOR_ARG(cgh, 3, mdstruct->dpctl_type_id, r, idx)
123+
}
124+
default:
125+
return false;
126+
}
127+
}
54128
/*!
55129
* @brief Set the kernel arg object
56130
*
57-
* @param cgh My Param doc
58-
* @param Arg My Param doc
131+
* @param cgh SYCL command group handler using which a kernel is going to
132+
* be submitted.
133+
* @param idx The position of the argument in the list of arguments passed
134+
* to a kernel.
135+
* @param Arg A void* representing a kernel argument.
136+
* @param Argty A typeid specifying the C++ type of the Arg parameter.
59137
*/
60138
bool set_kernel_arg(handler &cgh,
61139
size_t idx,
@@ -113,6 +191,9 @@ bool set_kernel_arg(handler &cgh,
113191
case DPCTL_VOID_PTR:
114192
cgh.set_arg(idx, Arg);
115193
break;
194+
case DPCTL_LOCAL_ACCESSOR:
195+
arg_set = set_local_accessor_arg(cgh, idx, (MDLocalAccessor *)Arg);
196+
break;
116197
default:
117198
arg_set = false;
118199
error_handler("Kernel argument could not be created.", __FILE__,
@@ -363,9 +444,9 @@ DPCTLQueue_SubmitRange(__dpctl_keep const DPCTLSyclKernelRef KRef,
363444
cgh.depends_on(*unwrap<event>(DepEvents[i]));
364445

365446
for (auto i = 0ul; i < NArgs; ++i) {
366-
// \todo add support for Sycl buffers
367-
if (!set_kernel_arg(cgh, i, Args[i], ArgTypes[i]))
368-
exit(1);
447+
if (!set_kernel_arg(cgh, i, Args[i], ArgTypes[i])) {
448+
return nullptr;
449+
}
369450
}
370451
switch (NDims) {
371452
case 1:
@@ -418,9 +499,9 @@ DPCTLQueue_SubmitNDRange(__dpctl_keep const DPCTLSyclKernelRef KRef,
418499
}
419500

420501
for (auto i = 0ul; i < NArgs; ++i) {
421-
// \todo add support for Sycl buffers
422-
if (!set_kernel_arg(cgh, i, Args[i], ArgTypes[i]))
423-
exit(1);
502+
if (!set_kernel_arg(cgh, i, Args[i], ArgTypes[i])) {
503+
return nullptr;
504+
}
424505
}
425506
switch (NDims) {
426507
case 1:

0 commit comments

Comments
 (0)