File tree Expand file tree Collapse file tree 3 files changed +31
-0
lines changed Expand file tree Collapse file tree 3 files changed +31
-0
lines changed Original file line number Diff line number Diff line change @@ -204,6 +204,8 @@ cdef extern from "syclinterface/dpctl_sycl_device_interface.h":
204
204
cdef uint64_t DPCTLDevice_GetGlobalMemCacheSize(const DPCTLSyclDeviceRef DRef)
205
205
cdef _global_mem_cache_type DPCTLDevice_GetGlobalMemCacheType(
206
206
const DPCTLSyclDeviceRef DRef)
207
+ cdef size_t * DPCTLDevice_GetSubGroupSizes(const DPCTLSyclDeviceRef DRef,
208
+ size_t * res_len)
207
209
208
210
209
211
cdef extern from " syclinterface/dpctl_sycl_device_manager.h" :
Original file line number Diff line number Diff line change @@ -65,6 +65,7 @@ from ._backend cimport ( # noqa: E211
65
65
DPCTLDevice_GetPreferredVectorWidthShort,
66
66
DPCTLDevice_GetProfilingTimerResolution,
67
67
DPCTLDevice_GetSubGroupIndependentForwardProgress,
68
+ DPCTLDevice_GetSubGroupSizes,
68
69
DPCTLDevice_GetVendor,
69
70
DPCTLDevice_HasAspect,
70
71
DPCTLDevice_Hash,
@@ -884,6 +885,28 @@ cdef class SyclDevice(_SyclDevice):
884
885
self ._device_ref
885
886
)
886
887
888
+ @property
889
+ def sub_group_sizes (self ):
890
+ """ Returns list of supported sub-group sizes for this device.
891
+
892
+ Returns:
893
+ List[int]: List of supported sub-group sizes.
894
+ """
895
+ cdef size_t * sg_sizes = NULL
896
+ cdef size_t sg_sizes_len = 0
897
+ cdef size_t i
898
+
899
+ sg_sizes = DPCTLDevice_GetSubGroupSizes(
900
+ self ._device_ref, & sg_sizes_len)
901
+ if (sg_sizes is not NULL and sg_sizes_len > 0 ):
902
+ res = list ()
903
+ for i in range (sg_sizes_len):
904
+ res.append(sg_sizes[i])
905
+ DPCTLSize_t_Array_Delete(sg_sizes)
906
+ return res
907
+ else :
908
+ return []
909
+
887
910
@property
888
911
def sycl_platform (self ):
889
912
""" Returns the platform associated with this device.
Original file line number Diff line number Diff line change @@ -115,6 +115,11 @@ def check_max_num_sub_groups(device):
115
115
assert max_num_sub_groups > 0
116
116
117
117
118
+ def check_sub_group_sizes (device ):
119
+ sg_sizes = device .sub_group_sizes
120
+ assert all (el > 0 for el in sg_sizes )
121
+
122
+
118
123
def check_has_aspect_host (device ):
119
124
try :
120
125
device .has_aspect_host
@@ -605,6 +610,7 @@ def check_global_mem_cache_line_size(device):
605
610
check_max_work_item_sizes ,
606
611
check_max_work_group_size ,
607
612
check_max_num_sub_groups ,
613
+ check_sub_group_sizes ,
608
614
check_is_accelerator ,
609
615
check_is_cpu ,
610
616
check_is_gpu ,
You can’t perform that action at this time.
0 commit comments