Skip to content

Commit 341114d

Browse files
committed
[UR][Loader] Fix handling of native handles
Native handles are created by adapters and thus are inheritently backend-specific. Loader can not assume anything about these handles, as even nullptr may be a valid value for such a handle. This patch changes two things about native handles: 1) Native handles are no longer wrapped in UR objects 2) Dispatch table is extracted from any other argument of the API function The above is true for all interop APIs except for urPlatformCreateWithNativeHandle, which needs a spec change.
1 parent a1fbbde commit 341114d

File tree

2 files changed

+16
-122
lines changed

2 files changed

+16
-122
lines changed

scripts/templates/ldrddi.cpp.mako

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,17 @@ namespace ur_loader
127127
%else:
128128
<%param_replacements={}%>
129129
%for i, item in enumerate(th.get_loader_prologue(n, tags, obj, meta)):
130-
%if 0 == i:
130+
%if not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle':
131131
// extract platform's function pointer table
132132
auto dditable = reinterpret_cast<${item['obj']}*>( ${item['pointer']}${item['name']} )->dditable;
133133
auto ${th.make_pfn_name(n, tags, obj)} = dditable->${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
134134
if( nullptr == ${th.make_pfn_name(n, tags, obj)} )
135135
return ${X}_RESULT_ERROR_UNINITIALIZED;
136136
137+
<%break%>
137138
%endif
139+
%endfor
140+
%for i, item in enumerate(th.get_loader_prologue(n, tags, obj, meta)):
138141
%if 'range' in item:
139142
<%
140143
add_local = True
@@ -143,13 +146,15 @@ namespace ur_loader
143146
for( size_t i = ${item['range'][0]}; i < ${item['range'][1]}; ++i )
144147
${item['name']}Local[ i ] = reinterpret_cast<${item['obj']}*>( ${item['name']}[ i ] )->handle;
145148
%else:
149+
%if not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle':
146150
// convert loader handle to platform handle
147151
%if item['optional']:
148152
${item['name']} = ( ${item['name']} ) ? reinterpret_cast<${item['obj']}*>( ${item['name']} )->handle : nullptr;
149153
%else:
150154
${item['name']} = reinterpret_cast<${item['obj']}*>( ${item['name']} )->handle;
151155
%endif
152156
%endif
157+
%endif
153158
154159
%endfor
155160
// forward to device-platform
@@ -170,7 +175,7 @@ namespace ur_loader
170175
%if item['release']:
171176
// release loader handle
172177
${item['factory']}.release( ${item['name']} );
173-
%else:
178+
%elif not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle':
174179
try
175180
{
176181
%if 'range' in item:

source/loader/ur_ldrddi.cpp

Lines changed: 9 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -349,14 +349,6 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGetNativeHandle(
349349
return result;
350350
}
351351

352-
try {
353-
// convert platform handle to loader handle
354-
*phNativePlatform = reinterpret_cast<ur_native_handle_t>(
355-
ur_native_factory.getInstance(*phNativePlatform, dditable));
356-
} catch (std::bad_alloc &) {
357-
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
358-
}
359-
360352
return result;
361353
}
362354

@@ -670,14 +662,6 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetNativeHandle(
670662
return result;
671663
}
672664

673-
try {
674-
// convert platform handle to loader handle
675-
*phNativeDevice = reinterpret_cast<ur_native_handle_t>(
676-
ur_native_factory.getInstance(*phNativeDevice, dditable));
677-
} catch (std::bad_alloc &) {
678-
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
679-
}
680-
681665
return result;
682666
}
683667

@@ -696,17 +680,13 @@ __urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
696680

697681
// extract platform's function pointer table
698682
auto dditable =
699-
reinterpret_cast<ur_native_object_t *>(hNativeDevice)->dditable;
683+
reinterpret_cast<ur_platform_object_t *>(hPlatform)->dditable;
700684
auto pfnCreateWithNativeHandle =
701685
dditable->ur.Device.pfnCreateWithNativeHandle;
702686
if (nullptr == pfnCreateWithNativeHandle) {
703687
return UR_RESULT_ERROR_UNINITIALIZED;
704688
}
705689

706-
// convert loader handle to platform handle
707-
hNativeDevice =
708-
reinterpret_cast<ur_native_object_t *>(hNativeDevice)->handle;
709-
710690
// convert loader handle to platform handle
711691
hPlatform = reinterpret_cast<ur_platform_object_t *>(hPlatform)->handle;
712692

@@ -913,14 +893,6 @@ __urdlllocal ur_result_t UR_APICALL urContextGetNativeHandle(
913893
return result;
914894
}
915895

916-
try {
917-
// convert platform handle to loader handle
918-
*phNativeContext = reinterpret_cast<ur_native_handle_t>(
919-
ur_native_factory.getInstance(*phNativeContext, dditable));
920-
} catch (std::bad_alloc &) {
921-
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
922-
}
923-
924896
return result;
925897
}
926898

@@ -941,17 +913,13 @@ __urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle(
941913

942914
// extract platform's function pointer table
943915
auto dditable =
944-
reinterpret_cast<ur_native_object_t *>(hNativeContext)->dditable;
916+
reinterpret_cast<ur_device_object_t *>(*phDevices)->dditable;
945917
auto pfnCreateWithNativeHandle =
946918
dditable->ur.Context.pfnCreateWithNativeHandle;
947919
if (nullptr == pfnCreateWithNativeHandle) {
948920
return UR_RESULT_ERROR_UNINITIALIZED;
949921
}
950922

951-
// convert loader handle to platform handle
952-
hNativeContext =
953-
reinterpret_cast<ur_native_object_t *>(hNativeContext)->handle;
954-
955923
// convert loader handles to platform handles
956924
auto phDevicesLocal = std::vector<ur_device_handle_t>(numDevices);
957925
for (size_t i = 0; i < numDevices; ++i) {
@@ -1204,14 +1172,6 @@ __urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle(
12041172
return result;
12051173
}
12061174

1207-
try {
1208-
// convert platform handle to loader handle
1209-
*phNativeMem = reinterpret_cast<ur_native_handle_t>(
1210-
ur_native_factory.getInstance(*phNativeMem, dditable));
1211-
} catch (std::bad_alloc &) {
1212-
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
1213-
}
1214-
12151175
return result;
12161176
}
12171177

@@ -1229,17 +1189,13 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferCreateWithNativeHandle(
12291189
ur_result_t result = UR_RESULT_SUCCESS;
12301190

12311191
// extract platform's function pointer table
1232-
auto dditable =
1233-
reinterpret_cast<ur_native_object_t *>(hNativeMem)->dditable;
1192+
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
12341193
auto pfnBufferCreateWithNativeHandle =
12351194
dditable->ur.Mem.pfnBufferCreateWithNativeHandle;
12361195
if (nullptr == pfnBufferCreateWithNativeHandle) {
12371196
return UR_RESULT_ERROR_UNINITIALIZED;
12381197
}
12391198

1240-
// convert loader handle to platform handle
1241-
hNativeMem = reinterpret_cast<ur_native_object_t *>(hNativeMem)->handle;
1242-
12431199
// convert loader handle to platform handle
12441200
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;
12451201

@@ -1279,17 +1235,13 @@ __urdlllocal ur_result_t UR_APICALL urMemImageCreateWithNativeHandle(
12791235
ur_result_t result = UR_RESULT_SUCCESS;
12801236

12811237
// extract platform's function pointer table
1282-
auto dditable =
1283-
reinterpret_cast<ur_native_object_t *>(hNativeMem)->dditable;
1238+
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
12841239
auto pfnImageCreateWithNativeHandle =
12851240
dditable->ur.Mem.pfnImageCreateWithNativeHandle;
12861241
if (nullptr == pfnImageCreateWithNativeHandle) {
12871242
return UR_RESULT_ERROR_UNINITIALIZED;
12881243
}
12891244

1290-
// convert loader handle to platform handle
1291-
hNativeMem = reinterpret_cast<ur_native_object_t *>(hNativeMem)->handle;
1292-
12931245
// convert loader handle to platform handle
12941246
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;
12951247

@@ -1525,14 +1477,6 @@ __urdlllocal ur_result_t UR_APICALL urSamplerGetNativeHandle(
15251477
return result;
15261478
}
15271479

1528-
try {
1529-
// convert platform handle to loader handle
1530-
*phNativeSampler = reinterpret_cast<ur_native_handle_t>(
1531-
ur_native_factory.getInstance(*phNativeSampler, dditable));
1532-
} catch (std::bad_alloc &) {
1533-
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
1534-
}
1535-
15361480
return result;
15371481
}
15381482

@@ -1550,18 +1494,13 @@ __urdlllocal ur_result_t UR_APICALL urSamplerCreateWithNativeHandle(
15501494
ur_result_t result = UR_RESULT_SUCCESS;
15511495

15521496
// extract platform's function pointer table
1553-
auto dditable =
1554-
reinterpret_cast<ur_native_object_t *>(hNativeSampler)->dditable;
1497+
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
15551498
auto pfnCreateWithNativeHandle =
15561499
dditable->ur.Sampler.pfnCreateWithNativeHandle;
15571500
if (nullptr == pfnCreateWithNativeHandle) {
15581501
return UR_RESULT_ERROR_UNINITIALIZED;
15591502
}
15601503

1561-
// convert loader handle to platform handle
1562-
hNativeSampler =
1563-
reinterpret_cast<ur_native_object_t *>(hNativeSampler)->handle;
1564-
15651504
// convert loader handle to platform handle
15661505
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;
15671506

@@ -2601,14 +2540,6 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetNativeHandle(
26012540
return result;
26022541
}
26032542

2604-
try {
2605-
// convert platform handle to loader handle
2606-
*phNativeProgram = reinterpret_cast<ur_native_handle_t>(
2607-
ur_native_factory.getInstance(*phNativeProgram, dditable));
2608-
} catch (std::bad_alloc &) {
2609-
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
2610-
}
2611-
26122543
return result;
26132544
}
26142545

@@ -2626,18 +2557,13 @@ __urdlllocal ur_result_t UR_APICALL urProgramCreateWithNativeHandle(
26262557
ur_result_t result = UR_RESULT_SUCCESS;
26272558

26282559
// extract platform's function pointer table
2629-
auto dditable =
2630-
reinterpret_cast<ur_native_object_t *>(hNativeProgram)->dditable;
2560+
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
26312561
auto pfnCreateWithNativeHandle =
26322562
dditable->ur.Program.pfnCreateWithNativeHandle;
26332563
if (nullptr == pfnCreateWithNativeHandle) {
26342564
return UR_RESULT_ERROR_UNINITIALIZED;
26352565
}
26362566

2637-
// convert loader handle to platform handle
2638-
hNativeProgram =
2639-
reinterpret_cast<ur_native_object_t *>(hNativeProgram)->handle;
2640-
26412567
// convert loader handle to platform handle
26422568
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;
26432569

@@ -3085,14 +3011,6 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetNativeHandle(
30853011
return result;
30863012
}
30873013

3088-
try {
3089-
// convert platform handle to loader handle
3090-
*phNativeKernel = reinterpret_cast<ur_native_handle_t>(
3091-
ur_native_factory.getInstance(*phNativeKernel, dditable));
3092-
} catch (std::bad_alloc &) {
3093-
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
3094-
}
3095-
30963014
return result;
30973015
}
30983016

@@ -3112,18 +3030,13 @@ __urdlllocal ur_result_t UR_APICALL urKernelCreateWithNativeHandle(
31123030
ur_result_t result = UR_RESULT_SUCCESS;
31133031

31143032
// extract platform's function pointer table
3115-
auto dditable =
3116-
reinterpret_cast<ur_native_object_t *>(hNativeKernel)->dditable;
3033+
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
31173034
auto pfnCreateWithNativeHandle =
31183035
dditable->ur.Kernel.pfnCreateWithNativeHandle;
31193036
if (nullptr == pfnCreateWithNativeHandle) {
31203037
return UR_RESULT_ERROR_UNINITIALIZED;
31213038
}
31223039

3123-
// convert loader handle to platform handle
3124-
hNativeKernel =
3125-
reinterpret_cast<ur_native_object_t *>(hNativeKernel)->handle;
3126-
31273040
// convert loader handle to platform handle
31283041
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;
31293042

@@ -3297,14 +3210,6 @@ __urdlllocal ur_result_t UR_APICALL urQueueGetNativeHandle(
32973210
return result;
32983211
}
32993212

3300-
try {
3301-
// convert platform handle to loader handle
3302-
*phNativeQueue = reinterpret_cast<ur_native_handle_t>(
3303-
ur_native_factory.getInstance(*phNativeQueue, dditable));
3304-
} catch (std::bad_alloc &) {
3305-
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
3306-
}
3307-
33083213
return result;
33093214
}
33103215

@@ -3323,17 +3228,13 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreateWithNativeHandle(
33233228
ur_result_t result = UR_RESULT_SUCCESS;
33243229

33253230
// extract platform's function pointer table
3326-
auto dditable =
3327-
reinterpret_cast<ur_native_object_t *>(hNativeQueue)->dditable;
3231+
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
33283232
auto pfnCreateWithNativeHandle =
33293233
dditable->ur.Queue.pfnCreateWithNativeHandle;
33303234
if (nullptr == pfnCreateWithNativeHandle) {
33313235
return UR_RESULT_ERROR_UNINITIALIZED;
33323236
}
33333237

3334-
// convert loader handle to platform handle
3335-
hNativeQueue = reinterpret_cast<ur_native_object_t *>(hNativeQueue)->handle;
3336-
33373238
// convert loader handle to platform handle
33383239
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;
33393240

@@ -3570,14 +3471,6 @@ __urdlllocal ur_result_t UR_APICALL urEventGetNativeHandle(
35703471
return result;
35713472
}
35723473

3573-
try {
3574-
// convert platform handle to loader handle
3575-
*phNativeEvent = reinterpret_cast<ur_native_handle_t>(
3576-
ur_native_factory.getInstance(*phNativeEvent, dditable));
3577-
} catch (std::bad_alloc &) {
3578-
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
3579-
}
3580-
35813474
return result;
35823475
}
35833476

@@ -3595,17 +3488,13 @@ __urdlllocal ur_result_t UR_APICALL urEventCreateWithNativeHandle(
35953488
ur_result_t result = UR_RESULT_SUCCESS;
35963489

35973490
// extract platform's function pointer table
3598-
auto dditable =
3599-
reinterpret_cast<ur_native_object_t *>(hNativeEvent)->dditable;
3491+
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
36003492
auto pfnCreateWithNativeHandle =
36013493
dditable->ur.Event.pfnCreateWithNativeHandle;
36023494
if (nullptr == pfnCreateWithNativeHandle) {
36033495
return UR_RESULT_ERROR_UNINITIALIZED;
36043496
}
36053497

3606-
// convert loader handle to platform handle
3607-
hNativeEvent = reinterpret_cast<ur_native_object_t *>(hNativeEvent)->handle;
3608-
36093498
// convert loader handle to platform handle
36103499
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;
36113500

0 commit comments

Comments
 (0)