Skip to content

Commit d2d56d1

Browse files
committed
Fix failures on cpu device
1 parent 7f3ebec commit d2d56d1

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

source/loader/layers/sanitizer/asan_interceptor.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,20 @@ ur_result_t SanitizerInterceptor::prepareLaunch(
634634
}
635635
}
636636

637+
auto ArgNums = GetKernelNumArgs(Kernel);
638+
// We must prepare all kernel args before call
639+
// urKernelGetSuggestedLocalWorkSize, otherwise the call will fail on
640+
// CPU device.
641+
if (ArgNums) {
642+
ur_result_t URes = getContext()->urDdiTable.Kernel.pfnSetArgPointer(
643+
Kernel, ArgNums - 1, nullptr, LaunchInfo.Data.getDevicePtr());
644+
if (URes != UR_RESULT_SUCCESS) {
645+
getContext()->logger.error("Failed to set launch info: {}",
646+
URes);
647+
return URes;
648+
}
649+
}
650+
637651
if (LaunchInfo.LocalWorkSize.empty()) {
638652
LaunchInfo.LocalWorkSize.resize(LaunchInfo.WorkDim);
639653
auto URes =
@@ -660,11 +674,6 @@ ur_result_t SanitizerInterceptor::prepareLaunch(
660674
LocalWorkSize[Dim];
661675
}
662676

663-
auto ArgNums = GetKernelNumArgs(Kernel);
664-
if (ArgNums == 0) {
665-
return UR_RESULT_SUCCESS;
666-
}
667-
668677
// Prepare asan runtime data
669678
LaunchInfo.Data.Host.GlobalShadowOffset =
670679
DeviceInfo->Shadow->ShadowBegin;
@@ -804,14 +813,6 @@ ur_result_t SanitizerInterceptor::prepareLaunch(
804813
// sync asan runtime data to device side
805814
UR_CALL(LaunchInfo.Data.syncToDevice(Queue));
806815

807-
// set kernel argument
808-
ur_result_t URes = getContext()->urDdiTable.Kernel.pfnSetArgPointer(
809-
Kernel, ArgNums - 1, nullptr, LaunchInfo.Data.getDevicePtr());
810-
if (URes != UR_RESULT_SUCCESS) {
811-
getContext()->logger.error("Failed to set launch info: {}", URes);
812-
return URes;
813-
}
814-
815816
getContext()->logger.debug(
816817
"launch_info {} (numLocalArgs={}, localArgs={})",
817818
(void *)LaunchInfo.Data.getDevicePtr(),

0 commit comments

Comments
 (0)