-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[Offload] Ensure to load images when the device is used #103002
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -78,8 +78,13 @@ bool PluginManager::initializePlugin(GenericPluginTy &Plugin) { | |||
|
||||
bool PluginManager::initializeDevice(GenericPluginTy &Plugin, | ||||
int32_t DeviceId) { | ||||
if (Plugin.is_device_initialized(DeviceId)) | ||||
if (Plugin.is_device_initialized(DeviceId)) { | ||||
auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor(); | ||||
(*ExclusiveDevicesAccessor)[PM->DeviceIds[std::make_pair(&Plugin, | ||||
DeviceId)]] | ||||
->setHasPendingImages(true); | ||||
return true; | ||||
} | ||||
|
||||
// Initialize the device information for the RTL we are about to use. | ||||
auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor(); | ||||
|
@@ -286,13 +291,194 @@ void PluginManager::unregisterLib(__tgt_bin_desc *Desc) { | |||
DP("Done unregistering library!\n"); | ||||
} | ||||
|
||||
/// Map global data and execute pending ctors | ||||
static int loadImagesOntoDevice(DeviceTy &Device) { | ||||
/* | ||||
* Map global data | ||||
*/ | ||||
int32_t DeviceId = Device.DeviceID; | ||||
int Rc = OFFLOAD_SUCCESS; | ||||
{ | ||||
std::lock_guard<decltype(PM->TrlTblMtx)> LG(PM->TrlTblMtx); | ||||
for (auto *HostEntriesBegin : PM->HostEntriesBeginRegistrationOrder) { | ||||
TranslationTable *TransTable = | ||||
&PM->HostEntriesBeginToTransTable[HostEntriesBegin]; | ||||
DP("Trans table %p : %p\n", TransTable->HostTable.EntriesBegin, | ||||
TransTable->HostTable.EntriesEnd); | ||||
if (TransTable->HostTable.EntriesBegin == | ||||
TransTable->HostTable.EntriesEnd) { | ||||
// No host entry so no need to proceed | ||||
continue; | ||||
} | ||||
|
||||
if (TransTable->TargetsTable[DeviceId] != 0) { | ||||
// Library entries have already been processed | ||||
continue; | ||||
} | ||||
|
||||
// 1) get image. | ||||
assert(TransTable->TargetsImages.size() > (size_t)DeviceId && | ||||
"Not expecting a device ID outside the table's bounds!"); | ||||
__tgt_device_image *Img = TransTable->TargetsImages[DeviceId]; | ||||
if (!Img) { | ||||
REPORT("No image loaded for device id %d.\n", DeviceId); | ||||
Rc = OFFLOAD_FAIL; | ||||
break; | ||||
} | ||||
|
||||
// 2) Load the image onto the given device. | ||||
auto BinaryOrErr = Device.loadBinary(Img); | ||||
if (llvm::Error Err = BinaryOrErr.takeError()) { | ||||
REPORT("Failed to load image %s\n", | ||||
llvm::toString(std::move(Err)).c_str()); | ||||
Rc = OFFLOAD_FAIL; | ||||
break; | ||||
} | ||||
|
||||
// 3) Create the translation table. | ||||
llvm::SmallVector<__tgt_offload_entry> &DeviceEntries = | ||||
TransTable->TargetsEntries[DeviceId]; | ||||
for (__tgt_offload_entry &Entry : | ||||
llvm::make_range(Img->EntriesBegin, Img->EntriesEnd)) { | ||||
__tgt_device_binary &Binary = *BinaryOrErr; | ||||
|
||||
__tgt_offload_entry DeviceEntry = Entry; | ||||
if (Entry.size) { | ||||
if (Device.RTL->get_global(Binary, Entry.size, Entry.name, | ||||
&DeviceEntry.addr) != OFFLOAD_SUCCESS) | ||||
REPORT("Failed to load symbol %s\n", Entry.name); | ||||
|
||||
// If unified memory is active, the corresponding global is a device | ||||
// reference to the host global. We need to initialize the pointer on | ||||
// the device to point to the memory on the host. | ||||
if ((PM->getRequirements() & OMP_REQ_UNIFIED_SHARED_MEMORY) || | ||||
(PM->getRequirements() & OMPX_REQ_AUTO_ZERO_COPY)) { | ||||
if (Device.RTL->data_submit(DeviceId, DeviceEntry.addr, Entry.addr, | ||||
Entry.size) != OFFLOAD_SUCCESS) | ||||
REPORT("Failed to write symbol for USM %s\n", Entry.name); | ||||
} | ||||
} else if (Entry.addr) { | ||||
if (Device.RTL->get_function(Binary, Entry.name, &DeviceEntry.addr) != | ||||
OFFLOAD_SUCCESS) | ||||
REPORT("Failed to load kernel %s\n", Entry.name); | ||||
} | ||||
DP("Entry point " DPxMOD " maps to%s %s (" DPxMOD ")\n", | ||||
DPxPTR(Entry.addr), (Entry.size) ? " global" : "", Entry.name, | ||||
DPxPTR(DeviceEntry.addr)); | ||||
|
||||
DeviceEntries.emplace_back(DeviceEntry); | ||||
} | ||||
|
||||
// Set the storage for the table and get a pointer to it. | ||||
__tgt_target_table DeviceTable{&DeviceEntries[0], | ||||
&DeviceEntries[0] + DeviceEntries.size()}; | ||||
TransTable->DeviceTables[DeviceId] = DeviceTable; | ||||
__tgt_target_table *TargetTable = TransTable->TargetsTable[DeviceId] = | ||||
&TransTable->DeviceTables[DeviceId]; | ||||
|
||||
// 4) Verify whether the two table sizes match. | ||||
size_t Hsize = | ||||
TransTable->HostTable.EntriesEnd - TransTable->HostTable.EntriesBegin; | ||||
size_t Tsize = TargetTable->EntriesEnd - TargetTable->EntriesBegin; | ||||
|
||||
// Invalid image for these host entries! | ||||
if (Hsize != Tsize) { | ||||
REPORT( | ||||
"Host and Target tables mismatch for device id %d [%zx != %zx].\n", | ||||
DeviceId, Hsize, Tsize); | ||||
TransTable->TargetsImages[DeviceId] = 0; | ||||
TransTable->TargetsTable[DeviceId] = 0; | ||||
Rc = OFFLOAD_FAIL; | ||||
break; | ||||
} | ||||
|
||||
MappingInfoTy::HDTTMapAccessorTy HDTTMap = | ||||
Device.getMappingInfo().HostDataToTargetMap.getExclusiveAccessor(); | ||||
|
||||
__tgt_target_table *HostTable = &TransTable->HostTable; | ||||
for (__tgt_offload_entry *CurrDeviceEntry = TargetTable->EntriesBegin, | ||||
*CurrHostEntry = HostTable->EntriesBegin, | ||||
*EntryDeviceEnd = TargetTable->EntriesEnd; | ||||
CurrDeviceEntry != EntryDeviceEnd; | ||||
CurrDeviceEntry++, CurrHostEntry++) { | ||||
if (CurrDeviceEntry->size == 0) | ||||
continue; | ||||
|
||||
assert(CurrDeviceEntry->size == CurrHostEntry->size && | ||||
"data size mismatch"); | ||||
|
||||
// Fortran may use multiple weak declarations for the same symbol, | ||||
// therefore we must allow for multiple weak symbols to be loaded from | ||||
// the fat binary. Treat these mappings as any other "regular" | ||||
// mapping. Add entry to map. | ||||
if (Device.getMappingInfo().getTgtPtrBegin(HDTTMap, CurrHostEntry->addr, | ||||
CurrHostEntry->size)) | ||||
continue; | ||||
|
||||
void *CurrDeviceEntryAddr = CurrDeviceEntry->addr; | ||||
|
||||
// For indirect mapping, follow the indirection and map the actual | ||||
// target. | ||||
if (CurrDeviceEntry->flags & OMP_DECLARE_TARGET_INDIRECT) { | ||||
AsyncInfoTy AsyncInfo(Device); | ||||
void *DevPtr; | ||||
Device.retrieveData(&DevPtr, CurrDeviceEntryAddr, sizeof(void *), | ||||
AsyncInfo, /*Entry=*/nullptr, &HDTTMap); | ||||
if (AsyncInfo.synchronize() != OFFLOAD_SUCCESS) | ||||
return OFFLOAD_FAIL; | ||||
CurrDeviceEntryAddr = DevPtr; | ||||
} | ||||
|
||||
DP("Add mapping from host " DPxMOD " to device " DPxMOD " with size %zu" | ||||
", name \"%s\"\n", | ||||
DPxPTR(CurrHostEntry->addr), DPxPTR(CurrDeviceEntry->addr), | ||||
CurrDeviceEntry->size, CurrDeviceEntry->name); | ||||
HDTTMap->emplace(new HostDataToTargetTy( | ||||
(uintptr_t)CurrHostEntry->addr /*HstPtrBase*/, | ||||
(uintptr_t)CurrHostEntry->addr /*HstPtrBegin*/, | ||||
(uintptr_t)CurrHostEntry->addr + CurrHostEntry->size /*HstPtrEnd*/, | ||||
(uintptr_t)CurrDeviceEntryAddr /*TgtAllocBegin*/, | ||||
(uintptr_t)CurrDeviceEntryAddr /*TgtPtrBegin*/, | ||||
false /*UseHoldRefCount*/, CurrHostEntry->name, | ||||
true /*IsRefCountINF*/)); | ||||
|
||||
// Notify about the new mapping. | ||||
if (Device.notifyDataMapped(CurrHostEntry->addr, CurrHostEntry->size)) | ||||
return OFFLOAD_FAIL; | ||||
} | ||||
} | ||||
Device.setHasPendingImages(false); | ||||
} | ||||
|
||||
if (Rc != OFFLOAD_SUCCESS) | ||||
return Rc; | ||||
|
||||
static Int32Envar DumpOffloadEntries = | ||||
Int32Envar("OMPTARGET_DUMP_OFFLOAD_ENTRIES", -1); | ||||
if (DumpOffloadEntries.get() == DeviceId) | ||||
Device.dumpOffloadEntries(); | ||||
|
||||
return OFFLOAD_SUCCESS; | ||||
} | ||||
|
||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All the code above is just moved. |
||||
Expected<DeviceTy &> PluginManager::getDevice(uint32_t DeviceNo) { | ||||
auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor(); | ||||
if (DeviceNo >= ExclusiveDevicesAccessor->size()) | ||||
return createStringError( | ||||
inconvertibleErrorCode(), | ||||
"Device number '%i' out of range, only %i devices available", DeviceNo, | ||||
ExclusiveDevicesAccessor->size()); | ||||
DeviceTy *DevicePtr; | ||||
{ | ||||
auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor(); | ||||
if (DeviceNo >= ExclusiveDevicesAccessor->size()) | ||||
return createStringError( | ||||
inconvertibleErrorCode(), | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There's an overload that just takes the string which will add the error code for you. Maybe we should save that for a global cleanup? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. Let's not conflate this too much. This part of the code is just moved, we should simplify the things afterwards, e.g., pointer arith and clearer APIs. |
||||
"Device number '%i' out of range, only %i devices available", | ||||
DeviceNo, ExclusiveDevicesAccessor->size()); | ||||
|
||||
DevicePtr = &*(*ExclusiveDevicesAccessor)[DeviceNo]; | ||||
} | ||||
|
||||
return *(*ExclusiveDevicesAccessor)[DeviceNo]; | ||||
// Check whether global data has been mapped for this device | ||||
if (DevicePtr->hasPendingImages()) | ||||
if (loadImagesOntoDevice(*DevicePtr) != OFFLOAD_SUCCESS) | ||||
return createStringError(inconvertibleErrorCode(), | ||||
"Failed to load images on device '%i'", | ||||
DeviceNo); | ||||
return *DevicePtr; | ||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,8 +12,11 @@ | |
//===----------------------------------------------------------------------===// | ||
|
||
#include "OpenMP/OMPT/Interface.h" | ||
#include "OffloadPolicy.h" | ||
#include "OpenMP/OMPT/Callback.h" | ||
#include "OpenMP/omp.h" | ||
#include "PluginManager.h" | ||
#include "omptarget.h" | ||
#include "private.h" | ||
|
||
#include "Shared/EnvironmentVar.h" | ||
|
@@ -32,6 +35,43 @@ | |
using namespace llvm::omp::target::ompt; | ||
#endif | ||
|
||
// If offload is enabled, ensure that device DeviceID has been initialized. | ||
// | ||
// The return bool indicates if the offload is to the host device | ||
// There are three possible results: | ||
// - Return false if the taregt device is ready for offload | ||
// - Return true without reporting a runtime error if offload is | ||
// disabled, perhaps because the initial device was specified. | ||
// - Report a runtime error and return true. | ||
// | ||
// If DeviceID == OFFLOAD_DEVICE_DEFAULT, set DeviceID to the default device. | ||
// This step might be skipped if offload is disabled. | ||
bool checkDevice(int64_t &DeviceID, ident_t *Loc) { | ||
if (OffloadPolicy::get(*PM).Kind == OffloadPolicy::DISABLED) { | ||
DP("Offload is disabled\n"); | ||
return true; | ||
} | ||
|
||
if (DeviceID == OFFLOAD_DEVICE_DEFAULT) { | ||
DeviceID = omp_get_default_device(); | ||
DP("Use default device id %" PRId64 "\n", DeviceID); | ||
} | ||
|
||
// Proposed behavior for OpenMP 5.2 in OpenMP spec github issue 2669. | ||
if (omp_get_num_devices() == 0) { | ||
DP("omp_get_num_devices() == 0 but offload is manadatory\n"); | ||
handleTargetOutcome(false, Loc); | ||
return true; | ||
} | ||
|
||
if (DeviceID == omp_get_initial_device()) { | ||
DP("Device is host (%" PRId64 "), returning as if offload is disabled\n", | ||
DeviceID); | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code is just moved and the last part was split off into PluginManager.cpp. |
||
//////////////////////////////////////////////////////////////////////////////// | ||
/// adds requires flags | ||
EXTERN void __tgt_register_requires(int64_t Flags) { | ||
|
@@ -85,7 +125,7 @@ targetData(ident_t *Loc, int64_t DeviceId, int32_t ArgNum, void **ArgsBase, | |
DP("Entering data %s region for device %" PRId64 " with %d mappings\n", | ||
RegionName, DeviceId, ArgNum); | ||
|
||
if (checkDeviceAndCtors(DeviceId, Loc)) { | ||
if (checkDevice(DeviceId, Loc)) { | ||
DP("Not offloading to device %" PRId64 "\n", DeviceId); | ||
return; | ||
} | ||
|
@@ -266,7 +306,7 @@ static inline int targetKernel(ident_t *Loc, int64_t DeviceId, int32_t NumTeams, | |
"\n", | ||
DeviceId, DPxPTR(HostPtr)); | ||
|
||
if (checkDeviceAndCtors(DeviceId, Loc)) { | ||
if (checkDevice(DeviceId, Loc)) { | ||
DP("Not offloading to device %" PRId64 "\n", DeviceId); | ||
return OMP_TGT_FAIL; | ||
} | ||
|
@@ -404,7 +444,7 @@ EXTERN int __tgt_target_kernel_replay(ident_t *Loc, int64_t DeviceId, | |
uint64_t LoopTripCount) { | ||
assert(PM && "Runtime not initialized"); | ||
OMPT_IF_BUILT(ReturnAddressSetterRAII RA(__builtin_return_address(0))); | ||
if (checkDeviceAndCtors(DeviceId, Loc)) { | ||
if (checkDevice(DeviceId, Loc)) { | ||
DP("Not offloading to device %" PRId64 "\n", DeviceId); | ||
return OMP_TGT_FAIL; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we have utils for void ptr arithmetic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just moved the code. Cleanup is separate.