Skip to content

Allow setting local_accessors as kernel arguments in DPCTLQueue_Submit #1558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions dpctl/enum_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,7 @@
"""
from enum import Enum, auto

__all__ = [
"device_type",
"backend_type",
"event_status_type",
]
__all__ = ["device_type", "backend_type", "event_status_type"]


class device_type(Enum):
Expand Down
2 changes: 1 addition & 1 deletion libsyclinterface/helper/include/dpctl_error_handlers.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
///
/// \file
/// A functor to use for passing an error handler callback function to sycl
/// context and queue contructors.
/// context and queue constructors.
//===----------------------------------------------------------------------===//

#pragma once
Expand Down
1 change: 1 addition & 0 deletions libsyclinterface/include/dpctl_sycl_enum_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ typedef enum
DPCTL_FLOAT32_T,
DPCTL_FLOAT64_T,
DPCTL_VOID_PTR,
DPCTL_LOCAL_ACCESSOR,
DPCTL_UNSUPPORTED_KERNEL_ARG
} DPCTLKernelArgType;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ _GetKernel_ze_impl(const kernel_bundle<bundle_state::executable> &kb,
else {
error_handler("Kernel named " + std::string(kernel_name) +
" could not be found.",
__FILE__, __func__, __LINE__);
__FILE__, __func__, __LINE__, error_level::error);
return nullptr;
}
}
Expand All @@ -541,7 +541,7 @@ bool _HasKernel_ze_impl(const kernel_bundle<bundle_state::executable> &kb,
auto zeKernelCreateFn = get_zeKernelCreate();
if (zeKernelCreateFn == nullptr) {
error_handler("Could not load zeKernelCreate function.", __FILE__,
__func__, __LINE__);
__func__, __LINE__, error_level::error);
return false;
}

Expand All @@ -564,7 +564,7 @@ bool _HasKernel_ze_impl(const kernel_bundle<bundle_state::executable> &kb,
if (ze_status != ZE_RESULT_ERROR_INVALID_KERNEL_NAME) {
error_handler("zeKernelCreate failed: " +
_GetErrorCode_ze_impl(ze_status),
__FILE__, __func__, __LINE__);
__FILE__, __func__, __LINE__, error_level::error);
return false;
}
}
Expand Down
116 changes: 112 additions & 4 deletions libsyclinterface/source/dpctl_sycl_queue_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,76 @@

using namespace sycl;

#define SET_LOCAL_ACCESSOR_ARG(CGH, NDIM, ARGTY, R, IDX) \
do { \
switch ((ARGTY)) { \
case DPCTL_INT8_T: \
{ \
auto la = local_accessor<int8_t, NDIM>(R, CGH); \
CGH.set_arg(IDX, la); \
return true; \
} \
case DPCTL_UINT8_T: \
{ \
auto la = local_accessor<uint8_t, NDIM>(R, CGH); \
CGH.set_arg(IDX, la); \
return true; \
} \
case DPCTL_INT16_T: \
{ \
auto la = local_accessor<int16_t, NDIM>(R, CGH); \
CGH.set_arg(IDX, la); \
return true; \
} \
case DPCTL_UINT16_T: \
{ \
auto la = local_accessor<uint16_t, NDIM>(R, CGH); \
CGH.set_arg(IDX, la); \
return true; \
} \
case DPCTL_INT32_T: \
{ \
auto la = local_accessor<int32_t, NDIM>(R, CGH); \
CGH.set_arg(IDX, la); \
return true; \
} \
case DPCTL_UINT32_T: \
{ \
auto la = local_accessor<uint32_t, NDIM>(R, CGH); \
CGH.set_arg(IDX, la); \
return true; \
} \
case DPCTL_INT64_T: \
{ \
auto la = local_accessor<int64_t, NDIM>(R, CGH); \
CGH.set_arg(IDX, la); \
return true; \
} \
case DPCTL_UINT64_T: \
{ \
auto la = local_accessor<uint64_t, NDIM>(R, CGH); \
CGH.set_arg(IDX, la); \
return true; \
} \
case DPCTL_FLOAT32_T: \
{ \
auto la = local_accessor<float, NDIM>(R, CGH); \
CGH.set_arg(IDX, la); \
return true; \
} \
case DPCTL_FLOAT64_T: \
{ \
auto la = local_accessor<double, NDIM>(R, CGH); \
CGH.set_arg(IDX, la); \
return true; \
} \
default: \
error_handler("Kernel argument could not be created.", __FILE__, \
__func__, __LINE__, error_level::error); \
return false; \
} \
} while (0);

namespace
{
static_assert(__SYCL_COMPILER_VERSION >= __SYCL_COMPILER_VERSION_REQUIRED,
Expand All @@ -51,6 +121,15 @@ typedef struct complex
uint64_t imag;
} complexNumber;

typedef struct MDLocalAccessorTy
{
size_t ndim;
DPCTLKernelArgType dpctl_type_id;
size_t dim0;
size_t dim1;
size_t dim2;
} MDLocalAccessor;

void set_dependent_events(handler &cgh,
__dpctl_keep const DPCTLSyclEventRef *DepEvents,
size_t NDepEvents)
Expand All @@ -62,11 +141,39 @@ void set_dependent_events(handler &cgh,
}
}

bool set_local_accessor_arg(handler &cgh,
size_t idx,
const MDLocalAccessor *mdstruct)
{
switch (mdstruct->ndim) {
case 1:
{
auto r = range<1>(mdstruct->dim0);
SET_LOCAL_ACCESSOR_ARG(cgh, 1, mdstruct->dpctl_type_id, r, idx);
}
case 2:
{
auto r = range<2>(mdstruct->dim0, mdstruct->dim1);
SET_LOCAL_ACCESSOR_ARG(cgh, 2, mdstruct->dpctl_type_id, r, idx);
}
case 3:
{
auto r = range<3>(mdstruct->dim0, mdstruct->dim1, mdstruct->dim2);
SET_LOCAL_ACCESSOR_ARG(cgh, 3, mdstruct->dpctl_type_id, r, idx);
}
default:
return false;
}
}
/*!
* @brief Set the kernel arg object
*
* @param cgh My Param doc
* @param Arg My Param doc
* @param cgh SYCL command group handler using which a kernel is going to
* be submitted.
* @param idx The position of the argument in the list of arguments passed
* to a kernel.
* @param Arg A void* representing a kernel argument.
* @param Argty A typeid specifying the C++ type of the Arg parameter.
*/
bool set_kernel_arg(handler &cgh,
size_t idx,
Expand Down Expand Up @@ -109,10 +216,11 @@ bool set_kernel_arg(handler &cgh,
case DPCTL_VOID_PTR:
cgh.set_arg(idx, Arg);
break;
case DPCTL_LOCAL_ACCESSOR:
arg_set = set_local_accessor_arg(cgh, idx, (MDLocalAccessor *)Arg);
break;
default:
arg_set = false;
error_handler("Kernel argument could not be created.", __FILE__,
__func__, __LINE__);
break;
}
return arg_set;
Expand Down
3 changes: 3 additions & 0 deletions libsyclinterface/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ set(spirv-test-files
multi_kernel.spv
oneD_range_kernel_inttys_fp32.spv
oneD_range_kernel_fp64.spv
local_accessor_kernel_inttys_fp32.spv
local_accessor_kernel_fp64.spv
)

foreach(tf ${spirv-test-files})
Expand Down Expand Up @@ -55,6 +57,7 @@ add_sycl_to_target(
${CMAKE_CURRENT_SOURCE_DIR}/test_sycl_platform_invalid_filters.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_sycl_queue_manager.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_sycl_queue_submit.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_sycl_queue_submit_local_accessor_arg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_sycl_queue_interface.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_sycl_usm_interface.cpp
)
Expand Down
Binary file not shown.
Binary file not shown.
Loading