Skip to content

Commit b88bd7a

Browse files
Added size_t *DPCTLDevice_GetSubGroupSizes(DRef, size_t res_len)
The function exposes `device::get_info<info::device::sub_group_sizes>()` which returns `std::vector<size_t>`. DPCTLDevice_GetSubGroupSizes returns pointer to allocated array, populated with the content of the result std::vector. res_len is set with the size of the result std::vector.
1 parent 530a7d7 commit b88bd7a

File tree

3 files changed

+66
-0
lines changed

3 files changed

+66
-0
lines changed

libsyclinterface/include/dpctl_sycl_device_interface.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,4 +651,17 @@ DPCTL_API
651651
DPCTLGlobalMemCacheType
652652
DPCTLDevice_GetGlobalMemCacheType(__dpctl_keep const DPCTLSyclDeviceRef DRef);
653653

654+
/*!
655+
* @brief Wrapper for get_info<info::device::sub_group_sizes>().
656+
*
657+
* @param DRef Opaque pointer to a ``sycl::device``
658+
* @param res_len Populated with size of the returned array
659+
* @return Returns the valid result if device exists else returns NULL.
660+
* @ingroup DeviceInterface
661+
*/
662+
DPCTL_API
663+
__dpctl_keep size_t *
664+
DPCTLDevice_GetSubGroupSizes(__dpctl_keep const DPCTLSyclDeviceRef DRef,
665+
size_t *res_len);
666+
654667
DPCTL_C_EXTERN_C_END

libsyclinterface/source/dpctl_sycl_device_interface.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,3 +743,30 @@ DPCTLDevice_GetGlobalMemCacheType(__dpctl_keep const DPCTLSyclDeviceRef DRef)
743743
return DPCTL_MEM_CACHE_TYPE_INDETERMINATE;
744744
}
745745
}
746+
747+
__dpctl_keep size_t *
748+
DPCTLDevice_GetSubGroupSizes(__dpctl_keep const DPCTLSyclDeviceRef DRef,
749+
size_t *res_len)
750+
{
751+
size_t *sizes = nullptr;
752+
std::vector<size_t> sg_sizes;
753+
*res_len = 0;
754+
auto D = unwrap<device>(DRef);
755+
if (D) {
756+
try {
757+
sg_sizes = D->get_info<info::device::sub_group_sizes>();
758+
*res_len = sg_sizes.size();
759+
} catch (std::exception const &e) {
760+
error_handler(e, __FILE__, __func__, __LINE__);
761+
}
762+
try {
763+
sizes = new size_t[sg_sizes.size()];
764+
} catch (std::exception const &e) {
765+
error_handler(e, __FILE__, __func__, __LINE__);
766+
}
767+
for (auto i = 0ul; (sizes != nullptr) && i < sg_sizes.size(); ++i) {
768+
sizes[i] = sg_sizes[i];
769+
}
770+
}
771+
return sizes;
772+
}

libsyclinterface/tests/test_sycl_device_interface.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,22 @@ TEST_P(TestDPCTLSyclDeviceInterface, ChkGetMaxNumSubGroups)
205205
EXPECT_TRUE(n > 0);
206206
}
207207

208+
TEST_P(TestDPCTLSyclDeviceInterface, ChkGetSubGroupSizes)
209+
{
210+
size_t sg_sizes_len = 0;
211+
size_t *sg_sizes = nullptr;
212+
EXPECT_NO_FATAL_FAILURE(
213+
sg_sizes = DPCTLDevice_GetSubGroupSizes(DRef, &sg_sizes_len));
214+
if (DPCTLDevice_IsAccelerator(DRef))
215+
EXPECT_TRUE(sg_sizes_len >= 0);
216+
else
217+
EXPECT_TRUE(sg_sizes_len > 0);
218+
for (size_t i = 0; i < sg_sizes_len; ++i) {
219+
EXPECT_TRUE(sg_sizes > 0);
220+
}
221+
EXPECT_NO_FATAL_FAILURE(DPCTLSize_t_Array_Delete(sg_sizes));
222+
}
223+
208224
TEST_P(TestDPCTLSyclDeviceInterface, ChkGetPlatform)
209225
{
210226
DPCTLSyclPlatformRef PRef = nullptr;
@@ -751,3 +767,13 @@ TEST_F(TestDPCTLSyclDeviceNullArgs, ChkGetGlobalMemCacheType)
751767
EXPECT_NO_FATAL_FAILURE(res = DPCTLDevice_GetGlobalMemCacheType(Null_DRef));
752768
ASSERT_TRUE(res == DPCTL_MEM_CACHE_TYPE_INDETERMINATE);
753769
}
770+
771+
TEST_F(TestDPCTLSyclDeviceNullArgs, ChkGetSubGroupSizes)
772+
{
773+
size_t *sg_sizes = nullptr;
774+
size_t sg_sizes_len = 0;
775+
EXPECT_NO_FATAL_FAILURE(
776+
sg_sizes = DPCTLDevice_GetSubGroupSizes(Null_DRef, &sg_sizes_len));
777+
ASSERT_TRUE(sg_sizes == nullptr);
778+
ASSERT_TRUE(sg_sizes_len == 0);
779+
}

0 commit comments

Comments
 (0)