Skip to content

Commit 83f86e5

Browse files
[NFCI][SYCL] Return cached device_impl properties by reference (#18477)
1 parent 092189d commit 83f86e5

File tree

3 files changed

+78
-35
lines changed

3 files changed

+78
-35
lines changed

sycl/include/sycl/detail/util.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,10 @@ template <> struct ABINeutralT<std::vector<std::string>> {
9090
template <typename T> using ABINeutralT_t = typename ABINeutralT<T>::type;
9191

9292
template <typename ParamT> auto convert_to_abi_neutral(ParamT &&Info) {
93-
using ParamNoRef = std::remove_reference_t<ParamT>;
94-
if constexpr (std::is_same_v<ParamNoRef, std::string>) {
93+
using ParamDecayT = std::decay_t<ParamT>;
94+
if constexpr (std::is_same_v<ParamDecayT, std::string>) {
9595
return detail::string{Info};
96-
} else if constexpr (std::is_same_v<ParamNoRef, std::vector<std::string>>) {
96+
} else if constexpr (std::is_same_v<ParamDecayT, std::vector<std::string>>) {
9797
std::vector<detail::string> Res;
9898
Res.reserve(Info.size());
9999
for (std::string &Str : Info) {

sycl/source/detail/device_impl.hpp

Lines changed: 66 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,7 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
471471

472472
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
473473
template <typename Param, bool InitializingCache = false>
474-
typename Param::return_type get_info() const {
474+
decltype(auto) get_info() const {
475475
#define CALL_GET_INFO get_info
476476
#else
477477
// We've been exporting
@@ -484,10 +484,27 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
484484
#define CALL_GET_INFO get_info_abi_workaround
485485
template <typename Param> typename Param::return_type get_info() const;
486486
template <typename Param, bool InitializingCache = false>
487-
typename Param::return_type get_info_abi_workaround() const {
487+
decltype(auto) get_info_abi_workaround() const {
488488
#endif
489489
using execution_scope = ext::oneapi::experimental::execution_scope;
490490

491+
// With the return type of this function being automatically
492+
// deduced we can't simply do
493+
//
494+
// CASE(Desc1) { return get_info<Desc2>(); }
495+
//
496+
// because the function isn't defined yet and we can't auto-deduce the
497+
// return type for `Desc2` yet. The solution here is to make that delegation
498+
// template-parameter-dependent. We use the `InitializingCache` parameter
499+
// for that out of convenience.
500+
//
501+
// Note that for "eager" cache it's the programmer's responsibility that
502+
// the descriptor we delegate to is initialized first (by referencing that
503+
// descriptor first when defining the cache data member). For "CallOnce"
504+
// cache we want to be querying cached value so "false" is the right
505+
// template parameter for such delegation.
506+
constexpr bool DependentFalse = InitializingCache && false;
507+
491508
if constexpr (decltype(MCache)::has<Param>() && !InitializingCache) {
492509
return MCache.get<Param>();
493510
}
@@ -523,11 +540,13 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
523540
return range<3>{result[2], result[1], result[0]};
524541
}
525542
CASE(info::device::max_work_item_sizes<2>) {
526-
range<3> r3 = CALL_GET_INFO<info::device::max_work_item_sizes<3>>();
543+
range<3> r3 =
544+
CALL_GET_INFO<info::device::max_work_item_sizes<3>, DependentFalse>();
527545
return range<2>{r3[1], r3[2]};
528546
}
529547
CASE(info::device::max_work_item_sizes<1>) {
530-
range<3> r3 = CALL_GET_INFO<info::device::max_work_item_sizes<3>>();
548+
range<3> r3 =
549+
CALL_GET_INFO<info::device::max_work_item_sizes<3>, DependentFalse>();
531550
return range<1>{r3[2]};
532551
}
533552

@@ -608,16 +627,18 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
608627
// profiling, urDeviceGetGlobalTimestamps is not supported,
609628
// command_submit, command_start, command_end will be calculated. See
610629
// MFallbackProfiling
611-
return get_info_impl<UR_DEVICE_INFO_QUEUE_PROPERTIES>() &
612-
UR_QUEUE_FLAG_PROFILING_ENABLE;
630+
return static_cast<bool>(
631+
get_info_impl<UR_DEVICE_INFO_QUEUE_PROPERTIES>() &
632+
UR_QUEUE_FLAG_PROFILING_ENABLE);
613633
}
614634

615635
CASE(info::device::built_in_kernels) {
616636
return split_string(get_info_impl<UR_DEVICE_INFO_BUILT_IN_KERNELS>(),
617637
';');
618638
}
619639
CASE(info::device::built_in_kernel_ids) {
620-
auto names = CALL_GET_INFO<info::device::built_in_kernels>();
640+
auto names =
641+
CALL_GET_INFO<info::device::built_in_kernels, DependentFalse>();
621642

622643
std::vector<kernel_id> ids;
623644
ids.reserve(names.size());
@@ -655,7 +676,8 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
655676
"the info::device::preferred_interop_user_sync info descriptor can "
656677
"only be queried with an OpenCL backend");
657678

658-
return get_info_impl<UR_DEVICE_INFO_PREFERRED_INTEROP_USER_SYNC>();
679+
return static_cast<bool>(
680+
get_info_impl<UR_DEVICE_INFO_PREFERRED_INTEROP_USER_SYNC>());
659681
}
660682

661683
CASE(info::device::partition_properties) {
@@ -753,16 +775,19 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
753775
}
754776

755777
CASE(info::device::usm_device_allocations) {
756-
return get_info_impl<UR_DEVICE_INFO_USM_DEVICE_SUPPORT>() &
757-
UR_DEVICE_USM_ACCESS_CAPABILITY_FLAG_ACCESS;
778+
return static_cast<bool>(
779+
get_info_impl<UR_DEVICE_INFO_USM_DEVICE_SUPPORT>() &
780+
UR_DEVICE_USM_ACCESS_CAPABILITY_FLAG_ACCESS);
758781
}
759782
CASE(info::device::usm_host_allocations) {
760-
return get_info_impl<UR_DEVICE_INFO_USM_HOST_SUPPORT>() &
761-
UR_DEVICE_USM_ACCESS_CAPABILITY_FLAG_ACCESS;
783+
return static_cast<bool>(
784+
get_info_impl<UR_DEVICE_INFO_USM_HOST_SUPPORT>() &
785+
UR_DEVICE_USM_ACCESS_CAPABILITY_FLAG_ACCESS);
762786
}
763787
CASE(info::device::usm_shared_allocations) {
764-
return get_info_impl<UR_DEVICE_INFO_USM_SINGLE_SHARED_SUPPORT>() &
765-
UR_DEVICE_USM_ACCESS_CAPABILITY_FLAG_ACCESS;
788+
return static_cast<bool>(
789+
get_info_impl<UR_DEVICE_INFO_USM_SINGLE_SHARED_SUPPORT>() &
790+
UR_DEVICE_USM_ACCESS_CAPABILITY_FLAG_ACCESS);
766791
}
767792
CASE(info::device::usm_restricted_shared_allocations) {
768793
ur_device_usm_access_capability_flags_t cap_flags =
@@ -773,14 +798,16 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
773798
UR_DEVICE_USM_ACCESS_CAPABILITY_FLAG_CONCURRENT_ACCESS));
774799
}
775800
CASE(info::device::usm_system_allocations) {
776-
return get_info_impl<UR_DEVICE_INFO_USM_SYSTEM_SHARED_SUPPORT>() &
777-
UR_DEVICE_USM_ACCESS_CAPABILITY_FLAG_ACCESS;
801+
return static_cast<bool>(
802+
get_info_impl<UR_DEVICE_INFO_USM_SYSTEM_SHARED_SUPPORT>() &
803+
UR_DEVICE_USM_ACCESS_CAPABILITY_FLAG_ACCESS);
778804
}
779805

780806
CASE(info::device::opencl_c_version) {
781807
throw sycl::exception(errc::feature_not_supported,
782808
"Deprecated interface that hasn't been working for "
783809
"some time already");
810+
return std::string{}; // for return type deduction.
784811
}
785812

786813
CASE(ext::intel::info::device::max_mem_bandwidth) {
@@ -794,30 +821,34 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
794821
CASE(info::device::ext_oneapi_max_global_work_groups) {
795822
// Deprecated alias.
796823
return CALL_GET_INFO<
797-
ext::oneapi::experimental::info::device::max_global_work_groups>();
824+
ext::oneapi::experimental::info::device::max_global_work_groups,
825+
DependentFalse>();
798826
}
799827
CASE(info::device::ext_oneapi_max_work_groups_1d) {
800828
// Deprecated alias.
801829
return CALL_GET_INFO<
802-
ext::oneapi::experimental::info::device::max_work_groups<1>>();
830+
ext::oneapi::experimental::info::device::max_work_groups<1>,
831+
DependentFalse>();
803832
}
804833
CASE(info::device::ext_oneapi_max_work_groups_2d) {
805834
// Deprecated alias.
806835
return CALL_GET_INFO<
807-
ext::oneapi::experimental::info::device::max_work_groups<2>>();
836+
ext::oneapi::experimental::info::device::max_work_groups<2>,
837+
DependentFalse>();
808838
}
809839
CASE(info::device::ext_oneapi_max_work_groups_3d) {
810840
// Deprecated alias.
811841
return CALL_GET_INFO<
812-
ext::oneapi::experimental::info::device::max_work_groups<3>>();
842+
ext::oneapi::experimental::info::device::max_work_groups<3>,
843+
DependentFalse>();
813844
}
814845

815846
CASE(info::device::ext_oneapi_cuda_cluster_group) {
816847
if (getBackend() != backend::ext_oneapi_cuda)
817848
return false;
818849

819850
return get_info_impl_nocheck<UR_DEVICE_INFO_CLUSTER_LAUNCH_SUPPORT_EXP>()
820-
.value_or(0);
851+
.value_or(0) != 0;
821852
}
822853

823854
// ext_codeplay_device_traits.def
@@ -834,7 +865,8 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
834865
}
835866
CASE(ext::oneapi::experimental::info::device::max_work_groups<3>) {
836867
size_t Limit = CALL_GET_INFO<
837-
ext::oneapi::experimental::info::device::max_global_work_groups>();
868+
ext::oneapi::experimental::info::device::max_global_work_groups,
869+
DependentFalse>();
838870

839871
// TODO: std::array<size_t, 3> ?
840872
size_t result[3];
@@ -846,12 +878,14 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
846878
}
847879
CASE(ext::oneapi::experimental::info::device::max_work_groups<2>) {
848880
id<3> max_3d = CALL_GET_INFO<
849-
ext::oneapi::experimental::info::device::max_work_groups<3>>();
881+
ext::oneapi::experimental::info::device::max_work_groups<3>,
882+
DependentFalse>();
850883
return id<2>{max_3d[1], max_3d[2]};
851884
}
852885
CASE(ext::oneapi::experimental::info::device::max_work_groups<1>) {
853886
id<3> max_3d = CALL_GET_INFO<
854-
ext::oneapi::experimental::info::device::max_work_groups<3>>();
887+
ext::oneapi::experimental::info::device::max_work_groups<3>,
888+
DependentFalse>();
855889
return id<1>{max_3d[2]};
856890
}
857891

@@ -896,8 +930,8 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
896930
}
897931

898932
CASE(ext::oneapi::experimental::info::device::mipmap_max_anisotropy) {
899-
// Implicit conversion:
900-
return get_info_impl<UR_DEVICE_INFO_MIPMAP_MAX_ANISOTROPY_EXP>();
933+
return static_cast<float>(
934+
get_info_impl<UR_DEVICE_INFO_MIPMAP_MAX_ANISOTROPY_EXP>());
901935
}
902936

903937
CASE(ext::oneapi::experimental::info::device::component_devices) {
@@ -906,7 +940,7 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
906940
if (!Devs.has_val()) {
907941
ur_result_t Err = Devs.error();
908942
if (Err == UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION)
909-
return {};
943+
return std::vector<sycl::device>{};
910944
getAdapter()->checkUrResult(Err);
911945
}
912946

@@ -935,8 +969,8 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
935969
"must have a composite device.");
936970
}
937971
CASE(ext::oneapi::info::device::num_compute_units) {
938-
// uint32_t -> size_t
939-
return get_info_impl<UR_DEVICE_INFO_NUM_COMPUTE_UNITS>();
972+
return static_cast<size_t>(
973+
get_info_impl<UR_DEVICE_INFO_NUM_COMPUTE_UNITS>());
940974
}
941975

942976
// ext_intel_device_traits.def
@@ -1028,8 +1062,8 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
10281062
return get_info_impl<UR_DEVICE_INFO_MEMORY_BUS_WIDTH>();
10291063
}
10301064
CASE(ext::intel::info::device::max_compute_queue_indices) {
1031-
// uint32_t->int implicit conversion.
1032-
return get_info_impl<UR_DEVICE_INFO_MAX_COMPUTE_QUEUE_INDICES>();
1065+
return static_cast<int>(
1066+
get_info_impl<UR_DEVICE_INFO_MAX_COMPUTE_QUEUE_INDICES>());
10331067
}
10341068
CASE(ext::intel::esimd::info::device::has_2d_block_io_support) {
10351069
if (!has(aspect::ext_intel_esimd))
@@ -1093,7 +1127,7 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
10931127
}
10941128
else {
10951129
constexpr auto Desc = UrInfoCode<Param>::value;
1096-
return get_info_impl<Desc>();
1130+
return static_cast<typename Param::return_type>(get_info_impl<Desc>());
10971131
}
10981132
#undef CASE
10991133
}

sycl/source/device.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,15 @@ bool device::has_extension(detail::string_view ext_name) const {
124124
template <typename Param>
125125
detail::ABINeutralT_t<typename detail::is_device_info_desc<Param>::return_type>
126126
device::get_info_impl() const {
127+
static_assert(
128+
std::is_same_v<typename detail::is_device_info_desc<Param>::return_type,
129+
decltype(impl->template
130+
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
131+
get_info
132+
#else
133+
get_info_abi_workaround
134+
#endif
135+
<Param, true /* InitializingCache */>())>);
127136
return detail::convert_to_abi_neutral(impl->template get_info<Param>());
128137
}
129138

0 commit comments

Comments
 (0)