Skip to content

Commit 071f940

Browse files
authored
[UR][CUDA][HIP] Fixup UMF error mapping (#18011)
The UMF provided CUDA adapter uses the name "CUDA" and stores native errors as `CUresult`. Changes the HIP providers to use the name "HIP". The HIP providers use `ur_result_t` as native errors.
1 parent f899fb3 commit 071f940

File tree

3 files changed

+13
-14
lines changed

3 files changed

+13
-14
lines changed

unified-runtime/source/adapters/cuda/usm.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,11 @@
2323
#include <cuda.h>
2424

2525
namespace umf {
26-
ur_result_t getProviderNativeError(const char *, int32_t) {
27-
// TODO: implement when UMF supports CUDA
26+
ur_result_t getProviderNativeError(const char *providerName, int32_t error) {
27+
if (strcmp(providerName, "CUDA") == 0) {
28+
return mapErrorUR(static_cast<CUresult>(error));
29+
}
30+
2831
return UR_RESULT_ERROR_UNKNOWN;
2932
}
3033
} // namespace umf

unified-runtime/source/adapters/hip/usm.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,13 @@
1919
#include "usm.hpp"
2020

2121
namespace umf {
22-
ur_result_t getProviderNativeError(const char *, int32_t) {
23-
// TODO: implement when UMF supports HIP
22+
ur_result_t getProviderNativeError(const char *providerName,
23+
int32_t nativeError) {
24+
if (strcmp(providerName, "HIP") == 0) {
25+
// HIP provider stores native errors of ur_result_t type
26+
return static_cast<ur_result_t>(nativeError);
27+
}
28+
2429
return UR_RESULT_ERROR_UNKNOWN;
2530
}
2631
} // namespace umf

unified-runtime/source/adapters/hip/usm.hpp

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -78,36 +78,27 @@ class USMMemoryProvider {
7878
umf_result_t allocation_split(void *, size_t, size_t) {
7979
return UMF_RESULT_ERROR_UNKNOWN;
8080
}
81-
virtual const char *get_name() = 0;
81+
const char *get_name() { return "HIP"; }
8282

8383
virtual ~USMMemoryProvider() = default;
8484
};
8585

8686
// Allocation routines for shared memory type
8787
class USMSharedMemoryProvider final : public USMMemoryProvider {
88-
public:
89-
const char *get_name() override { return "USMSharedMemoryProvider"; }
90-
9188
protected:
9289
ur_result_t allocateImpl(void **ResultPtr, size_t Size,
9390
uint32_t Alignment) override;
9491
};
9592

9693
// Allocation routines for device memory type
9794
class USMDeviceMemoryProvider final : public USMMemoryProvider {
98-
public:
99-
const char *get_name() override { return "USMSharedMemoryProvider"; }
100-
10195
protected:
10296
ur_result_t allocateImpl(void **ResultPtr, size_t Size,
10397
uint32_t Alignment) override;
10498
};
10599

106100
// Allocation routines for host memory type
107101
class USMHostMemoryProvider final : public USMMemoryProvider {
108-
public:
109-
const char *get_name() override { return "USMSharedMemoryProvider"; }
110-
111102
protected:
112103
ur_result_t allocateImpl(void **ResultPtr, size_t Size,
113104
uint32_t Alignment) override;

0 commit comments

Comments
 (0)