Skip to content

[SYCL] get kernel info with free functions #18866

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
2 changes: 2 additions & 0 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,8 @@ bool SemaSYCL::isFreeFunction(const FunctionDecl *FD) {
NameValuePair.first == "sycl-single-task-kernel";
});
IsFreeFunctionAttr = it != NameValuePairs.end();
if (IsFreeFunctionAttr)
break;
}
if (Redecl->isFirstDecl()) {
if (IsFreeFunctionAttr)
Expand Down
57 changes: 57 additions & 0 deletions clang/test/CodeGenSYCL/free_function_int_header.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,16 @@ __attribute__((sycl_device))
void ff_20(sycl::accessor<int, 1, sycl::access::mode::read_write> acc) {
}

[[__sycl_detail__::add_ir_attributes_function("work_group_size", 16)]]
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 0)]]
void ff_21(AliasType start, AliasType *ptr) {
}

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 0)]]
[[__sycl_detail__::add_ir_attributes_function("work_group_size", 16)]]
void ff_22(AliasType start, AliasType *ptr) {
}

// CHECK: const char* const kernel_names[] = {
// CHECK-NEXT: {{.*}}__sycl_kernel_ff_2Piii
// CHECK-NEXT: {{.*}}__sycl_kernel_ff_2Piiii
Expand Down Expand Up @@ -286,6 +296,8 @@ void ff_20(sycl::accessor<int, 1, sycl::access::mode::read_write> acc) {
// CHECK-NEXT: {{.*}}__sycl_kernel_free_functions5tests5ff_18ENS_3AggEPS1_
// CHECK-NEXT: {{.*}}__sycl_kernel_ff_19N14free_functions16KArgWithPtrArrayILi50EEE
// CHECK-NEXT: {{.*}}__sycl_kernel_ff_20N4sycl3_V18accessorIiLi1ELNS0_6access4modeE1026ELNS2_6targetE2014ELNS2_11placeholderE0ENS0_3ext6oneapi22accessor_property_listIJEEEEE
// CHECK-NEXT: {{.*}}__sycl_kernel_ff_217DerivedPS_
// CHECK-NEXT: {{.*}}__sycl_kernel_ff_227DerivedPS_

// CHECK-NEXT: ""
// CHECK-NEXT: };
Expand Down Expand Up @@ -980,6 +992,37 @@ void ff_20(sycl::accessor<int, 1, sycl::access::mode::read_write> acc) {
// CHECK-NEXT: };
// CHECK-NEXT: }


// CHECK: void ff_21(Derived start, Derived * ptr);
// CHECK-NEXT: static constexpr auto __sycl_shim30() {
// CHECK-NEXT: return (void (*)(struct Derived, struct Derived *))ff_21;
// CHECK-NEXT: }
// CHECK-NEXT: namespace sycl {
// CHECK-NEXT: template <>
// CHECK-NEXT: struct ext::oneapi::experimental::is_kernel<__sycl_shim30()> {
// CHECK-NEXT: static constexpr bool value = true;
// CHECK-NEXT: };
// CHECK-NEXT: template <>
// CHECK-NEXT: struct ext::oneapi::experimental::is_single_task_kernel<__sycl_shim30()> {
// CHECK-NEXT: static constexpr bool value = true;
// CHECK-NEXT: };
// CHECK-NEXT: }

// CHECK: void ff_22(Derived start, Derived * ptr);
// CHECK-NEXT: static constexpr auto __sycl_shim31() {
// CHECK-NEXT: return (void (*)(struct Derived, struct Derived *))ff_22;
// CHECK-NEXT: }
// CHECK-NEXT: namespace sycl {
// CHECK-NEXT: template <>
// CHECK-NEXT: struct ext::oneapi::experimental::is_kernel<__sycl_shim31()> {
// CHECK-NEXT: static constexpr bool value = true;
// CHECK-NEXT: };
// CHECK-NEXT: template <>
// CHECK-NEXT: struct ext::oneapi::experimental::is_single_task_kernel<__sycl_shim31()> {
// CHECK-NEXT: static constexpr bool value = true;
// CHECK-NEXT: };
// CHECK-NEXT: }

// CHECK: #include <sycl/kernel_bundle.hpp>

// CHECK: Definition of kernel_id of _Z18__sycl_kernel_ff_2Piii
Expand Down Expand Up @@ -1196,3 +1239,17 @@ void ff_20(sycl::accessor<int, 1, sycl::access::mode::read_write> acc) {
// CHECK-NEXT: return sycl::detail::get_kernel_id_impl(std::string_view{"_Z19__sycl_kernel_ff_20N4sycl3_V18accessorIiLi1ELNS0_6access4modeE1026ELNS2_6targetE2014ELNS2_11placeholderE0ENS0_3ext6oneapi22accessor_property_listIJEEEEE"});
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: namespace sycl {
// CHECK-NEXT: template <>
// CHECK-NEXT: kernel_id ext::oneapi::experimental::get_kernel_id<__sycl_shim30()>() {
// CHECK-NEXT: return sycl::detail::get_kernel_id_impl(std::string_view{"_Z19__sycl_kernel_ff_217DerivedPS_"});
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: namespace sycl {
// CHECK-NEXT: template <>
// CHECK-NEXT: kernel_id ext::oneapi::experimental::get_kernel_id<__sycl_shim31()>() {
// CHECK-NEXT: return sycl::detail::get_kernel_id_impl(std::string_view{"_Z19__sycl_kernel_ff_227DerivedPS_"});
// CHECK-NEXT: }
// CHECK-NEXT: }
33 changes: 33 additions & 0 deletions sycl/include/sycl/ext/oneapi/get_kernel_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <sycl/detail/export.hpp>
#include <sycl/detail/info_desc_helpers.hpp>
#include <sycl/device.hpp>
#include <sycl/kernel_bundle.hpp>
#include <sycl/kernel_bundle_enums.hpp>
#include <sycl/queue.hpp>

Expand Down Expand Up @@ -53,6 +54,38 @@ get_kernel_info(const queue &Q) {
Q.get_device());
}

// For free functions.
namespace experimental {

template <auto *Func, typename Param>
std::enable_if_t<ext::oneapi::experimental::is_kernel_v<Func>,
typename sycl::detail::is_kernel_info_desc<Param>::return_type>
get_kernel_info(const context &ctxt) {
auto Bundle = sycl::ext::oneapi::experimental::get_kernel_bundle<
Func, sycl::bundle_state::executable>(ctxt);
return Bundle.template ext_oneapi_get_kernel<Func>()
.template get_info<Param>();
}

template <auto *Func, typename Param>
std::enable_if_t<ext::oneapi::experimental::is_kernel_v<Func>,
typename sycl::detail::is_kernel_device_specific_info_desc<
Param>::return_type>
get_kernel_info(const context &ctxt, const device &dev) {
auto Bundle = sycl::ext::oneapi::experimental::get_kernel_bundle<
Func, sycl::bundle_state::executable>(ctxt);
return Bundle.template ext_oneapi_get_kernel<Func>().template get_info<Param>(
dev);
}

template <auto *Func, typename Param>
std::enable_if_t<ext::oneapi::experimental::is_kernel_v<Func>,
typename sycl::detail::is_kernel_device_specific_info_desc<
Param>::return_type>
get_kernel_info(const queue &q) {
return get_kernel_info<Func, Param>(q.get_context(), q.get_device());
}
} // namespace experimental
} // namespace ext::oneapi
} // namespace _V1
} // namespace sycl
132 changes: 132 additions & 0 deletions sycl/test-e2e/FreeFunctionKernels/get_kernel_info.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
// REQUIRES: aspect-usm_shared_allocations
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

// XFAIL: cpu
// XFAIL-TRACKER: CMPLRLLVM-68536
// UNSUPPORTED: cuda, hip
// UNSUPPORTED-INTENDED: Not implemented yet for Nvidia/AMD backends.

#include <iostream>
#include <sycl/ext/oneapi/free_function_queries.hpp>
#include <sycl/ext/oneapi/get_kernel_info.hpp>
#include <sycl/kernel_bundle.hpp>
#include <sycl/usm.hpp>

namespace syclext = sycl::ext::oneapi;
namespace syclexp = sycl::ext::oneapi::experimental;

static constexpr size_t NUM = 1024;
static constexpr size_t WGSIZE = 16;

SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::work_group_size<WGSIZE>))
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>))
void func(float start, float *ptr) {
size_t id = syclext::this_work_item::get_nd_item<1>().get_global_linear_id();
ptr[id] = start + static_cast<float>(id);
}

bool check_result(int *ptr) {
for (size_t i = 0; i < NUM; ++i) {
const int expected = 3 + static_cast<int>(i);
if (ptr[i] != expected)
return true;
}
return false;
}

static bool call_kernel_code(sycl::queue &q, sycl::kernel &kernel) {
int *ptr = sycl::malloc_shared<int>(NUM, q);
q.submit([&](sycl::handler &cgh) {
cgh.set_args(3, ptr);
sycl::nd_range ndr{{NUM}, {WGSIZE}};
cgh.parallel_for(ndr, kernel);
}).wait();
const bool ret = check_result(ptr);
sycl::free(ptr, q);
return ret;
}

bool test_ctxt_dev(sycl::kernel &k, sycl::queue &q) {
const auto wg_size_cmp =
k.get_info<sycl::info::kernel_device_specific::work_group_size>(
q.get_device());
const auto wg_size = syclexp::get_kernel_info<
func, sycl::info::kernel_device_specific::work_group_size>(
q.get_context(), q.get_device());
if (wg_size_cmp != wg_size)
std::cerr << "Work group size from get_info: " << wg_size_cmp
<< " is not equal to work group size from get_kernel_info: "
<< wg_size << std::endl;
return wg_size_cmp == wg_size;
}

bool test_ctxt(sycl::kernel &k, sycl::queue &q) {
const auto attributes =
syclexp::get_kernel_info<func, sycl::info::kernel::attributes>(
q.get_context());
const std::string wg_size_str = "work_group_size(";
if (attributes.empty() || attributes.find(wg_size_str) == std::string::npos) {
std::cerr << "Work group size attribute is not found in kernel attributes, "
"attributes:"
<< attributes << std::endl;
return false;
}
auto wg_size_pos = attributes.find(wg_size_str);
wg_size_pos += wg_size_str.size();
const auto comma_pos = attributes.find(',', wg_size_pos);
if (comma_pos == std::string::npos) {
std::cerr << "Comma not found in work group size attribute string"
<< std::endl;
return false;
}

const auto wg_size_str_value =
attributes.substr(wg_size_pos, comma_pos - wg_size_pos);
const size_t wg_size = std::stoul(wg_size_str_value);
if (wg_size != WGSIZE) {
std::cerr << "Work group size from attributes: " << wg_size
<< " is not equal to expected work group size: " << WGSIZE
<< std::endl;
return false;
}

if (const auto wg_size_cmp =
k.get_info<sycl::info::kernel_device_specific::work_group_size>(
q.get_device());
wg_size_cmp < wg_size) {
std::cerr << "Work group size from get_info: " << wg_size_cmp
<< " is less work group size from attributes: " << wg_size
<< std::endl;
return false;
}
return true;
}

bool test_queue(sycl::kernel &k, sycl::queue &q) {
const auto wg_size_cmp =
k.get_info<sycl::info::kernel_device_specific::work_group_size>(
q.get_device());
const auto wg_size = syclexp::get_kernel_info<
func, sycl::info::kernel_device_specific::work_group_size>(q);
if (wg_size_cmp != wg_size)
std::cerr << "Work group size from get_info: " << wg_size_cmp
<< " is not equal to work group size from get_kernel_info: "
<< wg_size << std::endl;
return wg_size_cmp == wg_size;
}

int main() {
sycl::queue q;
sycl::context ctxt = q.get_context();

auto exe_bndl =
syclexp::get_kernel_bundle<func, sycl::bundle_state::executable>(ctxt);
sycl::kernel k_func = exe_bndl.template ext_oneapi_get_kernel<func>();
call_kernel_code(q, k_func);

bool ret = test_ctxt_dev(k_func, q);
ret &= test_ctxt(k_func, q);
ret &= test_queue(k_func, q);
return ret ? 0 : 1;
}