Skip to content

Commit 66784dc

Browse files
authored
[OpenMP] Ensure Devices is accessed exlusively (llvm#74374)
We accessed the `Devices` container most of the time while holding the RTLsMtx, but not always. Sometimes we used the mutex for the size query, but then accessed Devices again unguarded. From now we properly encapsulate the container in a ProtectedObj which ensures exclusive accesses. We also hide the "isReady" part in the `getDevice` accessor and use an `llvm::Expected` to allow to return errors.
1 parent d6f4d52 commit 66784dc

File tree

9 files changed

+237
-226
lines changed

9 files changed

+237
-226
lines changed

openmp/libomptarget/include/PluginManager.h

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define OMPTARGET_PLUGIN_MANAGER_H
1515

1616
#include "DeviceImage.h"
17+
#include "ExclusiveAccess.h"
1718
#include "Shared/APITypes.h"
1819
#include "Shared/PluginAPI.h"
1920
#include "Shared/Requirements.h"
@@ -25,6 +26,7 @@
2526
#include "llvm/ADT/iterator.h"
2627
#include "llvm/ADT/iterator_range.h"
2728
#include "llvm/Support/DynamicLibrary.h"
29+
#include "llvm/Support/Error.h"
2830

2931
#include <cstdint>
3032
#include <list>
@@ -75,6 +77,13 @@ struct PluginAdaptorTy {
7577

7678
/// Struct for the data required to handle plugins
7779
struct PluginManager {
80+
/// Type of the devices container. We hand out DeviceTy& to queries which are
81+
/// stable addresses regardless if the container changes.
82+
using DeviceContainerTy = llvm::SmallVector<std::unique_ptr<DeviceTy>>;
83+
84+
/// Exclusive accessor type for the device container.
85+
using ExclusiveDevicesAccessorTy = Accessor<DeviceContainerTy>;
86+
7887
PluginManager() {}
7988

8089
void init();
@@ -89,13 +98,19 @@ struct PluginManager {
8998
DeviceImages.emplace_back(std::make_unique<DeviceImageTy>(TgtBinDesc, TgtDeviceImage));
9099
}
91100

101+
/// Return the device presented to the user as device \p DeviceNo if it is
102+
/// initialized and ready. Otherwise return an error explaining the problem.
103+
llvm::Expected<DeviceTy &> getDevice(uint32_t DeviceNo);
104+
105+
/// Iterate over all initialized and ready devices registered with this
106+
/// plugin.
107+
auto devices(ExclusiveDevicesAccessorTy &DevicesAccessor) {
108+
return llvm::make_pointee_range(*DevicesAccessor);
109+
}
110+
92111
/// Iterate over all device images registered with this plugin.
93112
auto deviceImages() { return llvm::make_pointee_range(DeviceImages); }
94113

95-
/// Devices associated with RTLs
96-
llvm::SmallVector<std::unique_ptr<DeviceTy>> Devices;
97-
std::mutex RTLsMtx; ///< For RTLs and Devices
98-
99114
/// Translation table retreived from the binary
100115
HostEntriesBeginToTransTableTy HostEntriesBeginToTransTable;
101116
std::mutex TrlTblMtx; ///< For Translation Table
@@ -124,9 +139,12 @@ struct PluginManager {
124139
DelayedBinDesc.clear();
125140
}
126141

127-
int getNumDevices() {
128-
std::lock_guard<decltype(RTLsMtx)> Lock(RTLsMtx);
129-
return Devices.size();
142+
/// Return the number of usable devices.
143+
int getNumDevices() { return getExclusiveDevicesAccessor()->size(); }
144+
145+
/// Return an exclusive handle to access the devices container.
146+
ExclusiveDevicesAccessorTy getExclusiveDevicesAccessor() {
147+
return Devices.getExclusiveAccessor();
130148
}
131149

132150
int getNumUsedPlugins() const {
@@ -166,6 +184,11 @@ struct PluginManager {
166184

167185
/// The user provided requirements.
168186
RequirementCollection Requirements;
187+
188+
std::mutex RTLsMtx; ///< For RTLs
189+
190+
/// Devices associated with plugins, accesses to the container are exclusive.
191+
ProtectedObj<DeviceContainerTy> Devices;
169192
};
170193

171194
extern PluginManager *PM;

openmp/libomptarget/include/Shared/Debug.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,15 +115,16 @@ inline uint32_t getDebugLevel() {
115115
/// Print fatal error message with an error string and error identifier
116116
#define FATAL_MESSAGE0(_num, _str) \
117117
do { \
118-
fprintf(stderr, GETNAME(TARGET_NAME) " fatal error %d: %s\n", _num, _str); \
118+
fprintf(stderr, GETNAME(TARGET_NAME) " fatal error %d: %s\n", (int)_num, \
119+
_str); \
119120
abort(); \
120121
} while (0)
121122

122123
/// Print fatal error message with a printf string and error identifier
123124
#define FATAL_MESSAGE(_num, _str, ...) \
124125
do { \
125-
fprintf(stderr, GETNAME(TARGET_NAME) " fatal error %d: " _str "\n", _num, \
126-
__VA_ARGS__); \
126+
fprintf(stderr, GETNAME(TARGET_NAME) " fatal error %d: " _str "\n", \
127+
(int)_num, __VA_ARGS__); \
127128
abort(); \
128129
} while (0)
129130

openmp/libomptarget/include/device.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,8 @@ struct DeviceTy {
202202
/// completed and AsyncInfo.isDone() returns true.
203203
int32_t queryAsync(AsyncInfoTy &AsyncInfo);
204204

205-
/// Calls the corresponding print in the \p RTLDEVID
206-
/// device RTL to obtain the information of the specific device.
207-
bool printDeviceInfo(int32_t RTLDevID);
205+
/// Calls the corresponding print device info function in the plugin.
206+
bool printDeviceInfo();
208207

209208
/// Event related interfaces.
210209
/// {
@@ -245,6 +244,4 @@ struct DeviceTy {
245244
llvm::DenseMap<llvm::StringRef, OffloadEntryTy *> DeviceOffloadEntries;
246245
};
247246

248-
extern bool deviceIsReady(int DeviceNum);
249-
250247
#endif

openmp/libomptarget/src/OpenMP/InteropAPI.cpp

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
#include "PluginManager.h"
1414
#include "device.h"
1515
#include "omptarget.h"
16+
#include "llvm/Support/Error.h"
17+
#include <cstdlib>
18+
#include <cstring>
1619

1720
extern "C" {
1821

@@ -190,6 +193,14 @@ __OMP_GET_INTEROP_TY3(const char *, type_desc)
190193
__OMP_GET_INTEROP_TY3(const char *, rc_desc)
191194
#undef __OMP_GET_INTEROP_TY3
192195

196+
static const char *copyErrorString(llvm::Error &&Err) {
197+
// TODO: Use the error string while avoiding leaks.
198+
std::string ErrMsg = llvm::toString(std::move(Err));
199+
char *UsrMsg = reinterpret_cast<char *>(malloc(ErrMsg.size() + 1));
200+
strcpy(UsrMsg, ErrMsg.c_str());
201+
return UsrMsg;
202+
};
203+
193204
extern "C" {
194205

195206
void __tgt_interop_init(ident_t *LocRef, int32_t Gtid,
@@ -211,12 +222,14 @@ void __tgt_interop_init(ident_t *LocRef, int32_t Gtid,
211222
}
212223

213224
InteropPtr = new omp_interop_val_t(DeviceId, InteropType);
214-
if (!deviceIsReady(DeviceId)) {
215-
InteropPtr->err_str = "Device not ready!";
225+
226+
auto DeviceOrErr = PM->getDevice(DeviceId);
227+
if (!DeviceOrErr) {
228+
InteropPtr->err_str = copyErrorString(DeviceOrErr.takeError());
216229
return;
217230
}
218231

219-
DeviceTy &Device = *PM->Devices[DeviceId];
232+
DeviceTy &Device = *DeviceOrErr;
220233
if (!Device.RTL || !Device.RTL->init_device_info ||
221234
Device.RTL->init_device_info(DeviceId, &(InteropPtr)->device_info,
222235
&(InteropPtr)->err_str)) {
@@ -248,8 +261,9 @@ void __tgt_interop_use(ident_t *LocRef, int32_t Gtid,
248261
assert((DeviceId == -1 || InteropVal->device_id == DeviceId) &&
249262
"Inconsistent device-id usage!");
250263

251-
if (!deviceIsReady(DeviceId)) {
252-
InteropPtr->err_str = "Device not ready!";
264+
auto DeviceOrErr = PM->getDevice(DeviceId);
265+
if (!DeviceOrErr) {
266+
InteropPtr->err_str = copyErrorString(DeviceOrErr.takeError());
253267
return;
254268
}
255269

@@ -277,8 +291,9 @@ void __tgt_interop_destroy(ident_t *LocRef, int32_t Gtid,
277291

278292
assert((DeviceId == -1 || InteropVal->device_id == DeviceId) &&
279293
"Inconsistent device-id usage!");
280-
if (!deviceIsReady(DeviceId)) {
281-
InteropPtr->err_str = "Device not ready!";
294+
auto DeviceOrErr = PM->getDevice(DeviceId);
295+
if (!DeviceOrErr) {
296+
InteropPtr->err_str = copyErrorString(DeviceOrErr.takeError());
282297
return;
283298
}
284299

openmp/libomptarget/src/PluginManager.cpp

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "PluginManager.h"
14+
#include "Shared/Debug.h"
15+
16+
#include "llvm/Support/Error.h"
17+
#include "llvm/Support/ErrorHandling.h"
1418

1519
using namespace llvm;
1620
using namespace llvm::sys;
@@ -71,7 +75,12 @@ PluginAdaptorTy::PluginAdaptorTy(const std::string &Name) : Name(Name) {
7175

7276
void PluginAdaptorTy::addOffloadEntries(DeviceImageTy &DI) {
7377
for (int32_t I = 0; I < NumberOfDevices; ++I) {
74-
DeviceTy &Device = *PM->Devices[DeviceOffset + I];
78+
auto DeviceOrErr = PM->getDevice(DeviceOffset + I);
79+
if (!DeviceOrErr)
80+
FATAL_MESSAGE(DeviceOffset + I, "%s",
81+
toString(DeviceOrErr.takeError()).c_str());
82+
83+
DeviceTy &Device = *DeviceOrErr;
7584
for (OffloadEntryTy &Entry : DI.entries())
7685
Device.addOffloadEntry(Entry);
7786
}
@@ -97,14 +106,15 @@ void PluginManager::initPlugin(PluginAdaptorTy &Plugin) {
97106
return;
98107

99108
// Initialize the device information for the RTL we are about to use.
100-
const size_t Start = Devices.size();
101-
Devices.reserve(Start + Plugin.NumberOfDevices);
109+
auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
110+
const size_t Start = ExclusiveDevicesAccessor->size();
111+
ExclusiveDevicesAccessor->reserve(Start + Plugin.NumberOfDevices);
102112
for (int32_t DeviceId = 0; DeviceId < Plugin.NumberOfDevices; DeviceId++) {
103-
Devices.push_back(std::make_unique<DeviceTy>(&Plugin));
113+
ExclusiveDevicesAccessor->push_back(std::make_unique<DeviceTy>(&Plugin));
104114
// global device ID
105-
Devices[Start + DeviceId]->DeviceID = Start + DeviceId;
115+
(*ExclusiveDevicesAccessor)[Start + DeviceId]->DeviceID = Start + DeviceId;
106116
// RTL local device ID
107-
Devices[Start + DeviceId]->RTLDeviceID = DeviceId;
117+
(*ExclusiveDevicesAccessor)[Start + DeviceId]->RTLDeviceID = DeviceId;
108118
}
109119

110120
// Initialize the index of this RTL and save it in the used RTLs.
@@ -254,7 +264,12 @@ void PluginManager::unregisterLib(__tgt_bin_desc *Desc) {
254264
// Execute dtors for static objects if the device has been used, i.e.
255265
// if its PendingCtors list has been emptied.
256266
for (int32_t I = 0; I < FoundRTL->NumberOfDevices; ++I) {
257-
DeviceTy &Device = *PM->Devices[FoundRTL->DeviceOffset + I];
267+
auto DeviceOrErr = PM->getDevice(FoundRTL->DeviceOffset + I);
268+
if (!DeviceOrErr)
269+
FATAL_MESSAGE(FoundRTL->DeviceOffset + I, "%s",
270+
toString(DeviceOrErr.takeError()).c_str());
271+
272+
DeviceTy &Device = *DeviceOrErr;
258273
Device.PendingGlobalsMtx.lock();
259274
if (Device.PendingCtorsDtors[Desc].PendingCtors.empty()) {
260275
AsyncInfoTy AsyncInfo(Device);
@@ -313,3 +328,26 @@ void PluginManager::unregisterLib(__tgt_bin_desc *Desc) {
313328

314329
DP("Done unregistering library!\n");
315330
}
331+
332+
Expected<DeviceTy &> PluginManager::getDevice(uint32_t DeviceNo) {
333+
auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
334+
if (DeviceNo >= ExclusiveDevicesAccessor->size())
335+
return createStringError(
336+
inconvertibleErrorCode(),
337+
"Device number '%i' out of range, only %i devices available", DeviceNo,
338+
ExclusiveDevicesAccessor->size());
339+
340+
DeviceTy &Device = *(*ExclusiveDevicesAccessor)[DeviceNo];
341+
342+
DP("Is the device %d (local ID %d) initialized? %d\n", DeviceNo,
343+
Device.RTLDeviceID, Device.IsInit);
344+
345+
// Init the device if not done before
346+
if (!Device.IsInit && Device.initOnce() != OFFLOAD_SUCCESS) {
347+
return createStringError(inconvertibleErrorCode(),
348+
"Failed to init device %d\n", DeviceNo);
349+
}
350+
351+
DP("Device %d is ready to use.\n", DeviceNo);
352+
return Device;
353+
}

0 commit comments

Comments
 (0)