17
17
#include " device.hpp"
18
18
#include " platform.hpp"
19
19
#include " ur/ur.hpp"
20
+ #include " ur2offload.hpp"
20
21
#include " ur_api.h"
21
22
22
- ur_adapter_handle_t_ Adapter{} ;
23
+ ur_adapter_handle_t Adapter = nullptr ;
23
24
24
25
// Initialize liboffload and perform the initial platform and device discovery
25
26
ur_result_t ur_adapter_handle_t_::init () {
@@ -30,7 +31,7 @@ ur_result_t ur_adapter_handle_t_::init() {
30
31
Res = olIterateDevices (
31
32
[](ol_device_handle_t D, void *UserData) {
32
33
auto *Platforms =
33
- reinterpret_cast <decltype (Adapter. Platforms ) *>(UserData);
34
+ reinterpret_cast <decltype (Adapter-> Platforms ) *>(UserData);
34
35
35
36
ol_platform_handle_t Platform;
36
37
olGetDeviceInfo (D, OL_DEVICE_INFO_PLATFORM, sizeof (Platform),
@@ -39,7 +40,7 @@ ur_result_t ur_adapter_handle_t_::init() {
39
40
olGetPlatformInfo (Platform, OL_PLATFORM_INFO_BACKEND, sizeof (Backend),
40
41
&Backend);
41
42
if (Backend == OL_PLATFORM_BACKEND_HOST) {
42
- Adapter. HostDevice = D;
43
+ Adapter-> HostDevice = D;
43
44
} else if (Backend != OL_PLATFORM_BACKEND_UNKNOWN) {
44
45
auto URPlatform =
45
46
std::find_if (Platforms->begin (), Platforms->end (), [&](auto &P) {
@@ -57,37 +58,52 @@ ur_result_t ur_adapter_handle_t_::init() {
57
58
}
58
59
return false ;
59
60
},
60
- &Adapter. Platforms );
61
+ &Adapter-> Platforms );
61
62
62
- (void )Res;
63
-
64
- return UR_RESULT_SUCCESS;
63
+ return offloadResultToUR (Res);
65
64
}
66
65
67
66
UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet (
68
67
uint32_t , ur_adapter_handle_t *phAdapters, uint32_t *pNumAdapters) {
68
+ std::mutex InitMutex{};
69
+
69
70
if (phAdapters) {
70
- if (++Adapter.RefCount == 1 ) {
71
- Adapter.init ();
71
+ std::lock_guard Guard{InitMutex};
72
+
73
+ // We explicitly only initialize the adapter when outputting it
74
+ if (!Adapter) {
75
+ Adapter = new ur_adapter_handle_t_{};
76
+ auto Res = Adapter->init ();
77
+ if (Res) {
78
+ delete Adapter;
79
+ Adapter = nullptr ;
80
+ return Res;
81
+ }
72
82
}
73
- *phAdapters = &Adapter;
83
+ Adapter->RefCount ++;
84
+ *phAdapters = Adapter;
74
85
}
86
+
75
87
if (pNumAdapters) {
76
88
*pNumAdapters = 1 ;
77
89
}
90
+
78
91
return UR_RESULT_SUCCESS;
79
92
}
80
93
81
94
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease (ur_adapter_handle_t ) {
82
- if (--Adapter.RefCount == 0 ) {
95
+ // Doesn't need protecting by a lock - There is no way to reinitialize the
96
+ // adapter after the final reference is released
97
+ if (--Adapter->RefCount == 0 ) {
83
98
// This can crash when tracing is enabled.
84
99
// olShutDown();
100
+ delete Adapter;
85
101
};
86
102
return UR_RESULT_SUCCESS;
87
103
}
88
104
89
105
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain (ur_adapter_handle_t ) {
90
- Adapter. RefCount ++;
106
+ Adapter-> RefCount ++;
91
107
return UR_RESULT_SUCCESS;
92
108
}
93
109
@@ -102,7 +118,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
102
118
case UR_ADAPTER_INFO_BACKEND:
103
119
return ReturnValue (UR_BACKEND_OFFLOAD);
104
120
case UR_ADAPTER_INFO_REFERENCE_COUNT:
105
- return ReturnValue (Adapter. RefCount .load ());
121
+ return ReturnValue (Adapter-> RefCount .load ());
106
122
case UR_ADAPTER_INFO_VERSION:
107
123
return ReturnValue (1 );
108
124
default :
@@ -124,15 +140,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterSetLoggerCallback(
124
140
ur_adapter_handle_t , ur_logger_callback_t pfnLoggerCallback,
125
141
void *pUserData, ur_logger_level_t level = UR_LOGGER_LEVEL_QUIET) {
126
142
127
- Adapter. Logger .setCallbackSink (pfnLoggerCallback, pUserData, level);
143
+ Adapter-> Logger .setCallbackSink (pfnLoggerCallback, pUserData, level);
128
144
129
145
return UR_RESULT_SUCCESS;
130
146
}
131
147
132
148
UR_APIEXPORT ur_result_t UR_APICALL
133
149
urAdapterSetLoggerCallbackLevel (ur_adapter_handle_t , ur_logger_level_t level) {
134
150
135
- Adapter. Logger .setCallbackLevel (level);
151
+ Adapter-> Logger .setCallbackLevel (level);
136
152
137
153
return UR_RESULT_SUCCESS;
138
154
}
0 commit comments