Skip to content

[Offload] Ensure to load images when the device is used #103002

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions offload/include/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,12 @@ struct DeviceTy {
/// Ask the device whether the runtime should use auto zero-copy.
bool useAutoZeroCopy();

/// Check if there are pending images for this device.
bool hasPendingImages() const { return HasPendingImages; }

/// Indicate that there are pending images for this device or not.
void setHasPendingImages(bool V) { HasPendingImages = V; }

private:
/// Deinitialize the device (and plugin).
void deinit();
Expand All @@ -163,6 +169,9 @@ struct DeviceTy {

/// Handler to collect and organize host-2-device mapping information.
MappingInfoTy MappingInfo;

/// Flag to indicate pending images (true after construction).
bool HasPendingImages = true;
};

#endif
202 changes: 194 additions & 8 deletions offload/src/PluginManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,13 @@ bool PluginManager::initializePlugin(GenericPluginTy &Plugin) {

bool PluginManager::initializeDevice(GenericPluginTy &Plugin,
int32_t DeviceId) {
if (Plugin.is_device_initialized(DeviceId))
if (Plugin.is_device_initialized(DeviceId)) {
auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
(*ExclusiveDevicesAccessor)[PM->DeviceIds[std::make_pair(&Plugin,
DeviceId)]]
->setHasPendingImages(true);
return true;
}

// Initialize the device information for the RTL we are about to use.
auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
Expand Down Expand Up @@ -286,13 +291,194 @@ void PluginManager::unregisterLib(__tgt_bin_desc *Desc) {
DP("Done unregistering library!\n");
}

/// Map global data and execute pending ctors
static int loadImagesOntoDevice(DeviceTy &Device) {
/*
* Map global data
*/
int32_t DeviceId = Device.DeviceID;
int Rc = OFFLOAD_SUCCESS;
{
std::lock_guard<decltype(PM->TrlTblMtx)> LG(PM->TrlTblMtx);
for (auto *HostEntriesBegin : PM->HostEntriesBeginRegistrationOrder) {
TranslationTable *TransTable =
&PM->HostEntriesBeginToTransTable[HostEntriesBegin];
DP("Trans table %p : %p\n", TransTable->HostTable.EntriesBegin,
TransTable->HostTable.EntriesEnd);
if (TransTable->HostTable.EntriesBegin ==
TransTable->HostTable.EntriesEnd) {
// No host entry so no need to proceed
continue;
}

if (TransTable->TargetsTable[DeviceId] != 0) {
// Library entries have already been processed
continue;
}

// 1) get image.
assert(TransTable->TargetsImages.size() > (size_t)DeviceId &&
"Not expecting a device ID outside the table's bounds!");
__tgt_device_image *Img = TransTable->TargetsImages[DeviceId];
if (!Img) {
REPORT("No image loaded for device id %d.\n", DeviceId);
Rc = OFFLOAD_FAIL;
break;
}

// 2) Load the image onto the given device.
auto BinaryOrErr = Device.loadBinary(Img);
if (llvm::Error Err = BinaryOrErr.takeError()) {
REPORT("Failed to load image %s\n",
llvm::toString(std::move(Err)).c_str());
Rc = OFFLOAD_FAIL;
break;
}

// 3) Create the translation table.
llvm::SmallVector<__tgt_offload_entry> &DeviceEntries =
TransTable->TargetsEntries[DeviceId];
for (__tgt_offload_entry &Entry :
llvm::make_range(Img->EntriesBegin, Img->EntriesEnd)) {
__tgt_device_binary &Binary = *BinaryOrErr;

__tgt_offload_entry DeviceEntry = Entry;
if (Entry.size) {
if (Device.RTL->get_global(Binary, Entry.size, Entry.name,
&DeviceEntry.addr) != OFFLOAD_SUCCESS)
REPORT("Failed to load symbol %s\n", Entry.name);

// If unified memory is active, the corresponding global is a device
// reference to the host global. We need to initialize the pointer on
// the device to point to the memory on the host.
if ((PM->getRequirements() & OMP_REQ_UNIFIED_SHARED_MEMORY) ||
(PM->getRequirements() & OMPX_REQ_AUTO_ZERO_COPY)) {
if (Device.RTL->data_submit(DeviceId, DeviceEntry.addr, Entry.addr,
Entry.size) != OFFLOAD_SUCCESS)
REPORT("Failed to write symbol for USM %s\n", Entry.name);
}
} else if (Entry.addr) {
if (Device.RTL->get_function(Binary, Entry.name, &DeviceEntry.addr) !=
OFFLOAD_SUCCESS)
REPORT("Failed to load kernel %s\n", Entry.name);
}
DP("Entry point " DPxMOD " maps to%s %s (" DPxMOD ")\n",
DPxPTR(Entry.addr), (Entry.size) ? " global" : "", Entry.name,
DPxPTR(DeviceEntry.addr));

DeviceEntries.emplace_back(DeviceEntry);
}

// Set the storage for the table and get a pointer to it.
__tgt_target_table DeviceTable{&DeviceEntries[0],
&DeviceEntries[0] + DeviceEntries.size()};
TransTable->DeviceTables[DeviceId] = DeviceTable;
__tgt_target_table *TargetTable = TransTable->TargetsTable[DeviceId] =
&TransTable->DeviceTables[DeviceId];

// 4) Verify whether the two table sizes match.
size_t Hsize =
TransTable->HostTable.EntriesEnd - TransTable->HostTable.EntriesBegin;
size_t Tsize = TargetTable->EntriesEnd - TargetTable->EntriesBegin;

// Invalid image for these host entries!
if (Hsize != Tsize) {
REPORT(
"Host and Target tables mismatch for device id %d [%zx != %zx].\n",
DeviceId, Hsize, Tsize);
TransTable->TargetsImages[DeviceId] = 0;
TransTable->TargetsTable[DeviceId] = 0;
Rc = OFFLOAD_FAIL;
break;
}

MappingInfoTy::HDTTMapAccessorTy HDTTMap =
Device.getMappingInfo().HostDataToTargetMap.getExclusiveAccessor();

__tgt_target_table *HostTable = &TransTable->HostTable;
for (__tgt_offload_entry *CurrDeviceEntry = TargetTable->EntriesBegin,
*CurrHostEntry = HostTable->EntriesBegin,
*EntryDeviceEnd = TargetTable->EntriesEnd;
CurrDeviceEntry != EntryDeviceEnd;
CurrDeviceEntry++, CurrHostEntry++) {
if (CurrDeviceEntry->size == 0)
continue;

assert(CurrDeviceEntry->size == CurrHostEntry->size &&
"data size mismatch");

// Fortran may use multiple weak declarations for the same symbol,
// therefore we must allow for multiple weak symbols to be loaded from
// the fat binary. Treat these mappings as any other "regular"
// mapping. Add entry to map.
if (Device.getMappingInfo().getTgtPtrBegin(HDTTMap, CurrHostEntry->addr,
CurrHostEntry->size))
continue;

void *CurrDeviceEntryAddr = CurrDeviceEntry->addr;

// For indirect mapping, follow the indirection and map the actual
// target.
if (CurrDeviceEntry->flags & OMP_DECLARE_TARGET_INDIRECT) {
AsyncInfoTy AsyncInfo(Device);
void *DevPtr;
Device.retrieveData(&DevPtr, CurrDeviceEntryAddr, sizeof(void *),
AsyncInfo, /*Entry=*/nullptr, &HDTTMap);
if (AsyncInfo.synchronize() != OFFLOAD_SUCCESS)
return OFFLOAD_FAIL;
CurrDeviceEntryAddr = DevPtr;
}

DP("Add mapping from host " DPxMOD " to device " DPxMOD " with size %zu"
", name \"%s\"\n",
DPxPTR(CurrHostEntry->addr), DPxPTR(CurrDeviceEntry->addr),
CurrDeviceEntry->size, CurrDeviceEntry->name);
HDTTMap->emplace(new HostDataToTargetTy(
(uintptr_t)CurrHostEntry->addr /*HstPtrBase*/,
(uintptr_t)CurrHostEntry->addr /*HstPtrBegin*/,
(uintptr_t)CurrHostEntry->addr + CurrHostEntry->size /*HstPtrEnd*/,
(uintptr_t)CurrDeviceEntryAddr /*TgtAllocBegin*/,
(uintptr_t)CurrDeviceEntryAddr /*TgtPtrBegin*/,
false /*UseHoldRefCount*/, CurrHostEntry->name,
true /*IsRefCountINF*/));
Comment on lines +437 to +443
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we have utils for void ptr arithmetic?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just moved the code. Cleanup is separate.


// Notify about the new mapping.
if (Device.notifyDataMapped(CurrHostEntry->addr, CurrHostEntry->size))
return OFFLOAD_FAIL;
}
}
Device.setHasPendingImages(false);
}

if (Rc != OFFLOAD_SUCCESS)
return Rc;

static Int32Envar DumpOffloadEntries =
Int32Envar("OMPTARGET_DUMP_OFFLOAD_ENTRIES", -1);
if (DumpOffloadEntries.get() == DeviceId)
Device.dumpOffloadEntries();

return OFFLOAD_SUCCESS;
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the code above is just moved.

Expected<DeviceTy &> PluginManager::getDevice(uint32_t DeviceNo) {
auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
if (DeviceNo >= ExclusiveDevicesAccessor->size())
return createStringError(
inconvertibleErrorCode(),
"Device number '%i' out of range, only %i devices available", DeviceNo,
ExclusiveDevicesAccessor->size());
DeviceTy *DevicePtr;
{
auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
if (DeviceNo >= ExclusiveDevicesAccessor->size())
return createStringError(
inconvertibleErrorCode(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
inconvertibleErrorCode(),

There's an overload that just takes the string which will add the error code for you. Maybe we should save that for a global cleanup?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Let's not conflate this too much. This part of the code is just moved, we should simplify the things afterwards, e.g., pointer arith and clearer APIs.

"Device number '%i' out of range, only %i devices available",
DeviceNo, ExclusiveDevicesAccessor->size());

DevicePtr = &*(*ExclusiveDevicesAccessor)[DeviceNo];
}

return *(*ExclusiveDevicesAccessor)[DeviceNo];
// Check whether global data has been mapped for this device
if (DevicePtr->hasPendingImages())
if (loadImagesOntoDevice(*DevicePtr) != OFFLOAD_SUCCESS)
return createStringError(inconvertibleErrorCode(),
"Failed to load images on device '%i'",
DeviceNo);
return *DevicePtr;
}
46 changes: 43 additions & 3 deletions offload/src/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
//===----------------------------------------------------------------------===//

#include "OpenMP/OMPT/Interface.h"
#include "OffloadPolicy.h"
#include "OpenMP/OMPT/Callback.h"
#include "OpenMP/omp.h"
#include "PluginManager.h"
#include "omptarget.h"
#include "private.h"

#include "Shared/EnvironmentVar.h"
Expand All @@ -32,6 +35,43 @@
using namespace llvm::omp::target::ompt;
#endif

// If offload is enabled, ensure that device DeviceID has been initialized.
//
// The return bool indicates if the offload is to the host device
// There are three possible results:
// - Return false if the taregt device is ready for offload
// - Return true without reporting a runtime error if offload is
// disabled, perhaps because the initial device was specified.
// - Report a runtime error and return true.
//
// If DeviceID == OFFLOAD_DEVICE_DEFAULT, set DeviceID to the default device.
// This step might be skipped if offload is disabled.
bool checkDevice(int64_t &DeviceID, ident_t *Loc) {
if (OffloadPolicy::get(*PM).Kind == OffloadPolicy::DISABLED) {
DP("Offload is disabled\n");
return true;
}

if (DeviceID == OFFLOAD_DEVICE_DEFAULT) {
DeviceID = omp_get_default_device();
DP("Use default device id %" PRId64 "\n", DeviceID);
}

// Proposed behavior for OpenMP 5.2 in OpenMP spec github issue 2669.
if (omp_get_num_devices() == 0) {
DP("omp_get_num_devices() == 0 but offload is manadatory\n");
handleTargetOutcome(false, Loc);
return true;
}

if (DeviceID == omp_get_initial_device()) {
DP("Device is host (%" PRId64 "), returning as if offload is disabled\n",
DeviceID);
return true;
}
return false;
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is just moved and the last part was split off into PluginManager.cpp.

////////////////////////////////////////////////////////////////////////////////
/// adds requires flags
EXTERN void __tgt_register_requires(int64_t Flags) {
Expand Down Expand Up @@ -85,7 +125,7 @@ targetData(ident_t *Loc, int64_t DeviceId, int32_t ArgNum, void **ArgsBase,
DP("Entering data %s region for device %" PRId64 " with %d mappings\n",
RegionName, DeviceId, ArgNum);

if (checkDeviceAndCtors(DeviceId, Loc)) {
if (checkDevice(DeviceId, Loc)) {
DP("Not offloading to device %" PRId64 "\n", DeviceId);
return;
}
Expand Down Expand Up @@ -266,7 +306,7 @@ static inline int targetKernel(ident_t *Loc, int64_t DeviceId, int32_t NumTeams,
"\n",
DeviceId, DPxPTR(HostPtr));

if (checkDeviceAndCtors(DeviceId, Loc)) {
if (checkDevice(DeviceId, Loc)) {
DP("Not offloading to device %" PRId64 "\n", DeviceId);
return OMP_TGT_FAIL;
}
Expand Down Expand Up @@ -404,7 +444,7 @@ EXTERN int __tgt_target_kernel_replay(ident_t *Loc, int64_t DeviceId,
uint64_t LoopTripCount) {
assert(PM && "Runtime not initialized");
OMPT_IF_BUILT(ReturnAddressSetterRAII RA(__builtin_return_address(0)));
if (checkDeviceAndCtors(DeviceId, Loc)) {
if (checkDevice(DeviceId, Loc)) {
DP("Not offloading to device %" PRId64 "\n", DeviceId);
return OMP_TGT_FAIL;
}
Expand Down
Loading
Loading