@@ -907,17 +907,18 @@ ur_result_t setKernelPendingArguments(
907
907
* @param[in] CommandBuffer The CommandBuffer associated with the new command.
908
908
* @param[in] Kernel The Kernel associated with the new command.
909
909
* @param[in] WorkDim Dimensions of the kernel associated with the new command.
910
+ * @param[in] GlobalWorkSize Global work size of the kernel associated with the
911
+ * new command.
910
912
* @param[in] LocalWorkSize LocalWorkSize of the kernel associated with the new
911
913
* command.
912
914
* @param[out] Command The handle to the new command.
913
915
* @return UR_RESULT_SUCCESS or an error code on failure
914
916
*/
915
- ur_result_t
916
- createCommandHandle (ur_exp_command_buffer_handle_t CommandBuffer,
917
- ur_kernel_handle_t Kernel, uint32_t WorkDim,
918
- const size_t *LocalWorkSize, uint32_t NumKernelAlternatives,
919
- ur_kernel_handle_t *KernelAlternatives,
920
- ur_exp_command_buffer_command_handle_t *Command) {
917
+ ur_result_t createCommandHandle (
918
+ ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel,
919
+ uint32_t WorkDim, const size_t *GlobalWorkSize, const size_t *LocalWorkSize,
920
+ uint32_t NumKernelAlternatives, ur_kernel_handle_t *KernelAlternatives,
921
+ ur_exp_command_buffer_command_handle_t *Command) {
921
922
922
923
assert (CommandBuffer->IsUpdatable );
923
924
@@ -989,6 +990,8 @@ createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer,
989
990
CommandBuffer, Kernel, CommandId, WorkDim, LocalWorkSize != nullptr ,
990
991
NumKernelAlternatives, KernelAlternatives);
991
992
993
+ NewCommand->setGlobalWorkSize (GlobalWorkSize);
994
+
992
995
*Command = NewCommand.get ();
993
996
994
997
CommandBuffer->CommandHandles .push_back (std::move (NewCommand));
@@ -1063,9 +1066,9 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
1063
1066
}
1064
1067
1065
1068
if (Command) {
1066
- UR_CALL (createCommandHandle (CommandBuffer, Kernel, WorkDim, LocalWorkSize ,
1067
- NumKernelAlternatives, KernelAlternatives ,
1068
- Command));
1069
+ UR_CALL (createCommandHandle (CommandBuffer, Kernel, WorkDim, GlobalWorkSize ,
1070
+ LocalWorkSize, NumKernelAlternatives ,
1071
+ KernelAlternatives, Command));
1069
1072
}
1070
1073
std::vector<ze_event_handle_t > ZeEventList;
1071
1074
ze_event_handle_t ZeLaunchEvent = nullptr ;
@@ -1919,10 +1922,16 @@ ur_result_t updateKernelCommand(
1919
1922
Descs.push_back (std::move (MutableGroupSizeDesc));
1920
1923
}
1921
1924
1922
- // Check if a new global size is provided and if we need to update the group
1923
- // count.
1925
+ // Check if a new global or local size is provided and if so we need to update
1926
+ // the group count.
1924
1927
ze_group_count_t ZeThreadGroupDimensions{1 , 1 , 1 };
1925
- if (NewGlobalWorkSize && Dim > 0 ) {
1928
+ if ((NewGlobalWorkSize || NewLocalWorkSize) && Dim > 0 ) {
1929
+ // If a new global work size is provided update that in the command,
1930
+ // otherwise the previous work group size will be used
1931
+ if (NewGlobalWorkSize) {
1932
+ Command->WorkDim = Dim;
1933
+ Command->setGlobalWorkSize (NewGlobalWorkSize);
1934
+ }
1926
1935
// If a new global work size is provided but a new local work size is not
1927
1936
// then we still need to update local work size based on the size suggested
1928
1937
// by the driver for the kernel.
@@ -1932,9 +1941,9 @@ ur_result_t updateKernelCommand(
1932
1941
UR_CALL (getZeKernel (ZeDevice, Command->Kernel , &ZeKernel));
1933
1942
1934
1943
uint32_t WG[3 ];
1935
- UR_CALL (calculateKernelWorkDimensions (ZeKernel, CommandBuffer-> Device ,
1936
- ZeThreadGroupDimensions, WG, Dim,
1937
- NewGlobalWorkSize , NewLocalWorkSize));
1944
+ UR_CALL (calculateKernelWorkDimensions (
1945
+ ZeKernel, CommandBuffer-> Device , ZeThreadGroupDimensions, WG, Dim,
1946
+ Command-> GlobalWorkSize , NewLocalWorkSize));
1938
1947
1939
1948
auto MutableGroupCountDesc =
1940
1949
std::make_unique<ZeStruct<ze_mutable_group_count_exp_desc_t >>();
0 commit comments