Skip to content

[UR] urPlatformGet() takes only 1 adapter. #17876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sycl/source/detail/adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ class Adapter {
std::vector<ur_platform_handle_t> &getUrPlatforms() {
std::call_once(PlatformsPopulated, [&]() {
uint32_t platformCount = 0;
call<UrApiKind::urPlatformGet>(&MAdapter, 1, 0, nullptr, &platformCount);
call<UrApiKind::urPlatformGet>(MAdapter, 0, nullptr, &platformCount);
UrPlatforms.resize(platformCount);
if (platformCount) {
call<UrApiKind::urPlatformGet>(&MAdapter, 1, platformCount,
call<UrApiKind::urPlatformGet>(MAdapter, platformCount,
UrPlatforms.data(), nullptr);
}
// We need one entry in this per platform
Expand Down
15 changes: 10 additions & 5 deletions unified-runtime/examples/codegen/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,21 @@ get_supported_adapters(std::vector<ur_adapter_handle_t> &adapters) {
std::vector<ur_platform_handle_t>
get_platforms(std::vector<ur_adapter_handle_t> &adapters) {
uint32_t platformCount = 0;
ur_check(urPlatformGet(adapters.data(), adapters.size(), 1, nullptr,
&platformCount));
std::vector<ur_platform_handle_t> platforms;
for (auto adapter : adapters) {
uint32_t adapterPlatformCount = 0;
urPlatformGet(adapter, 0, nullptr, &adapterPlatformCount);

platforms.reserve(platformCount + adapterPlatformCount);
urPlatformGet(adapter, adapterPlatformCount, &platforms[platformCount],
&adapterPlatformCount);
platformCount += adapterPlatformCount;
}
if (!platformCount) {
throw std::runtime_error("No platforms available.");
}
platforms.resize(platformCount);
Comment on lines +72 to +80
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is UB, the access to platforms[platformCount] is out of bounds, because platforms.size() is always 0 here. This should be using resize() and not reserve().
In practice the later .resize() will reset all platforms to nullptr and make the example crash. See #18032


std::vector<ur_platform_handle_t> platforms(platformCount);
ur_check(urPlatformGet(adapters.data(), adapters.size(), platformCount,
platforms.data(), nullptr));
return platforms;
}

Expand Down
32 changes: 18 additions & 14 deletions unified-runtime/examples/hello_world/hello_world.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,26 @@ int main(int, char *[]) {
return 1;
}

status =
urPlatformGet(adapters.data(), adapterCount, 1, nullptr, &platformCount);
if (status != UR_RESULT_SUCCESS) {
std::cout << "urPlatformGet failed with return code: " << status
<< std::endl;
goto out;
}
for (auto adapter : adapters) {
uint32_t adapterPlatformCount = 0;
status = urPlatformGet(adapter, 0, nullptr, &adapterPlatformCount);
if (status != UR_RESULT_SUCCESS) {
std::cout << "urPlatformGet failed with return code: " << status
<< std::endl;
goto out;
}

platforms.resize(platformCount);
status = urPlatformGet(adapters.data(), adapterCount, platformCount,
platforms.data(), nullptr);
if (status != UR_RESULT_SUCCESS) {
std::cout << "urPlatformGet failed with return code: " << status
<< std::endl;
goto out;
platforms.reserve(platformCount + adapterPlatformCount);
status = urPlatformGet(adapter, adapterPlatformCount,
&platforms[platformCount], &adapterPlatformCount);
if (status != UR_RESULT_SUCCESS) {
std::cout << "urPlatformGet failed with return code: " << status
<< std::endl;
goto out;
}
platformCount += adapterPlatformCount;
}
platforms.resize(platformCount);

for (auto p : platforms) {
ur_api_version_t api_version = {};
Expand Down
13 changes: 5 additions & 8 deletions unified-runtime/include/ur_api.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions unified-runtime/include/ur_ddi.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 2 additions & 18 deletions unified-runtime/include/ur_print.hpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 10 additions & 5 deletions unified-runtime/scripts/core/PROG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,16 @@ Initialization and Discovery
${x}AdapterGet(adapterCount, adapters.data(), nullptr);

// Discover all the platform instances
uint32_t platformCount = 0;
${x}PlatformGet(adapters.data(), adapterCount, 0, nullptr, &platformCount);

std::vector<${x}_platform_handle_t> platforms(platformCount);
${x}PlatformGet(adapters.data(), adapterCount, platform.size(), platforms.data(), &platformCount);
std::vector<${x}_platform_handle_t> platforms;
uint32_t totalPlatformCount = 0;
for (auto adapter : adapters) {
uint32_t adapterPlatformCount = 0;
${x}PlatformGet(adapter, 0, nullptr, &adapterPlatformCount);

platforms.reserve(totalPlatformCount + adapterPlatformCount);
${x}PlatformGet(adapter, adapterPlatformCount, &platforms[totalPlatformCount], &adapterPlatformCount);
totalPlatformCount += adapterPlatformCount;
}

// Get number of total GPU devices in the platform
uint32_t deviceCount = 0;
Expand Down
9 changes: 3 additions & 6 deletions unified-runtime/scripts/core/platform.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,9 @@ details:
- "Multiple calls to this function will return identical platforms handles, in the same order."
- "The application may call this function from simultaneous threads, the implementation must be thread-safe"
params:
- type: "$x_adapter_handle_t*"
name: "phAdapters"
desc: "[in][range(0, NumAdapters)] array of adapters to query for platforms."
- type: "uint32_t"
name: "NumAdapters"
desc: "[in] number of adapters pointed to by phAdapters"
- type: "$x_adapter_handle_t"
name: "hAdapter"
desc: "[in] adapter to query for platforms."
- type: "uint32_t"
name: NumEntries
desc: |
Expand Down
56 changes: 23 additions & 33 deletions unified-runtime/scripts/templates/ldrddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -101,49 +101,39 @@ namespace ur_loader
}

%elif func_basename == "PlatformGet":
uint32_t total_platform_handle_count = 0;

for( uint32_t adapter_index = 0; adapter_index < ${obj['params'][1]['name']}; adapter_index++)
{
// extract adapter's function pointer table
auto dditable =
reinterpret_cast<${n}_platform_object_t *>( ${obj['params'][0]['name']}[adapter_index])->dditable;
// extract adapter's function pointer table
auto dditable =
reinterpret_cast<${n}_platform_object_t *>( ${obj['params'][0]['name']})->dditable;

if( ( 0 < ${obj['params'][2]['name']} ) && ( ${obj['params'][2]['name']} == total_platform_handle_count))
break;
uint32_t library_platform_handle_count = 0;

uint32_t library_platform_handle_count = 0;
result = dditable->${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}( ${obj['params'][0]['name']}, 0, nullptr, &library_platform_handle_count );
if( ${X}_RESULT_SUCCESS != result ) return result;

result = dditable->${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}( &${obj['params'][0]['name']}[adapter_index], 1, 0, nullptr, &library_platform_handle_count );
if( ${X}_RESULT_SUCCESS != result ) break;
if( nullptr != ${obj['params'][2]['name']} && ${obj['params'][1]['name']} !=0)
{
if( library_platform_handle_count > ${obj['params'][1]['name']}) {
library_platform_handle_count = ${obj['params'][1]['name']};
}
result = dditable->${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}( ${obj['params'][0]['name']}, library_platform_handle_count, ${obj['params'][2]['name']}, nullptr );
if( ${X}_RESULT_SUCCESS != result ) return result;

if( nullptr != ${obj['params'][3]['name']} && ${obj['params'][2]['name']} !=0)
try
{
if( total_platform_handle_count + library_platform_handle_count > ${obj['params'][2]['name']}) {
library_platform_handle_count = ${obj['params'][2]['name']} - total_platform_handle_count;
}
result = dditable->${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}( &${obj['params'][0]['name']}[adapter_index], 1, library_platform_handle_count, &${obj['params'][3]['name']}[ total_platform_handle_count ], nullptr );
if( ${X}_RESULT_SUCCESS != result ) break;

try
{
for( uint32_t i = 0; i < library_platform_handle_count; ++i ) {
uint32_t platform_index = total_platform_handle_count + i;
${obj['params'][3]['name']}[ platform_index ] = reinterpret_cast<${n}_platform_handle_t>(
context->factories.${n}_platform_factory.getInstance( ${obj['params'][3]['name']}[ platform_index ], dditable ) );
}
}
catch( std::bad_alloc& )
{
result = ${X}_RESULT_ERROR_OUT_OF_HOST_MEMORY;
for( uint32_t i = 0; i < library_platform_handle_count; ++i ) {
${obj['params'][2]['name']}[ i ] = reinterpret_cast<${n}_platform_handle_t>(
context->factories.${n}_platform_factory.getInstance( ${obj['params'][2]['name']}[ i ], dditable ) );
}
}

total_platform_handle_count += library_platform_handle_count;
catch( std::bad_alloc& )
{
result = ${X}_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}
}

if( ${X}_RESULT_SUCCESS == result && ${obj['params'][4]['name']} != nullptr )
*${obj['params'][4]['name']} = total_platform_handle_count;
if( ${X}_RESULT_SUCCESS == result && ${obj['params'][3]['name']} != nullptr )
*${obj['params'][3]['name']} = library_platform_handle_count;

%else:
<%param_replacements={}%>
Expand Down
5 changes: 2 additions & 3 deletions unified-runtime/source/adapters/cuda/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1285,15 +1285,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
// Get list of platforms
uint32_t NumPlatforms = 0;
ur_adapter_handle_t AdapterHandle = &adapter;
ur_result_t Result =
urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms);
ur_result_t Result = urPlatformGet(AdapterHandle, 0, nullptr, &NumPlatforms);
if (Result != UR_RESULT_SUCCESS)
return Result;

std::vector<ur_platform_handle_t> Platforms(NumPlatforms);

Result =
urPlatformGet(&AdapterHandle, 1, NumPlatforms, Platforms.data(), nullptr);
urPlatformGet(AdapterHandle, NumPlatforms, Platforms.data(), nullptr);
if (Result != UR_RESULT_SUCCESS)
return Result;

Expand Down
2 changes: 1 addition & 1 deletion unified-runtime/source/adapters/cuda/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGetInfo(
/// Triggers the CUDA Driver initialization (cuInit) the first time, so this
/// must be the first PI API called.
UR_APIEXPORT ur_result_t UR_APICALL
urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries,
urPlatformGet(ur_adapter_handle_t, uint32_t NumEntries,
ur_platform_handle_t *phPlatforms, uint32_t *pNumPlatforms) {

try {
Expand Down
2 changes: 1 addition & 1 deletion unified-runtime/source/adapters/cuda/usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,
// cuda backend has only one platform containing all devices
ur_platform_handle_t platform;
ur_adapter_handle_t AdapterHandle = &adapter;
Result = urPlatformGet(&AdapterHandle, 1, 1, &platform, nullptr);
Result = urPlatformGet(AdapterHandle, 1, &platform, nullptr);

// get the device from the platform
ur_device_handle_t Device = platform->Devices[DeviceIndex].get();
Expand Down
5 changes: 2 additions & 3 deletions unified-runtime/source/adapters/hip/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1182,8 +1182,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
// Get list of platforms
uint32_t NumPlatforms = 0;
ur_adapter_handle_t AdapterHandle = &adapter;
ur_result_t Result =
urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms);
ur_result_t Result = urPlatformGet(AdapterHandle, 0, nullptr, &NumPlatforms);
if (Result != UR_RESULT_SUCCESS)
return Result;

Expand All @@ -1193,7 +1192,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(

ur_platform_handle_t Platform = nullptr;

Result = urPlatformGet(&AdapterHandle, 1, NumPlatforms, &Platform, nullptr);
Result = urPlatformGet(AdapterHandle, NumPlatforms, &Platform, nullptr);
if (Result != UR_RESULT_SUCCESS)
return Result;

Expand Down
2 changes: 1 addition & 1 deletion unified-runtime/source/adapters/hip/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ urPlatformGetInfo(ur_platform_handle_t, ur_platform_info_t propName,
/// Triggers the HIP Driver initialization (hipInit) the first time, so this
/// must be the first UR API called.
UR_APIEXPORT ur_result_t UR_APICALL
urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries,
urPlatformGet(ur_adapter_handle_t, uint32_t NumEntries,
ur_platform_handle_t *phPlatforms, uint32_t *pNumPlatforms) {

try {
Expand Down
2 changes: 1 addition & 1 deletion unified-runtime/source/adapters/hip/usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,
// hip backend has only one platform containing all devices
ur_platform_handle_t platform;
ur_adapter_handle_t AdapterHandle = &adapter;
UR_CHECK_ERROR(urPlatformGet(&AdapterHandle, 1, 1, &platform, nullptr));
UR_CHECK_ERROR(urPlatformGet(AdapterHandle, 1, &platform, nullptr));

// get the device from the platform
ur_device_handle_t Device = platform->Devices[DeviceIdx].get();
Expand Down
8 changes: 4 additions & 4 deletions unified-runtime/source/adapters/level_zero/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
namespace ur::level_zero {

ur_result_t urPlatformGet(
ur_adapter_handle_t *, uint32_t,
ur_adapter_handle_t,
/// [in] the number of platforms to be added to phPlatforms. If phPlatforms
/// is not NULL, then NumEntries should be greater than zero, otherwise
/// ::UR_RESULT_ERROR_INVALID_SIZE, will be returned.
Expand Down Expand Up @@ -141,12 +141,12 @@ ur_result_t urPlatformCreateWithNativeHandle(

uint32_t NumPlatforms = 0;
ur_adapter_handle_t AdapterHandle = GlobalAdapter;
UR_CALL(ur::level_zero::urPlatformGet(&AdapterHandle, 1, 0, nullptr,
&NumPlatforms));
UR_CALL(
ur::level_zero::urPlatformGet(AdapterHandle, 0, nullptr, &NumPlatforms));

if (NumPlatforms) {
std::vector<ur_platform_handle_t> Platforms(NumPlatforms);
UR_CALL(ur::level_zero::urPlatformGet(&AdapterHandle, 1, NumPlatforms,
UR_CALL(ur::level_zero::urPlatformGet(AdapterHandle, NumPlatforms,
Platforms.data(), nullptr));

// The SYCL spec requires that the set of platforms must remain fixed for
Expand Down
4 changes: 2 additions & 2 deletions unified-runtime/source/adapters/level_zero/queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -786,8 +786,8 @@ ur_result_t urQueueCreateWithNativeHandle(
uint32_t NumEntries = 1;
ur_platform_handle_t Platform{};
ur_adapter_handle_t AdapterHandle = GlobalAdapter;
UR_CALL(ur::level_zero::urPlatformGet(&AdapterHandle, 1, NumEntries,
&Platform, nullptr));
UR_CALL(ur::level_zero::urPlatformGet(AdapterHandle, NumEntries, &Platform,
nullptr));

ur_device_handle_t UrDevice = Device;
if (UrDevice == nullptr) {
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading