Skip to content

Commit 72e80a4

Browse files
authored
Merge pull request #2316 from 0x12CC/coop_kernel_query
Change `urSuggestMaxCooperativeGroupCountExp` to accept ND size parameter
2 parents 6e5d0e6 + 9c7e56c commit 72e80a4

File tree

16 files changed

+105
-41
lines changed

16 files changed

+105
-41
lines changed

include/ur_api.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9486,13 +9486,17 @@ urEnqueueCooperativeKernelLaunchExp(
94869486
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
94879487
/// + `NULL == hKernel`
94889488
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
9489+
/// + `NULL == pLocalWorkSize`
94899490
/// + `NULL == pGroupCountRet`
94909491
/// - ::UR_RESULT_ERROR_INVALID_KERNEL
94919492
UR_APIEXPORT ur_result_t UR_APICALL
94929493
urKernelSuggestMaxCooperativeGroupCountExp(
94939494
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
9494-
size_t localWorkSize, ///< [in] number of local work-items that will form a work-group when the
9495-
///< kernel is launched
9495+
uint32_t workDim, ///< [in] number of dimensions, from 1 to 3, to specify the work-group
9496+
///< work-items
9497+
const size_t *pLocalWorkSize, ///< [in] pointer to an array of workDim unsigned values that specify the
9498+
///< number of local work-items forming a work-group that will execute the
9499+
///< kernel function.
94969500
size_t dynamicSharedMemorySize, ///< [in] size of dynamic shared memory, for each work-group, in bytes,
94979501
///< that will be used when the kernel is launched
94989502
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
@@ -11028,7 +11032,8 @@ typedef struct ur_kernel_set_specialization_constants_params_t {
1102811032
/// allowing the callback the ability to modify the parameter's value
1102911033
typedef struct ur_kernel_suggest_max_cooperative_group_count_exp_params_t {
1103011034
ur_kernel_handle_t *phKernel;
11031-
size_t *plocalWorkSize;
11035+
uint32_t *pworkDim;
11036+
const size_t **ppLocalWorkSize;
1103211037
size_t *pdynamicSharedMemorySize;
1103311038
uint32_t **ppGroupCountRet;
1103411039
} ur_kernel_suggest_max_cooperative_group_count_exp_params_t;

include/ur_ddi.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,8 @@ typedef ur_result_t(UR_APICALL *ur_pfnGetKernelProcAddrTable_t)(
651651
/// @brief Function-pointer for urKernelSuggestMaxCooperativeGroupCountExp
652652
typedef ur_result_t(UR_APICALL *ur_pfnKernelSuggestMaxCooperativeGroupCountExp_t)(
653653
ur_kernel_handle_t,
654-
size_t,
654+
uint32_t,
655+
const size_t *,
655656
size_t,
656657
uint32_t *);
657658

include/ur_print.hpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13074,9 +13074,15 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct
1307413074
*(params->phKernel));
1307513075

1307613076
os << ", ";
13077-
os << ".localWorkSize = ";
13077+
os << ".workDim = ";
13078+
13079+
os << *(params->pworkDim);
13080+
13081+
os << ", ";
13082+
os << ".pLocalWorkSize = ";
1307813083

13079-
os << *(params->plocalWorkSize);
13084+
ur::details::printPtr(os,
13085+
*(params->ppLocalWorkSize));
1308013086

1308113087
os << ", ";
1308213088
os << ".dynamicSharedMemorySize = ";

scripts/core/exp-cooperative-kernels.yml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,13 @@ params:
7878
- type: $x_kernel_handle_t
7979
name: hKernel
8080
desc: "[in] handle of the kernel object"
81-
- type: size_t
82-
name: localWorkSize
83-
desc: "[in] number of local work-items that will form a work-group when the kernel is launched"
81+
- type: uint32_t
82+
name: workDim
83+
desc: "[in] number of dimensions, from 1 to 3, to specify the work-group work-items"
84+
- type: "const size_t*"
85+
name: pLocalWorkSize
86+
desc: |
87+
[in] pointer to an array of workDim unsigned values that specify the number of local work-items forming a work-group that will execute the kernel function.
8488
- type: size_t
8589
name: dynamicSharedMemorySize
8690
desc: "[in] size of dynamic shared memory, for each work-group, in bytes, that will be used when the kernel is launched"

source/adapters/cuda/kernel.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle(
190190
}
191191

192192
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
193-
ur_kernel_handle_t hKernel, size_t localWorkSize,
193+
ur_kernel_handle_t hKernel, uint32_t workDim, const size_t *pLocalWorkSize,
194194
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
195195
UR_ASSERT(hKernel, UR_RESULT_ERROR_INVALID_KERNEL);
196196

197+
size_t localWorkSize = pLocalWorkSize[0];
198+
localWorkSize *= (workDim >= 2 ? pLocalWorkSize[1] : 1);
199+
localWorkSize *= (workDim == 3 ? pLocalWorkSize[2] : 1);
200+
197201
// We need to set the active current device for this kernel explicitly here,
198202
// because the occupancy querying API does not take device parameter.
199203
ur_device_handle_t Device = hKernel->getProgram()->getDevice();

source/adapters/hip/kernel.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,11 @@ urKernelGetNativeHandle(ur_kernel_handle_t, ur_native_handle_t *) {
169169
}
170170

171171
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
172-
ur_kernel_handle_t hKernel, size_t localWorkSize,
172+
ur_kernel_handle_t hKernel, uint32_t workDim, const size_t *pLocalWorkSize,
173173
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
174174
std::ignore = hKernel;
175-
std::ignore = localWorkSize;
175+
std::ignore = workDim;
176+
std::ignore = pLocalWorkSize;
176177
std::ignore = dynamicSharedMemorySize;
177178
std::ignore = pGroupCountRet;
178179
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;

source/adapters/level_zero/kernel.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,11 +1054,17 @@ ur_result_t urKernelGetNativeHandle(
10541054
}
10551055

10561056
ur_result_t urKernelSuggestMaxCooperativeGroupCountExp(
1057-
ur_kernel_handle_t hKernel, size_t localWorkSize,
1057+
ur_kernel_handle_t hKernel, uint32_t workDim, const size_t *pLocalWorkSize,
10581058
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
1059-
(void)localWorkSize;
10601059
(void)dynamicSharedMemorySize;
10611060
std::shared_lock<ur_shared_mutex> Guard(hKernel->Mutex);
1061+
1062+
uint32_t WG[3];
1063+
WG[0] = ur_cast<uint32_t>(pLocalWorkSize[0]);
1064+
WG[1] = workDim >= 2 ? ur_cast<uint32_t>(pLocalWorkSize[1]) : 1;
1065+
WG[2] = workDim == 3 ? ur_cast<uint32_t>(pLocalWorkSize[2]) : 1;
1066+
ZE2UR_CALL(zeKernelSetGroupSize, (hKernel->ZeKernel, WG[0], WG[1], WG[2]));
1067+
10621068
uint32_t TotalGroupCount = 0;
10631069
ZE2UR_CALL(zeKernelSuggestMaxCooperativeGroupCount,
10641070
(hKernel->ZeKernel, &TotalGroupCount));

source/adapters/level_zero/ur_interface_loader.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@ ur_result_t urEnqueueCooperativeKernelLaunchExp(
687687
const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
688688
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent);
689689
ur_result_t urKernelSuggestMaxCooperativeGroupCountExp(
690-
ur_kernel_handle_t hKernel, size_t localWorkSize,
690+
ur_kernel_handle_t hKernel, uint32_t workDim, const size_t *pLocalWorkSize,
691691
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet);
692692
ur_result_t urEnqueueTimestampRecordingExp(
693693
ur_queue_handle_t hQueue, bool blocking, uint32_t numEventsInWaitList,

source/adapters/level_zero/v2/api.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ ur_result_t urCommandBufferCommandGetInfoExp(
568568
}
569569

570570
ur_result_t urKernelSuggestMaxCooperativeGroupCountExp(
571-
ur_kernel_handle_t hKernel, size_t localWorkSize,
571+
ur_kernel_handle_t hKernel, uint32_t workDim, const size_t *pLocalWorkSize,
572572
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
573573
logger::error("{} function not implemented!", __FUNCTION__);
574574
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;

source/adapters/mock/ur_mockddi.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10003,9 +10003,13 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
1000310003
/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCountExp
1000410004
__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
1000510005
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
10006-
size_t
10007-
localWorkSize, ///< [in] number of local work-items that will form a work-group when the
10008-
///< kernel is launched
10006+
uint32_t
10007+
workDim, ///< [in] number of dimensions, from 1 to 3, to specify the work-group
10008+
///< work-items
10009+
const size_t *
10010+
pLocalWorkSize, ///< [in] pointer to an array of workDim unsigned values that specify the
10011+
///< number of local work-items forming a work-group that will execute the
10012+
///< kernel function.
1000910013
size_t
1001010014
dynamicSharedMemorySize, ///< [in] size of dynamic shared memory, for each work-group, in bytes,
1001110015
///< that will be used when the kernel is launched
@@ -10014,7 +10018,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
1001410018
ur_result_t result = UR_RESULT_SUCCESS;
1001510019

1001610020
ur_kernel_suggest_max_cooperative_group_count_exp_params_t params = {
10017-
&hKernel, &localWorkSize, &dynamicSharedMemorySize, &pGroupCountRet};
10021+
&hKernel, &workDim, &pLocalWorkSize, &dynamicSharedMemorySize,
10022+
&pGroupCountRet};
1001810023

1001910024
auto beforeCallback = reinterpret_cast<ur_mock_callback_t>(
1002010025
mock::getCallbacks().get_before_callback(

source/adapters/opencl/kernel.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle(
390390

391391
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
392392
[[maybe_unused]] ur_kernel_handle_t hKernel,
393-
[[maybe_unused]] size_t localWorkSize,
393+
[[maybe_unused]] uint32_t workDim,
394+
[[maybe_unused]] const size_t *pLocalWorkSize,
394395
[[maybe_unused]] size_t dynamicSharedMemorySize,
395396
[[maybe_unused]] uint32_t *pGroupCountRet) {
396397
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;

source/loader/layers/tracing/ur_trcddi.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8585,9 +8585,13 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
85858585
/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCountExp
85868586
__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
85878587
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
8588-
size_t
8589-
localWorkSize, ///< [in] number of local work-items that will form a work-group when the
8590-
///< kernel is launched
8588+
uint32_t
8589+
workDim, ///< [in] number of dimensions, from 1 to 3, to specify the work-group
8590+
///< work-items
8591+
const size_t *
8592+
pLocalWorkSize, ///< [in] pointer to an array of workDim unsigned values that specify the
8593+
///< number of local work-items forming a work-group that will execute the
8594+
///< kernel function.
85918595
size_t
85928596
dynamicSharedMemorySize, ///< [in] size of dynamic shared memory, for each work-group, in bytes,
85938597
///< that will be used when the kernel is launched
@@ -8602,7 +8606,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
86028606
}
86038607

86048608
ur_kernel_suggest_max_cooperative_group_count_exp_params_t params = {
8605-
&hKernel, &localWorkSize, &dynamicSharedMemorySize, &pGroupCountRet};
8609+
&hKernel, &workDim, &pLocalWorkSize, &dynamicSharedMemorySize,
8610+
&pGroupCountRet};
86068611
uint64_t instance = getContext()->notify_begin(
86078612
UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP,
86088613
"urKernelSuggestMaxCooperativeGroupCountExp", &params);
@@ -8611,7 +8616,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
86118616
logger.info(" ---> urKernelSuggestMaxCooperativeGroupCountExp\n");
86128617

86138618
ur_result_t result = pfnSuggestMaxCooperativeGroupCountExp(
8614-
hKernel, localWorkSize, dynamicSharedMemorySize, pGroupCountRet);
8619+
hKernel, workDim, pLocalWorkSize, dynamicSharedMemorySize,
8620+
pGroupCountRet);
86158621

86168622
getContext()->notify_end(
86178623
UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP,

source/loader/layers/validation/ur_valddi.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9613,9 +9613,13 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
96139613
/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCountExp
96149614
__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
96159615
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
9616-
size_t
9617-
localWorkSize, ///< [in] number of local work-items that will form a work-group when the
9618-
///< kernel is launched
9616+
uint32_t
9617+
workDim, ///< [in] number of dimensions, from 1 to 3, to specify the work-group
9618+
///< work-items
9619+
const size_t *
9620+
pLocalWorkSize, ///< [in] pointer to an array of workDim unsigned values that specify the
9621+
///< number of local work-items forming a work-group that will execute the
9622+
///< kernel function.
96199623
size_t
96209624
dynamicSharedMemorySize, ///< [in] size of dynamic shared memory, for each work-group, in bytes,
96219625
///< that will be used when the kernel is launched
@@ -9634,6 +9638,10 @@ __urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
96349638
return UR_RESULT_ERROR_INVALID_NULL_HANDLE;
96359639
}
96369640

9641+
if (NULL == pLocalWorkSize) {
9642+
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
9643+
}
9644+
96379645
if (NULL == pGroupCountRet) {
96389646
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
96399647
}
@@ -9645,7 +9653,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
96459653
}
96469654

96479655
ur_result_t result = pfnSuggestMaxCooperativeGroupCountExp(
9648-
hKernel, localWorkSize, dynamicSharedMemorySize, pGroupCountRet);
9656+
hKernel, workDim, pLocalWorkSize, dynamicSharedMemorySize,
9657+
pGroupCountRet);
96499658

96509659
return result;
96519660
}

source/loader/ur_ldrddi.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8760,9 +8760,13 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
87608760
/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCountExp
87618761
__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
87628762
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
8763-
size_t
8764-
localWorkSize, ///< [in] number of local work-items that will form a work-group when the
8765-
///< kernel is launched
8763+
uint32_t
8764+
workDim, ///< [in] number of dimensions, from 1 to 3, to specify the work-group
8765+
///< work-items
8766+
const size_t *
8767+
pLocalWorkSize, ///< [in] pointer to an array of workDim unsigned values that specify the
8768+
///< number of local work-items forming a work-group that will execute the
8769+
///< kernel function.
87668770
size_t
87678771
dynamicSharedMemorySize, ///< [in] size of dynamic shared memory, for each work-group, in bytes,
87688772
///< that will be used when the kernel is launched
@@ -8785,7 +8789,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
87858789

87868790
// forward to device-platform
87878791
result = pfnSuggestMaxCooperativeGroupCountExp(
8788-
hKernel, localWorkSize, dynamicSharedMemorySize, pGroupCountRet);
8792+
hKernel, workDim, pLocalWorkSize, dynamicSharedMemorySize,
8793+
pGroupCountRet);
87898794

87908795
return result;
87918796
}

source/loader/ur_libapi.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8893,13 +8893,18 @@ ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
88938893
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
88948894
/// + `NULL == hKernel`
88958895
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
8896+
/// + `NULL == pLocalWorkSize`
88968897
/// + `NULL == pGroupCountRet`
88978898
/// - ::UR_RESULT_ERROR_INVALID_KERNEL
88988899
ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
88998900
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
8900-
size_t
8901-
localWorkSize, ///< [in] number of local work-items that will form a work-group when the
8902-
///< kernel is launched
8901+
uint32_t
8902+
workDim, ///< [in] number of dimensions, from 1 to 3, to specify the work-group
8903+
///< work-items
8904+
const size_t *
8905+
pLocalWorkSize, ///< [in] pointer to an array of workDim unsigned values that specify the
8906+
///< number of local work-items forming a work-group that will execute the
8907+
///< kernel function.
89038908
size_t
89048909
dynamicSharedMemorySize, ///< [in] size of dynamic shared memory, for each work-group, in bytes,
89058910
///< that will be used when the kernel is launched
@@ -8913,7 +8918,8 @@ ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
89138918
}
89148919

89158920
return pfnSuggestMaxCooperativeGroupCountExp(
8916-
hKernel, localWorkSize, dynamicSharedMemorySize, pGroupCountRet);
8921+
hKernel, workDim, pLocalWorkSize, dynamicSharedMemorySize,
8922+
pGroupCountRet);
89178923
} catch (...) {
89188924
return exceptionToResult(std::current_exception());
89198925
}

source/ur_api.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7543,13 +7543,18 @@ ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
75437543
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
75447544
/// + `NULL == hKernel`
75457545
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
7546+
/// + `NULL == pLocalWorkSize`
75467547
/// + `NULL == pGroupCountRet`
75477548
/// - ::UR_RESULT_ERROR_INVALID_KERNEL
75487549
ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
75497550
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
7550-
size_t
7551-
localWorkSize, ///< [in] number of local work-items that will form a work-group when the
7552-
///< kernel is launched
7551+
uint32_t
7552+
workDim, ///< [in] number of dimensions, from 1 to 3, to specify the work-group
7553+
///< work-items
7554+
const size_t *
7555+
pLocalWorkSize, ///< [in] pointer to an array of workDim unsigned values that specify the
7556+
///< number of local work-items forming a work-group that will execute the
7557+
///< kernel function.
75537558
size_t
75547559
dynamicSharedMemorySize, ///< [in] size of dynamic shared memory, for each work-group, in bytes,
75557560
///< that will be used when the kernel is launched

0 commit comments

Comments
 (0)