Skip to content

Commit 4321b97

Browse files
dhruvachakronlieb
authored andcommitted
devices cherry-pick patch 66784dc
Restores: [OpenMP] Ensure `Devices` is accessed exlusively (llvm#74374) Change-Id: I34e5814a76c61cba9deae2c129e3aae96116662e
1 parent 828e8dc commit 4321b97

File tree

10 files changed

+254
-246
lines changed

10 files changed

+254
-246
lines changed

openmp/libomptarget/include/PluginManager.h

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define OMPTARGET_PLUGIN_MANAGER_H
1515

1616
#include "DeviceImage.h"
17+
#include "ExclusiveAccess.h"
1718
#include "Shared/APITypes.h"
1819
#include "Shared/PluginAPI.h"
1920
#include "Shared/Requirements.h"
@@ -25,6 +26,7 @@
2526
#include "llvm/ADT/iterator.h"
2627
#include "llvm/ADT/iterator_range.h"
2728
#include "llvm/Support/DynamicLibrary.h"
29+
#include "llvm/Support/Error.h"
2830

2931
#include <cstdint>
3032
#include <list>
@@ -75,6 +77,13 @@ struct PluginAdaptorTy {
7577

7678
/// Struct for the data required to handle plugins
7779
struct PluginManager {
80+
/// Type of the devices container. We hand out DeviceTy& to queries which are
81+
/// stable addresses regardless if the container changes.
82+
using DeviceContainerTy = llvm::SmallVector<std::unique_ptr<DeviceTy>>;
83+
84+
/// Exclusive accessor type for the device container.
85+
using ExclusiveDevicesAccessorTy = Accessor<DeviceContainerTy>;
86+
7887
PluginManager() {}
7988

8089
void init();
@@ -89,13 +98,19 @@ struct PluginManager {
8998
DeviceImages.emplace_back(std::make_unique<DeviceImageTy>(TgtBinDesc, TgtDeviceImage));
9099
}
91100

101+
/// Return the device presented to the user as device \p DeviceNo if it is
102+
/// initialized and ready. Otherwise return an error explaining the problem.
103+
llvm::Expected<DeviceTy &> getDevice(uint32_t DeviceNo);
104+
105+
/// Iterate over all initialized and ready devices registered with this
106+
/// plugin.
107+
auto devices(ExclusiveDevicesAccessorTy &DevicesAccessor) {
108+
return llvm::make_pointee_range(*DevicesAccessor);
109+
}
110+
92111
/// Iterate over all device images registered with this plugin.
93112
auto deviceImages() { return llvm::make_pointee_range(DeviceImages); }
94113

95-
/// Devices associated with RTLs
96-
std::vector<std::unique_ptr<DeviceTy>> Devices;
97-
std::mutex RTLsMtx; ///< For RTLs and Devices
98-
99114
/// Translation table retreived from the binary
100115
HostEntriesBeginToTransTableTy HostEntriesBeginToTransTable;
101116
std::mutex TrlTblMtx; ///< For Translation Table
@@ -124,9 +139,12 @@ struct PluginManager {
124139
DelayedBinDesc.clear();
125140
}
126141

127-
int getNumDevices() {
128-
std::lock_guard<decltype(RTLsMtx)> Lock(RTLsMtx);
129-
return Devices.size();
142+
/// Return the number of usable devices.
143+
int getNumDevices() { return getExclusiveDevicesAccessor()->size(); }
144+
145+
/// Return an exclusive handle to access the devices container.
146+
ExclusiveDevicesAccessorTy getExclusiveDevicesAccessor() {
147+
return Devices.getExclusiveAccessor();
130148
}
131149

132150
int getNumUsedPlugins() const {
@@ -166,6 +184,11 @@ struct PluginManager {
166184

167185
/// The user provided requirements.
168186
RequirementCollection Requirements;
187+
188+
std::mutex RTLsMtx; ///< For RTLs
189+
190+
/// Devices associated with plugins, accesses to the container are exclusive.
191+
ProtectedObj<DeviceContainerTy> Devices;
169192
};
170193

171194
extern PluginManager *PM;

openmp/libomptarget/include/Shared/Debug.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,15 +141,16 @@ inline uint32_t getDebugLevel() {
141141
/// Print fatal error message with an error string and error identifier
142142
#define FATAL_MESSAGE0(_num, _str) \
143143
do { \
144-
fprintf(stderr, GETNAME(TARGET_NAME) " fatal error %d: %s\n", _num, _str); \
144+
fprintf(stderr, GETNAME(TARGET_NAME) " fatal error %d: %s\n", (int)_num, \
145+
_str); \
145146
abort(); \
146147
} while (0)
147148

148149
/// Print fatal error message with a printf string and error identifier
149150
#define FATAL_MESSAGE(_num, _str, ...) \
150151
do { \
151-
fprintf(stderr, GETNAME(TARGET_NAME) " fatal error %d: " _str "\n", _num, \
152-
__VA_ARGS__); \
152+
fprintf(stderr, GETNAME(TARGET_NAME) " fatal error %d: " _str "\n", \
153+
(int)_num, __VA_ARGS__); \
153154
abort(); \
154155
} while (0)
155156

openmp/libomptarget/include/device.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,9 +212,8 @@ struct DeviceTy {
212212
/// completed and AsyncInfo.isDone() returns true.
213213
int32_t queryAsync(AsyncInfoTy &AsyncInfo);
214214

215-
/// Calls the corresponding print in the \p RTLDEVID
216-
/// device RTL to obtain the information of the specific device.
217-
bool printDeviceInfo(int32_t RTLDevID);
215+
/// Calls the corresponding print device info function in the plugin.
216+
bool printDeviceInfo();
218217

219218
/// Event related interfaces.
220219
/// {
@@ -258,6 +257,4 @@ struct DeviceTy {
258257
llvm::DenseMap<llvm::StringRef, OffloadEntryTy *> DeviceOffloadEntries;
259258
};
260259

261-
extern bool deviceIsReady(int DeviceNum);
262-
263260
#endif

openmp/libomptarget/src/OpenMP/InteropAPI.cpp

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
#include "PluginManager.h"
1414
#include "device.h"
1515
#include "omptarget.h"
16+
#include "llvm/Support/Error.h"
17+
#include <cstdlib>
18+
#include <cstring>
1619

1720
extern "C" {
1821

@@ -190,6 +193,14 @@ __OMP_GET_INTEROP_TY3(const char *, type_desc)
190193
__OMP_GET_INTEROP_TY3(const char *, rc_desc)
191194
#undef __OMP_GET_INTEROP_TY3
192195

196+
static const char *copyErrorString(llvm::Error &&Err) {
197+
// TODO: Use the error string while avoiding leaks.
198+
std::string ErrMsg = llvm::toString(std::move(Err));
199+
char *UsrMsg = reinterpret_cast<char *>(malloc(ErrMsg.size() + 1));
200+
strcpy(UsrMsg, ErrMsg.c_str());
201+
return UsrMsg;
202+
};
203+
193204
extern "C" {
194205

195206
void __tgt_interop_init(ident_t *LocRef, int32_t Gtid,
@@ -211,12 +222,14 @@ void __tgt_interop_init(ident_t *LocRef, int32_t Gtid,
211222
}
212223

213224
InteropPtr = new omp_interop_val_t(DeviceId, InteropType);
214-
if (!deviceIsReady(DeviceId)) {
215-
InteropPtr->err_str = "Device not ready!";
225+
226+
auto DeviceOrErr = PM->getDevice(DeviceId);
227+
if (!DeviceOrErr) {
228+
InteropPtr->err_str = copyErrorString(DeviceOrErr.takeError());
216229
return;
217230
}
218231

219-
DeviceTy &Device = *PM->Devices[DeviceId];
232+
DeviceTy &Device = *DeviceOrErr;
220233
if (!Device.RTL || !Device.RTL->init_device_info ||
221234
Device.RTL->init_device_info(DeviceId, &(InteropPtr)->device_info,
222235
&(InteropPtr)->err_str)) {
@@ -248,8 +261,9 @@ void __tgt_interop_use(ident_t *LocRef, int32_t Gtid,
248261
assert((DeviceId == -1 || InteropVal->device_id == DeviceId) &&
249262
"Inconsistent device-id usage!");
250263

251-
if (!deviceIsReady(DeviceId)) {
252-
InteropPtr->err_str = "Device not ready!";
264+
auto DeviceOrErr = PM->getDevice(DeviceId);
265+
if (!DeviceOrErr) {
266+
InteropPtr->err_str = copyErrorString(DeviceOrErr.takeError());
253267
return;
254268
}
255269

@@ -277,8 +291,9 @@ void __tgt_interop_destroy(ident_t *LocRef, int32_t Gtid,
277291

278292
assert((DeviceId == -1 || InteropVal->device_id == DeviceId) &&
279293
"Inconsistent device-id usage!");
280-
if (!deviceIsReady(DeviceId)) {
281-
InteropPtr->err_str = "Device not ready!";
294+
auto DeviceOrErr = PM->getDevice(DeviceId);
295+
if (!DeviceOrErr) {
296+
InteropPtr->err_str = copyErrorString(DeviceOrErr.takeError());
282297
return;
283298
}
284299

openmp/libomptarget/src/PluginManager.cpp

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,13 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13+
#include "PluginManager.h"
1314
#include "OmptTracing.h"
1415
#include "OpenMP/OMPT/Callback.h"
15-
#include "PluginManager.h"
16+
#include "Shared/Debug.h"
17+
18+
#include "llvm/Support/Error.h"
19+
#include "llvm/Support/ErrorHandling.h"
1620

1721
using namespace llvm;
1822
using namespace llvm::sys;
@@ -73,7 +77,12 @@ PluginAdaptorTy::PluginAdaptorTy(const std::string &Name) : Name(Name) {
7377

7478
void PluginAdaptorTy::addOffloadEntries(DeviceImageTy &DI) {
7579
for (int32_t I = 0; I < NumberOfDevices; ++I) {
76-
DeviceTy &Device = *PM->Devices[DeviceOffset + I];
80+
auto DeviceOrErr = PM->getDevice(DeviceOffset + I);
81+
if (!DeviceOrErr)
82+
FATAL_MESSAGE(DeviceOffset + I, "%s",
83+
toString(DeviceOrErr.takeError()).c_str());
84+
85+
DeviceTy &Device = *DeviceOrErr;
7786
for (OffloadEntryTy &Entry : DI.entries())
7887
Device.addOffloadEntry(Entry);
7988
}
@@ -99,14 +108,15 @@ void PluginManager::initPlugin(PluginAdaptorTy &Plugin) {
99108
return;
100109

101110
// Initialize the device information for the RTL we are about to use.
102-
const size_t Start = Devices.size();
103-
Devices.reserve(Start + Plugin.NumberOfDevices);
111+
auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
112+
const size_t Start = ExclusiveDevicesAccessor->size();
113+
ExclusiveDevicesAccessor->reserve(Start + Plugin.NumberOfDevices);
104114
for (int32_t DeviceId = 0; DeviceId < Plugin.NumberOfDevices; DeviceId++) {
105-
Devices.push_back(std::make_unique<DeviceTy>(&Plugin));
115+
ExclusiveDevicesAccessor->push_back(std::make_unique<DeviceTy>(&Plugin));
106116
// global device ID
107-
Devices[Start + DeviceId]->DeviceID = Start + DeviceId;
117+
(*ExclusiveDevicesAccessor)[Start + DeviceId]->DeviceID = Start + DeviceId;
108118
// RTL local device ID
109-
Devices[Start + DeviceId]->RTLDeviceID = DeviceId;
119+
(*ExclusiveDevicesAccessor)[Start + DeviceId]->RTLDeviceID = DeviceId;
110120
}
111121

112122
// Initialize the index of this RTL and save it in the used RTLs.
@@ -270,7 +280,12 @@ void PluginManager::unregisterLib(__tgt_bin_desc *Desc) {
270280
// Execute dtors for static objects if the device has been used, i.e.
271281
// if its PendingCtors list has been emptied.
272282
for (int32_t I = 0; I < FoundRTL->NumberOfDevices; ++I) {
273-
DeviceTy &Device = *PM->Devices[FoundRTL->DeviceOffset + I];
283+
auto DeviceOrErr = PM->getDevice(FoundRTL->DeviceOffset + I);
284+
if (!DeviceOrErr)
285+
FATAL_MESSAGE(FoundRTL->DeviceOffset + I, "%s",
286+
toString(DeviceOrErr.takeError()).c_str());
287+
288+
DeviceTy &Device = *DeviceOrErr;
274289
Device.PendingGlobalsMtx.lock();
275290
if (Device.PendingCtorsDtors[Desc].PendingCtors.empty()) {
276291
AsyncInfoTy AsyncInfo(Device);
@@ -329,3 +344,26 @@ void PluginManager::unregisterLib(__tgt_bin_desc *Desc) {
329344

330345
DP("Done unregistering library!\n");
331346
}
347+
348+
Expected<DeviceTy &> PluginManager::getDevice(uint32_t DeviceNo) {
349+
auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
350+
if (DeviceNo >= ExclusiveDevicesAccessor->size())
351+
return createStringError(
352+
inconvertibleErrorCode(),
353+
"Device number '%i' out of range, only %i devices available", DeviceNo,
354+
ExclusiveDevicesAccessor->size());
355+
356+
DeviceTy &Device = *(*ExclusiveDevicesAccessor)[DeviceNo];
357+
358+
DP("Is the device %d (local ID %d) initialized? %d\n", DeviceNo,
359+
Device.RTLDeviceID, Device.IsInit);
360+
361+
// Init the device if not done before
362+
if (!Device.IsInit && Device.initOnce() != OFFLOAD_SUCCESS) {
363+
return createStringError(inconvertibleErrorCode(),
364+
"Failed to init device %d\n", DeviceNo);
365+
}
366+
367+
DP("Device %d is ready to use.\n", DeviceNo);
368+
return Device;
369+
}

0 commit comments

Comments
 (0)