Skip to content

Commit ff12c00

Browse files
jdoerfertjhuber6
andauthored
[Offload] Ensure to load images when the device is used (#103002)
When we use the device, e.g., with an API that interacts with it, we need to ensure the image is loaded and the constructors are executed. Two tests are included to verify we 1) load images and run constructors when needed, and 2) we do so lazily only if the device is actually used. --------- Co-authored-by: Joseph Huber <[email protected]>
1 parent 101acff commit ff12c00

File tree

7 files changed

+303
-229
lines changed

7 files changed

+303
-229
lines changed

offload/include/device.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,12 @@ struct DeviceTy {
152152
/// Ask the device whether the runtime should use auto zero-copy.
153153
bool useAutoZeroCopy();
154154

155+
/// Check if there are pending images for this device.
156+
bool hasPendingImages() const { return HasPendingImages; }
157+
158+
/// Indicate that there are pending images for this device or not.
159+
void setHasPendingImages(bool V) { HasPendingImages = V; }
160+
155161
private:
156162
/// Deinitialize the device (and plugin).
157163
void deinit();
@@ -163,6 +169,9 @@ struct DeviceTy {
163169

164170
/// Handler to collect and organize host-2-device mapping information.
165171
MappingInfoTy MappingInfo;
172+
173+
/// Flag to indicate pending images (true after construction).
174+
bool HasPendingImages = true;
166175
};
167176

168177
#endif

offload/src/PluginManager.cpp

Lines changed: 194 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,13 @@ bool PluginManager::initializePlugin(GenericPluginTy &Plugin) {
7878

7979
bool PluginManager::initializeDevice(GenericPluginTy &Plugin,
8080
int32_t DeviceId) {
81-
if (Plugin.is_device_initialized(DeviceId))
81+
if (Plugin.is_device_initialized(DeviceId)) {
82+
auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
83+
(*ExclusiveDevicesAccessor)[PM->DeviceIds[std::make_pair(&Plugin,
84+
DeviceId)]]
85+
->setHasPendingImages(true);
8286
return true;
87+
}
8388

8489
// Initialize the device information for the RTL we are about to use.
8590
auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
@@ -286,13 +291,194 @@ void PluginManager::unregisterLib(__tgt_bin_desc *Desc) {
286291
DP("Done unregistering library!\n");
287292
}
288293

294+
/// Map global data and execute pending ctors
295+
static int loadImagesOntoDevice(DeviceTy &Device) {
296+
/*
297+
* Map global data
298+
*/
299+
int32_t DeviceId = Device.DeviceID;
300+
int Rc = OFFLOAD_SUCCESS;
301+
{
302+
std::lock_guard<decltype(PM->TrlTblMtx)> LG(PM->TrlTblMtx);
303+
for (auto *HostEntriesBegin : PM->HostEntriesBeginRegistrationOrder) {
304+
TranslationTable *TransTable =
305+
&PM->HostEntriesBeginToTransTable[HostEntriesBegin];
306+
DP("Trans table %p : %p\n", TransTable->HostTable.EntriesBegin,
307+
TransTable->HostTable.EntriesEnd);
308+
if (TransTable->HostTable.EntriesBegin ==
309+
TransTable->HostTable.EntriesEnd) {
310+
// No host entry so no need to proceed
311+
continue;
312+
}
313+
314+
if (TransTable->TargetsTable[DeviceId] != 0) {
315+
// Library entries have already been processed
316+
continue;
317+
}
318+
319+
// 1) get image.
320+
assert(TransTable->TargetsImages.size() > (size_t)DeviceId &&
321+
"Not expecting a device ID outside the table's bounds!");
322+
__tgt_device_image *Img = TransTable->TargetsImages[DeviceId];
323+
if (!Img) {
324+
REPORT("No image loaded for device id %d.\n", DeviceId);
325+
Rc = OFFLOAD_FAIL;
326+
break;
327+
}
328+
329+
// 2) Load the image onto the given device.
330+
auto BinaryOrErr = Device.loadBinary(Img);
331+
if (llvm::Error Err = BinaryOrErr.takeError()) {
332+
REPORT("Failed to load image %s\n",
333+
llvm::toString(std::move(Err)).c_str());
334+
Rc = OFFLOAD_FAIL;
335+
break;
336+
}
337+
338+
// 3) Create the translation table.
339+
llvm::SmallVector<__tgt_offload_entry> &DeviceEntries =
340+
TransTable->TargetsEntries[DeviceId];
341+
for (__tgt_offload_entry &Entry :
342+
llvm::make_range(Img->EntriesBegin, Img->EntriesEnd)) {
343+
__tgt_device_binary &Binary = *BinaryOrErr;
344+
345+
__tgt_offload_entry DeviceEntry = Entry;
346+
if (Entry.size) {
347+
if (Device.RTL->get_global(Binary, Entry.size, Entry.name,
348+
&DeviceEntry.addr) != OFFLOAD_SUCCESS)
349+
REPORT("Failed to load symbol %s\n", Entry.name);
350+
351+
// If unified memory is active, the corresponding global is a device
352+
// reference to the host global. We need to initialize the pointer on
353+
// the device to point to the memory on the host.
354+
if ((PM->getRequirements() & OMP_REQ_UNIFIED_SHARED_MEMORY) ||
355+
(PM->getRequirements() & OMPX_REQ_AUTO_ZERO_COPY)) {
356+
if (Device.RTL->data_submit(DeviceId, DeviceEntry.addr, Entry.addr,
357+
Entry.size) != OFFLOAD_SUCCESS)
358+
REPORT("Failed to write symbol for USM %s\n", Entry.name);
359+
}
360+
} else if (Entry.addr) {
361+
if (Device.RTL->get_function(Binary, Entry.name, &DeviceEntry.addr) !=
362+
OFFLOAD_SUCCESS)
363+
REPORT("Failed to load kernel %s\n", Entry.name);
364+
}
365+
DP("Entry point " DPxMOD " maps to%s %s (" DPxMOD ")\n",
366+
DPxPTR(Entry.addr), (Entry.size) ? " global" : "", Entry.name,
367+
DPxPTR(DeviceEntry.addr));
368+
369+
DeviceEntries.emplace_back(DeviceEntry);
370+
}
371+
372+
// Set the storage for the table and get a pointer to it.
373+
__tgt_target_table DeviceTable{&DeviceEntries[0],
374+
&DeviceEntries[0] + DeviceEntries.size()};
375+
TransTable->DeviceTables[DeviceId] = DeviceTable;
376+
__tgt_target_table *TargetTable = TransTable->TargetsTable[DeviceId] =
377+
&TransTable->DeviceTables[DeviceId];
378+
379+
// 4) Verify whether the two table sizes match.
380+
size_t Hsize =
381+
TransTable->HostTable.EntriesEnd - TransTable->HostTable.EntriesBegin;
382+
size_t Tsize = TargetTable->EntriesEnd - TargetTable->EntriesBegin;
383+
384+
// Invalid image for these host entries!
385+
if (Hsize != Tsize) {
386+
REPORT(
387+
"Host and Target tables mismatch for device id %d [%zx != %zx].\n",
388+
DeviceId, Hsize, Tsize);
389+
TransTable->TargetsImages[DeviceId] = 0;
390+
TransTable->TargetsTable[DeviceId] = 0;
391+
Rc = OFFLOAD_FAIL;
392+
break;
393+
}
394+
395+
MappingInfoTy::HDTTMapAccessorTy HDTTMap =
396+
Device.getMappingInfo().HostDataToTargetMap.getExclusiveAccessor();
397+
398+
__tgt_target_table *HostTable = &TransTable->HostTable;
399+
for (__tgt_offload_entry *CurrDeviceEntry = TargetTable->EntriesBegin,
400+
*CurrHostEntry = HostTable->EntriesBegin,
401+
*EntryDeviceEnd = TargetTable->EntriesEnd;
402+
CurrDeviceEntry != EntryDeviceEnd;
403+
CurrDeviceEntry++, CurrHostEntry++) {
404+
if (CurrDeviceEntry->size == 0)
405+
continue;
406+
407+
assert(CurrDeviceEntry->size == CurrHostEntry->size &&
408+
"data size mismatch");
409+
410+
// Fortran may use multiple weak declarations for the same symbol,
411+
// therefore we must allow for multiple weak symbols to be loaded from
412+
// the fat binary. Treat these mappings as any other "regular"
413+
// mapping. Add entry to map.
414+
if (Device.getMappingInfo().getTgtPtrBegin(HDTTMap, CurrHostEntry->addr,
415+
CurrHostEntry->size))
416+
continue;
417+
418+
void *CurrDeviceEntryAddr = CurrDeviceEntry->addr;
419+
420+
// For indirect mapping, follow the indirection and map the actual
421+
// target.
422+
if (CurrDeviceEntry->flags & OMP_DECLARE_TARGET_INDIRECT) {
423+
AsyncInfoTy AsyncInfo(Device);
424+
void *DevPtr;
425+
Device.retrieveData(&DevPtr, CurrDeviceEntryAddr, sizeof(void *),
426+
AsyncInfo, /*Entry=*/nullptr, &HDTTMap);
427+
if (AsyncInfo.synchronize() != OFFLOAD_SUCCESS)
428+
return OFFLOAD_FAIL;
429+
CurrDeviceEntryAddr = DevPtr;
430+
}
431+
432+
DP("Add mapping from host " DPxMOD " to device " DPxMOD " with size %zu"
433+
", name \"%s\"\n",
434+
DPxPTR(CurrHostEntry->addr), DPxPTR(CurrDeviceEntry->addr),
435+
CurrDeviceEntry->size, CurrDeviceEntry->name);
436+
HDTTMap->emplace(new HostDataToTargetTy(
437+
(uintptr_t)CurrHostEntry->addr /*HstPtrBase*/,
438+
(uintptr_t)CurrHostEntry->addr /*HstPtrBegin*/,
439+
(uintptr_t)CurrHostEntry->addr + CurrHostEntry->size /*HstPtrEnd*/,
440+
(uintptr_t)CurrDeviceEntryAddr /*TgtAllocBegin*/,
441+
(uintptr_t)CurrDeviceEntryAddr /*TgtPtrBegin*/,
442+
false /*UseHoldRefCount*/, CurrHostEntry->name,
443+
true /*IsRefCountINF*/));
444+
445+
// Notify about the new mapping.
446+
if (Device.notifyDataMapped(CurrHostEntry->addr, CurrHostEntry->size))
447+
return OFFLOAD_FAIL;
448+
}
449+
}
450+
Device.setHasPendingImages(false);
451+
}
452+
453+
if (Rc != OFFLOAD_SUCCESS)
454+
return Rc;
455+
456+
static Int32Envar DumpOffloadEntries =
457+
Int32Envar("OMPTARGET_DUMP_OFFLOAD_ENTRIES", -1);
458+
if (DumpOffloadEntries.get() == DeviceId)
459+
Device.dumpOffloadEntries();
460+
461+
return OFFLOAD_SUCCESS;
462+
}
463+
289464
Expected<DeviceTy &> PluginManager::getDevice(uint32_t DeviceNo) {
290-
auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
291-
if (DeviceNo >= ExclusiveDevicesAccessor->size())
292-
return createStringError(
293-
inconvertibleErrorCode(),
294-
"Device number '%i' out of range, only %i devices available", DeviceNo,
295-
ExclusiveDevicesAccessor->size());
465+
DeviceTy *DevicePtr;
466+
{
467+
auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
468+
if (DeviceNo >= ExclusiveDevicesAccessor->size())
469+
return createStringError(
470+
inconvertibleErrorCode(),
471+
"Device number '%i' out of range, only %i devices available",
472+
DeviceNo, ExclusiveDevicesAccessor->size());
473+
474+
DevicePtr = &*(*ExclusiveDevicesAccessor)[DeviceNo];
475+
}
296476

297-
return *(*ExclusiveDevicesAccessor)[DeviceNo];
477+
// Check whether global data has been mapped for this device
478+
if (DevicePtr->hasPendingImages())
479+
if (loadImagesOntoDevice(*DevicePtr) != OFFLOAD_SUCCESS)
480+
return createStringError(inconvertibleErrorCode(),
481+
"Failed to load images on device '%i'",
482+
DeviceNo);
483+
return *DevicePtr;
298484
}

offload/src/interface.cpp

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
//===----------------------------------------------------------------------===//
1313

1414
#include "OpenMP/OMPT/Interface.h"
15+
#include "OffloadPolicy.h"
1516
#include "OpenMP/OMPT/Callback.h"
17+
#include "OpenMP/omp.h"
1618
#include "PluginManager.h"
19+
#include "omptarget.h"
1720
#include "private.h"
1821

1922
#include "Shared/EnvironmentVar.h"
@@ -32,6 +35,43 @@
3235
using namespace llvm::omp::target::ompt;
3336
#endif
3437

38+
// If offload is enabled, ensure that device DeviceID has been initialized.
39+
//
40+
// The return bool indicates if the offload is to the host device
41+
// There are three possible results:
42+
// - Return false if the taregt device is ready for offload
43+
// - Return true without reporting a runtime error if offload is
44+
// disabled, perhaps because the initial device was specified.
45+
// - Report a runtime error and return true.
46+
//
47+
// If DeviceID == OFFLOAD_DEVICE_DEFAULT, set DeviceID to the default device.
48+
// This step might be skipped if offload is disabled.
49+
bool checkDevice(int64_t &DeviceID, ident_t *Loc) {
50+
if (OffloadPolicy::get(*PM).Kind == OffloadPolicy::DISABLED) {
51+
DP("Offload is disabled\n");
52+
return true;
53+
}
54+
55+
if (DeviceID == OFFLOAD_DEVICE_DEFAULT) {
56+
DeviceID = omp_get_default_device();
57+
DP("Use default device id %" PRId64 "\n", DeviceID);
58+
}
59+
60+
// Proposed behavior for OpenMP 5.2 in OpenMP spec github issue 2669.
61+
if (omp_get_num_devices() == 0) {
62+
DP("omp_get_num_devices() == 0 but offload is manadatory\n");
63+
handleTargetOutcome(false, Loc);
64+
return true;
65+
}
66+
67+
if (DeviceID == omp_get_initial_device()) {
68+
DP("Device is host (%" PRId64 "), returning as if offload is disabled\n",
69+
DeviceID);
70+
return true;
71+
}
72+
return false;
73+
}
74+
3575
////////////////////////////////////////////////////////////////////////////////
3676
/// adds requires flags
3777
EXTERN void __tgt_register_requires(int64_t Flags) {
@@ -85,7 +125,7 @@ targetData(ident_t *Loc, int64_t DeviceId, int32_t ArgNum, void **ArgsBase,
85125
DP("Entering data %s region for device %" PRId64 " with %d mappings\n",
86126
RegionName, DeviceId, ArgNum);
87127

88-
if (checkDeviceAndCtors(DeviceId, Loc)) {
128+
if (checkDevice(DeviceId, Loc)) {
89129
DP("Not offloading to device %" PRId64 "\n", DeviceId);
90130
return;
91131
}
@@ -266,7 +306,7 @@ static inline int targetKernel(ident_t *Loc, int64_t DeviceId, int32_t NumTeams,
266306
"\n",
267307
DeviceId, DPxPTR(HostPtr));
268308

269-
if (checkDeviceAndCtors(DeviceId, Loc)) {
309+
if (checkDevice(DeviceId, Loc)) {
270310
DP("Not offloading to device %" PRId64 "\n", DeviceId);
271311
return OMP_TGT_FAIL;
272312
}
@@ -404,7 +444,7 @@ EXTERN int __tgt_target_kernel_replay(ident_t *Loc, int64_t DeviceId,
404444
uint64_t LoopTripCount) {
405445
assert(PM && "Runtime not initialized");
406446
OMPT_IF_BUILT(ReturnAddressSetterRAII RA(__builtin_return_address(0)));
407-
if (checkDeviceAndCtors(DeviceId, Loc)) {
447+
if (checkDevice(DeviceId, Loc)) {
408448
DP("Not offloading to device %" PRId64 "\n", DeviceId);
409449
return OMP_TGT_FAIL;
410450
}

0 commit comments

Comments
 (0)