Skip to content

Commit 99277f6

Browse files
authored
[UR] urPlatformGet() takes only 1 adapter. (#17876)
Fixes #17504 .
1 parent 546b114 commit 99277f6

File tree

41 files changed

+222
-276
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+222
-276
lines changed

sycl/source/detail/adapter.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,10 @@ class Adapter {
9898
std::vector<ur_platform_handle_t> &getUrPlatforms() {
9999
std::call_once(PlatformsPopulated, [&]() {
100100
uint32_t platformCount = 0;
101-
call<UrApiKind::urPlatformGet>(&MAdapter, 1, 0, nullptr, &platformCount);
101+
call<UrApiKind::urPlatformGet>(MAdapter, 0, nullptr, &platformCount);
102102
UrPlatforms.resize(platformCount);
103103
if (platformCount) {
104-
call<UrApiKind::urPlatformGet>(&MAdapter, 1, platformCount,
104+
call<UrApiKind::urPlatformGet>(MAdapter, platformCount,
105105
UrPlatforms.data(), nullptr);
106106
}
107107
// We need one entry in this per platform

unified-runtime/examples/codegen/codegen.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,21 @@ get_supported_adapters(std::vector<ur_adapter_handle_t> &adapters) {
6464
std::vector<ur_platform_handle_t>
6565
get_platforms(std::vector<ur_adapter_handle_t> &adapters) {
6666
uint32_t platformCount = 0;
67-
ur_check(urPlatformGet(adapters.data(), adapters.size(), 1, nullptr,
68-
&platformCount));
67+
std::vector<ur_platform_handle_t> platforms;
68+
for (auto adapter : adapters) {
69+
uint32_t adapterPlatformCount = 0;
70+
urPlatformGet(adapter, 0, nullptr, &adapterPlatformCount);
6971

72+
platforms.reserve(platformCount + adapterPlatformCount);
73+
urPlatformGet(adapter, adapterPlatformCount, &platforms[platformCount],
74+
&adapterPlatformCount);
75+
platformCount += adapterPlatformCount;
76+
}
7077
if (!platformCount) {
7178
throw std::runtime_error("No platforms available.");
7279
}
80+
platforms.resize(platformCount);
7381

74-
std::vector<ur_platform_handle_t> platforms(platformCount);
75-
ur_check(urPlatformGet(adapters.data(), adapters.size(), platformCount,
76-
platforms.data(), nullptr));
7782
return platforms;
7883
}
7984

unified-runtime/examples/hello_world/hello_world.cpp

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,22 +47,26 @@ int main(int, char *[]) {
4747
return 1;
4848
}
4949

50-
status =
51-
urPlatformGet(adapters.data(), adapterCount, 1, nullptr, &platformCount);
52-
if (status != UR_RESULT_SUCCESS) {
53-
std::cout << "urPlatformGet failed with return code: " << status
54-
<< std::endl;
55-
goto out;
56-
}
50+
for (auto adapter : adapters) {
51+
uint32_t adapterPlatformCount = 0;
52+
status = urPlatformGet(adapter, 0, nullptr, &adapterPlatformCount);
53+
if (status != UR_RESULT_SUCCESS) {
54+
std::cout << "urPlatformGet failed with return code: " << status
55+
<< std::endl;
56+
goto out;
57+
}
5758

58-
platforms.resize(platformCount);
59-
status = urPlatformGet(adapters.data(), adapterCount, platformCount,
60-
platforms.data(), nullptr);
61-
if (status != UR_RESULT_SUCCESS) {
62-
std::cout << "urPlatformGet failed with return code: " << status
63-
<< std::endl;
64-
goto out;
59+
platforms.reserve(platformCount + adapterPlatformCount);
60+
status = urPlatformGet(adapter, adapterPlatformCount,
61+
&platforms[platformCount], &adapterPlatformCount);
62+
if (status != UR_RESULT_SUCCESS) {
63+
std::cout << "urPlatformGet failed with return code: " << status
64+
<< std::endl;
65+
goto out;
66+
}
67+
platformCount += adapterPlatformCount;
6568
}
69+
platforms.resize(platformCount);
6670

6771
for (auto p : platforms) {
6872
ur_api_version_t api_version = {};

unified-runtime/include/ur_api.h

Lines changed: 5 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

unified-runtime/include/ur_ddi.h

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

unified-runtime/include/ur_print.hpp

Lines changed: 2 additions & 18 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

unified-runtime/scripts/core/PROG.rst

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,16 @@ Initialization and Discovery
5858
${x}AdapterGet(adapterCount, adapters.data(), nullptr);
5959
6060
// Discover all the platform instances
61-
uint32_t platformCount = 0;
62-
${x}PlatformGet(adapters.data(), adapterCount, 0, nullptr, &platformCount);
63-
64-
std::vector<${x}_platform_handle_t> platforms(platformCount);
65-
${x}PlatformGet(adapters.data(), adapterCount, platform.size(), platforms.data(), &platformCount);
61+
std::vector<${x}_platform_handle_t> platforms;
62+
uint32_t totalPlatformCount = 0;
63+
for (auto adapter : adapters) {
64+
uint32_t adapterPlatformCount = 0;
65+
${x}PlatformGet(adapter, 0, nullptr, &adapterPlatformCount);
66+
67+
platforms.reserve(totalPlatformCount + adapterPlatformCount);
68+
${x}PlatformGet(adapter, adapterPlatformCount, &platforms[totalPlatformCount], &adapterPlatformCount);
69+
totalPlatformCount += adapterPlatformCount;
70+
}
6671
6772
// Get number of total GPU devices in the platform
6873
uint32_t deviceCount = 0;

unified-runtime/scripts/core/platform.yml

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,9 @@ details:
2424
- "Multiple calls to this function will return identical platforms handles, in the same order."
2525
- "The application may call this function from simultaneous threads, the implementation must be thread-safe"
2626
params:
27-
- type: "$x_adapter_handle_t*"
28-
name: "phAdapters"
29-
desc: "[in][range(0, NumAdapters)] array of adapters to query for platforms."
30-
- type: "uint32_t"
31-
name: "NumAdapters"
32-
desc: "[in] number of adapters pointed to by phAdapters"
27+
- type: "$x_adapter_handle_t"
28+
name: "hAdapter"
29+
desc: "[in] adapter to query for platforms."
3330
- type: "uint32_t"
3431
name: NumEntries
3532
desc: |

unified-runtime/scripts/templates/ldrddi.cpp.mako

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -101,49 +101,39 @@ namespace ur_loader
101101
}
102102

103103
%elif func_basename == "PlatformGet":
104-
uint32_t total_platform_handle_count = 0;
105104

106-
for( uint32_t adapter_index = 0; adapter_index < ${obj['params'][1]['name']}; adapter_index++)
107-
{
108-
// extract adapter's function pointer table
109-
auto dditable =
110-
reinterpret_cast<${n}_platform_object_t *>( ${obj['params'][0]['name']}[adapter_index])->dditable;
105+
// extract adapter's function pointer table
106+
auto dditable =
107+
reinterpret_cast<${n}_platform_object_t *>( ${obj['params'][0]['name']})->dditable;
111108

112-
if( ( 0 < ${obj['params'][2]['name']} ) && ( ${obj['params'][2]['name']} == total_platform_handle_count))
113-
break;
109+
uint32_t library_platform_handle_count = 0;
114110

115-
uint32_t library_platform_handle_count = 0;
111+
result = dditable->${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}( ${obj['params'][0]['name']}, 0, nullptr, &library_platform_handle_count );
112+
if( ${X}_RESULT_SUCCESS != result ) return result;
116113

117-
result = dditable->${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}( &${obj['params'][0]['name']}[adapter_index], 1, 0, nullptr, &library_platform_handle_count );
118-
if( ${X}_RESULT_SUCCESS != result ) break;
114+
if( nullptr != ${obj['params'][2]['name']} && ${obj['params'][1]['name']} !=0)
115+
{
116+
if( library_platform_handle_count > ${obj['params'][1]['name']}) {
117+
library_platform_handle_count = ${obj['params'][1]['name']};
118+
}
119+
result = dditable->${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}( ${obj['params'][0]['name']}, library_platform_handle_count, ${obj['params'][2]['name']}, nullptr );
120+
if( ${X}_RESULT_SUCCESS != result ) return result;
119121

120-
if( nullptr != ${obj['params'][3]['name']} && ${obj['params'][2]['name']} !=0)
122+
try
121123
{
122-
if( total_platform_handle_count + library_platform_handle_count > ${obj['params'][2]['name']}) {
123-
library_platform_handle_count = ${obj['params'][2]['name']} - total_platform_handle_count;
124-
}
125-
result = dditable->${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}( &${obj['params'][0]['name']}[adapter_index], 1, library_platform_handle_count, &${obj['params'][3]['name']}[ total_platform_handle_count ], nullptr );
126-
if( ${X}_RESULT_SUCCESS != result ) break;
127-
128-
try
129-
{
130-
for( uint32_t i = 0; i < library_platform_handle_count; ++i ) {
131-
uint32_t platform_index = total_platform_handle_count + i;
132-
${obj['params'][3]['name']}[ platform_index ] = reinterpret_cast<${n}_platform_handle_t>(
133-
context->factories.${n}_platform_factory.getInstance( ${obj['params'][3]['name']}[ platform_index ], dditable ) );
134-
}
135-
}
136-
catch( std::bad_alloc& )
137-
{
138-
result = ${X}_RESULT_ERROR_OUT_OF_HOST_MEMORY;
124+
for( uint32_t i = 0; i < library_platform_handle_count; ++i ) {
125+
${obj['params'][2]['name']}[ i ] = reinterpret_cast<${n}_platform_handle_t>(
126+
context->factories.${n}_platform_factory.getInstance( ${obj['params'][2]['name']}[ i ], dditable ) );
139127
}
140128
}
141-
142-
total_platform_handle_count += library_platform_handle_count;
129+
catch( std::bad_alloc& )
130+
{
131+
result = ${X}_RESULT_ERROR_OUT_OF_HOST_MEMORY;
132+
}
143133
}
144134

145-
if( ${X}_RESULT_SUCCESS == result && ${obj['params'][4]['name']} != nullptr )
146-
*${obj['params'][4]['name']} = total_platform_handle_count;
135+
if( ${X}_RESULT_SUCCESS == result && ${obj['params'][3]['name']} != nullptr )
136+
*${obj['params'][3]['name']} = library_platform_handle_count;
147137

148138
%else:
149139
<%param_replacements={}%>

unified-runtime/source/adapters/cuda/device.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1285,15 +1285,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
12851285
// Get list of platforms
12861286
uint32_t NumPlatforms = 0;
12871287
ur_adapter_handle_t AdapterHandle = &adapter;
1288-
ur_result_t Result =
1289-
urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms);
1288+
ur_result_t Result = urPlatformGet(AdapterHandle, 0, nullptr, &NumPlatforms);
12901289
if (Result != UR_RESULT_SUCCESS)
12911290
return Result;
12921291

12931292
std::vector<ur_platform_handle_t> Platforms(NumPlatforms);
12941293

12951294
Result =
1296-
urPlatformGet(&AdapterHandle, 1, NumPlatforms, Platforms.data(), nullptr);
1295+
urPlatformGet(AdapterHandle, NumPlatforms, Platforms.data(), nullptr);
12971296
if (Result != UR_RESULT_SUCCESS)
12981297
return Result;
12991298

unified-runtime/source/adapters/cuda/platform.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGetInfo(
114114
/// Triggers the CUDA Driver initialization (cuInit) the first time, so this
115115
/// must be the first PI API called.
116116
UR_APIEXPORT ur_result_t UR_APICALL
117-
urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries,
117+
urPlatformGet(ur_adapter_handle_t, uint32_t NumEntries,
118118
ur_platform_handle_t *phPlatforms, uint32_t *pNumPlatforms) {
119119

120120
try {

unified-runtime/source/adapters/cuda/usm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,
243243
// cuda backend has only one platform containing all devices
244244
ur_platform_handle_t platform;
245245
ur_adapter_handle_t AdapterHandle = &adapter;
246-
Result = urPlatformGet(&AdapterHandle, 1, 1, &platform, nullptr);
246+
Result = urPlatformGet(AdapterHandle, 1, &platform, nullptr);
247247

248248
// get the device from the platform
249249
ur_device_handle_t Device = platform->Devices[DeviceIndex].get();

unified-runtime/source/adapters/hip/device.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,8 +1182,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
11821182
// Get list of platforms
11831183
uint32_t NumPlatforms = 0;
11841184
ur_adapter_handle_t AdapterHandle = &adapter;
1185-
ur_result_t Result =
1186-
urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms);
1185+
ur_result_t Result = urPlatformGet(AdapterHandle, 0, nullptr, &NumPlatforms);
11871186
if (Result != UR_RESULT_SUCCESS)
11881187
return Result;
11891188

@@ -1193,7 +1192,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
11931192

11941193
ur_platform_handle_t Platform = nullptr;
11951194

1196-
Result = urPlatformGet(&AdapterHandle, 1, NumPlatforms, &Platform, nullptr);
1195+
Result = urPlatformGet(AdapterHandle, NumPlatforms, &Platform, nullptr);
11971196
if (Result != UR_RESULT_SUCCESS)
11981197
return Result;
11991198

unified-runtime/source/adapters/hip/platform.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ urPlatformGetInfo(ur_platform_handle_t, ur_platform_info_t propName,
5050
/// Triggers the HIP Driver initialization (hipInit) the first time, so this
5151
/// must be the first UR API called.
5252
UR_APIEXPORT ur_result_t UR_APICALL
53-
urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries,
53+
urPlatformGet(ur_adapter_handle_t, uint32_t NumEntries,
5454
ur_platform_handle_t *phPlatforms, uint32_t *pNumPlatforms) {
5555

5656
try {

unified-runtime/source/adapters/hip/usm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,
200200
// hip backend has only one platform containing all devices
201201
ur_platform_handle_t platform;
202202
ur_adapter_handle_t AdapterHandle = &adapter;
203-
UR_CHECK_ERROR(urPlatformGet(&AdapterHandle, 1, 1, &platform, nullptr));
203+
UR_CHECK_ERROR(urPlatformGet(AdapterHandle, 1, &platform, nullptr));
204204

205205
// get the device from the platform
206206
ur_device_handle_t Device = platform->Devices[DeviceIdx].get();

unified-runtime/source/adapters/level_zero/platform.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
namespace ur::level_zero {
1616

1717
ur_result_t urPlatformGet(
18-
ur_adapter_handle_t *, uint32_t,
18+
ur_adapter_handle_t,
1919
/// [in] the number of platforms to be added to phPlatforms. If phPlatforms
2020
/// is not NULL, then NumEntries should be greater than zero, otherwise
2121
/// ::UR_RESULT_ERROR_INVALID_SIZE, will be returned.
@@ -141,12 +141,12 @@ ur_result_t urPlatformCreateWithNativeHandle(
141141

142142
uint32_t NumPlatforms = 0;
143143
ur_adapter_handle_t AdapterHandle = GlobalAdapter;
144-
UR_CALL(ur::level_zero::urPlatformGet(&AdapterHandle, 1, 0, nullptr,
145-
&NumPlatforms));
144+
UR_CALL(
145+
ur::level_zero::urPlatformGet(AdapterHandle, 0, nullptr, &NumPlatforms));
146146

147147
if (NumPlatforms) {
148148
std::vector<ur_platform_handle_t> Platforms(NumPlatforms);
149-
UR_CALL(ur::level_zero::urPlatformGet(&AdapterHandle, 1, NumPlatforms,
149+
UR_CALL(ur::level_zero::urPlatformGet(AdapterHandle, NumPlatforms,
150150
Platforms.data(), nullptr));
151151

152152
// The SYCL spec requires that the set of platforms must remain fixed for

unified-runtime/source/adapters/level_zero/queue.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -786,8 +786,8 @@ ur_result_t urQueueCreateWithNativeHandle(
786786
uint32_t NumEntries = 1;
787787
ur_platform_handle_t Platform{};
788788
ur_adapter_handle_t AdapterHandle = GlobalAdapter;
789-
UR_CALL(ur::level_zero::urPlatformGet(&AdapterHandle, 1, NumEntries,
790-
&Platform, nullptr));
789+
UR_CALL(ur::level_zero::urPlatformGet(AdapterHandle, NumEntries, &Platform,
790+
nullptr));
791791

792792
ur_device_handle_t UrDevice = Device;
793793
if (UrDevice == nullptr) {

unified-runtime/source/adapters/level_zero/ur_interface_loader.hpp

Lines changed: 1 addition & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)