Skip to content

Commit 6adbb52

Browse files
authored
Merge pull request #2527 from RossBrunton/ross/wrapper
Wrap urEventSetCallback when ran through loader
2 parents 4a91696 + 2dc42f6 commit 6adbb52

File tree

3 files changed

+64
-7
lines changed

3 files changed

+64
-7
lines changed

scripts/templates/ldrddi.cpp.mako

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,37 @@ from templates import helper as th
2424
namespace ur_loader
2525
{
2626
%for obj in th.get_adapter_functions(specs):
27+
<%
28+
func_name = th.make_func_name(n, tags, obj)
29+
if func_name.startswith(x):
30+
func_basename = func_name[len(x):]
31+
else:
32+
func_basename = func_name
33+
%>
34+
%if func_basename == "EventSetCallback":
35+
namespace {
36+
struct event_callback_wrapper_data_t {
37+
${x}_event_callback_t fn;
38+
${x}_event_handle_t event;
39+
void *userData;
40+
};
41+
42+
void event_callback_wrapper([[maybe_unused]] ${x}_event_handle_t hEvent,
43+
${x}_execution_info_t execStatus, void *pUserData) {
44+
auto *wrapper =
45+
reinterpret_cast<event_callback_wrapper_data_t *>(pUserData);
46+
(wrapper->fn)(wrapper->event, execStatus, wrapper->userData);
47+
delete wrapper;
48+
}
49+
}
50+
51+
%endif
2752
///////////////////////////////////////////////////////////////////////////////
28-
/// @brief Intercept function for ${th.make_func_name(n, tags, obj)}
53+
/// @brief Intercept function for ${func_name}
2954
%if 'condition' in obj:
3055
#if ${th.subt(n, tags, obj['condition'])}
3156
%endif
32-
__${x}dlllocal ${x}_result_t ${X}_APICALL
33-
${th.make_func_name(n, tags, obj)}(
57+
__${x}dlllocal ${x}_result_t ${X}_APICALL ${func_name}(
3458
%for line in th.make_param_lines(n, tags, obj):
3559
${line}
3660
%endfor
@@ -41,7 +65,16 @@ namespace ur_loader
4165
%>${th.get_initial_null_set(obj)}
4266

4367
[[maybe_unused]] auto context = getContext();
44-
%if re.match(r"\w+AdapterGet$", th.make_func_name(n, tags, obj)):
68+
%if func_basename == "EventSetCallback":
69+
70+
// Replace the callback with a wrapper function that gives the callback the loader event rather than a
71+
// backend-specific event
72+
auto wrapper_data =
73+
new event_callback_wrapper_data_t{pfnNotify, hEvent, pUserData};
74+
pUserData = wrapper_data;
75+
pfnNotify = event_callback_wrapper;
76+
%endif
77+
%if func_basename == "AdapterGet":
4578

4679
size_t adapterIndex = 0;
4780
if( nullptr != ${obj['params'][1]['name']} && ${obj['params'][0]['name']} !=0)
@@ -74,7 +107,7 @@ namespace ur_loader
74107
*${obj['params'][2]['name']} = static_cast<uint32_t>(context->platforms.size());
75108
}
76109

77-
%elif re.match(r"\w+PlatformGet$", th.make_func_name(n, tags, obj)):
110+
%elif func_basename == "PlatformGet":
78111
uint32_t total_platform_handle_count = 0;
79112

80113
for( uint32_t adapter_index = 0; adapter_index < ${obj['params'][1]['name']}; adapter_index++)
@@ -263,7 +296,7 @@ namespace ur_loader
263296
%for i, item in enumerate(epilogue):
264297
%if 0 == i and not item['release'] and not item['retain'] and not th.always_wrap_outputs(obj):
265298
## TODO: Remove once we have a concrete way for submitting warnings in place.
266-
%if re.match(r"urEnqueue\w+", th.make_func_name(n, tags, obj)):
299+
%if re.match(r"Enqueue\w+", func_basename):
267300
// In the event of ERROR_ADAPTER_SPECIFIC we should still attempt to wrap any output handles below.
268301
if( ${X}_RESULT_SUCCESS != result && ${X}_RESULT_ERROR_ADAPTER_SPECIFIC != result )
269302
return result;

source/loader/ur_ldrddi.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "ur_loader.hpp"
1414

1515
namespace ur_loader {
16+
1617
///////////////////////////////////////////////////////////////////////////////
1718
/// @brief Intercept function for urAdapterGet
1819
__urdlllocal ur_result_t UR_APICALL urAdapterGet(
@@ -4476,6 +4477,22 @@ __urdlllocal ur_result_t UR_APICALL urEventCreateWithNativeHandle(
44764477
return result;
44774478
}
44784479

4480+
namespace {
4481+
struct event_callback_wrapper_data_t {
4482+
ur_event_callback_t fn;
4483+
ur_event_handle_t event;
4484+
void *userData;
4485+
};
4486+
4487+
void event_callback_wrapper([[maybe_unused]] ur_event_handle_t hEvent,
4488+
ur_execution_info_t execStatus, void *pUserData) {
4489+
auto *wrapper =
4490+
reinterpret_cast<event_callback_wrapper_data_t *>(pUserData);
4491+
(wrapper->fn)(wrapper->event, execStatus, wrapper->userData);
4492+
delete wrapper;
4493+
}
4494+
} // namespace
4495+
44794496
///////////////////////////////////////////////////////////////////////////////
44804497
/// @brief Intercept function for urEventSetCallback
44814498
__urdlllocal ur_result_t UR_APICALL urEventSetCallback(
@@ -4489,6 +4506,13 @@ __urdlllocal ur_result_t UR_APICALL urEventSetCallback(
44894506

44904507
[[maybe_unused]] auto context = getContext();
44914508

4509+
// Replace the callback with a wrapper function that gives the callback the loader event rather than a
4510+
// backend-specific event
4511+
auto wrapper_data =
4512+
new event_callback_wrapper_data_t{pfnNotify, hEvent, pUserData};
4513+
pUserData = wrapper_data;
4514+
pfnNotify = event_callback_wrapper;
4515+
44924516
// extract platform's function pointer table
44934517
auto dditable = reinterpret_cast<ur_event_object_t *>(hEvent)->dditable;
44944518
auto pfnSetCallback = dditable->ur.Event.pfnSetCallback;

test/conformance/event/urEventSetCallback.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ TEST_P(urEventSetCallbackTest, Success) {
4141
*/
4242
TEST_P(urEventSetCallbackTest, ValidateParameters) {
4343
UUR_KNOWN_FAILURE_ON(uur::CUDA{}, uur::HIP{}, uur::LevelZero{},
44-
uur::LevelZeroV2{}, uur::OpenCL{}, uur::NativeCPU{});
44+
uur::LevelZeroV2{}, uur::NativeCPU{});
4545

4646
struct CallbackParameters {
4747
ur_event_handle_t event;

0 commit comments

Comments
 (0)