Skip to content

[Offload] Provide a kernel library useable by the offload runtime #104168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions clang/lib/Driver/ToolChains/CommonArgs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1202,8 +1202,11 @@ bool tools::addOpenMPRuntime(const Compilation &C, ArgStringList &CmdArgs,
options::OPT_fno_openmp, false)) {
// We need libomptarget (liboffload) if it's the choosen offloading runtime.
if (Args.hasFlag(options::OPT_foffload_via_llvm,
options::OPT_fno_offload_via_llvm, false))
options::OPT_fno_offload_via_llvm, false)) {
CmdArgs.push_back("-lomptarget");
if (!Args.hasArg(options::OPT_nogpulib))
CmdArgs.append({"-lomptarget.devicertl", "-loffload.kernels"});
}
return false;
}

Expand Down Expand Up @@ -1240,7 +1243,7 @@ bool tools::addOpenMPRuntime(const Compilation &C, ArgStringList &CmdArgs,
CmdArgs.push_back("-lomptarget");

if (IsOffloadingHost && !Args.hasArg(options::OPT_nogpulib))
CmdArgs.push_back("-lomptarget.devicertl");
CmdArgs.append({"-lomptarget.devicertl", "-loffload.kernels"});

addArchSpecificRPath(TC, Args, CmdArgs);

Expand Down
1 change: 1 addition & 0 deletions offload/DeviceRTL/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ elseif(LIBOMPTARGET_DEVICE_ARCHITECTURES STREQUAL "auto" OR
"${LIBOMPTARGET_NVPTX_DETECTED_ARCH_LIST};${LIBOMPTARGET_AMDGPU_DETECTED_ARCH_LIST}")
endif()
list(REMOVE_DUPLICATES LIBOMPTARGET_DEVICE_ARCHITECTURES)
set(LIBOMPTARGET_DEVICE_ARCHITECTURES ${LIBOMPTARGET_DEVICE_ARCHITECTURES} PARENT_SCOPE)

set(include_files
${include_directory}/Allocator.h
Expand Down
3 changes: 3 additions & 0 deletions offload/include/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ struct DeviceTy {
/// Calls the corresponding print device info function in the plugin.
bool printDeviceInfo();

/// Return the handle to the kernel with name \p Name in \p HandlePtr.
int32_t getKernelHandle(llvm::StringRef Name, void **HandlePtr);

/// Event related interfaces.
/// {
/// Create an event.
Expand Down
5 changes: 5 additions & 0 deletions offload/include/omptarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,11 @@ void __tgt_target_data_update_nowait_mapper(
int __tgt_target_kernel(ident_t *Loc, int64_t DeviceId, int32_t NumTeams,
int32_t ThreadLimit, void *HostPtr, KernelArgsTy *Args);

/// Launch the kernel \p KernelName with a CUDA style launch and the given grid
/// sizes and arguments (\p KernelArgs).
int __tgt_launch_by_name(ident_t *Loc, int64_t DeviceId, const char *KernelName,
KernelArgsTy *KernelArgs);

// Non-blocking synchronization for target nowait regions. This function
// acquires the asynchronous context from task data of the current task being
// executed and tries to query for the completion of its operations. If the
Expand Down
56 changes: 9 additions & 47 deletions offload/plugins-nextgen/amdgpu/src/rtl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2016,20 +2016,13 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
return Plugin::success();
}

virtual Error callGlobalConstructors(GenericPluginTy &Plugin,
DeviceImageTy &Image) override {
GenericGlobalHandlerTy &Handler = Plugin.getGlobalHandler();
if (Handler.isSymbolInImage(*this, Image, "amdgcn.device.fini"))
Image.setPendingGlobalDtors();

return callGlobalCtorDtorCommon(Plugin, Image, /*IsCtor=*/true);
virtual Expected<StringRef>
getGlobalConstructorName(DeviceImageTy &Image) override {
return "amdgcn.device.init";
}

virtual Error callGlobalDestructors(GenericPluginTy &Plugin,
DeviceImageTy &Image) override {
if (Image.hasPendingGlobalDtors())
return callGlobalCtorDtorCommon(Plugin, Image, /*IsCtor=*/false);
return Plugin::success();
virtual Expected<StringRef>
getGlobalDestructorName(DeviceImageTy &Image) override {
return "amdgcn.device.fini";
}

uint64_t getStreamBusyWaitMicroseconds() const { return OMPX_StreamBusyWait; }
Expand Down Expand Up @@ -2107,13 +2100,14 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
uint64_t getClockFrequency() const override { return ClockFrequency; }

/// Allocate and construct an AMDGPU kernel.
Expected<GenericKernelTy &> constructKernel(const char *Name) override {
Expected<GenericKernelTy &>
constructKernelImpl(llvm::StringRef Name) override {
// Allocate and construct the AMDGPU kernel.
AMDGPUKernelTy *AMDGPUKernel = Plugin.allocate<AMDGPUKernelTy>();
if (!AMDGPUKernel)
return Plugin::error("Failed to allocate memory for AMDGPU kernel");

new (AMDGPUKernel) AMDGPUKernelTy(Name);
new (AMDGPUKernel) AMDGPUKernelTy(Name.data());

return *AMDGPUKernel;
}
Expand Down Expand Up @@ -2791,38 +2785,6 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
using AMDGPUEventRef = AMDGPUResourceRef<AMDGPUEventTy>;
using AMDGPUEventManagerTy = GenericDeviceResourceManagerTy<AMDGPUEventRef>;

/// Common method to invoke a single threaded constructor or destructor
/// kernel by name.
Error callGlobalCtorDtorCommon(GenericPluginTy &Plugin, DeviceImageTy &Image,
bool IsCtor) {
const char *KernelName =
IsCtor ? "amdgcn.device.init" : "amdgcn.device.fini";
// Perform a quick check for the named kernel in the image. The kernel
// should be created by the 'amdgpu-lower-ctor-dtor' pass.
GenericGlobalHandlerTy &Handler = Plugin.getGlobalHandler();
if (IsCtor && !Handler.isSymbolInImage(*this, Image, KernelName))
return Plugin::success();

// Allocate and construct the AMDGPU kernel.
AMDGPUKernelTy AMDGPUKernel(KernelName);
if (auto Err = AMDGPUKernel.init(*this, Image))
return Err;

AsyncInfoWrapperTy AsyncInfoWrapper(*this, nullptr);

KernelArgsTy KernelArgs = {};
if (auto Err =
AMDGPUKernel.launchImpl(*this, /*NumThread=*/1u,
/*NumBlocks=*/1ul, KernelArgs,
KernelLaunchParamsTy{}, AsyncInfoWrapper))
return Err;

Error Err = Plugin::success();
AsyncInfoWrapper.finalize(Err);

return Err;
}

/// Detect if current architecture is an APU.
Error checkIfAPU() {
// TODO: replace with ROCr API once it becomes available.
Expand Down
33 changes: 21 additions & 12 deletions offload/plugins-nextgen/common/include/PluginInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -722,18 +722,17 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
Error synchronize(__tgt_async_info *AsyncInfo);
virtual Error synchronizeImpl(__tgt_async_info &AsyncInfo) = 0;

/// Invokes any global constructors on the device if present and is required
/// by the target.
virtual Error callGlobalConstructors(GenericPluginTy &Plugin,
DeviceImageTy &Image) {
return Error::success();
/// Call the ctor/dtor of image \p Image, if available.
Error callGlobalCtorDtor(DeviceImageTy &Image, bool IsCtor);

/// Return the name of the global constructors on the device.
virtual Expected<StringRef> getGlobalConstructorName(DeviceImageTy &Image) {
return "";
}

/// Invokes any global destructors on the device if present and is required
/// by the target.
virtual Error callGlobalDestructors(GenericPluginTy &Plugin,
DeviceImageTy &Image) {
return Error::success();
/// Return the name of the global destructors on the device.
virtual Expected<StringRef> getGlobalDestructorName(DeviceImageTy &Image) {
return "";
}

/// Query for the completion of the pending operations on the __tgt_async_info
Expand Down Expand Up @@ -928,8 +927,12 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
bool useAutoZeroCopy();
virtual bool useAutoZeroCopyImpl() { return false; }

/// Allocate and construct a kernel object.
virtual Expected<GenericKernelTy &> constructKernel(const char *Name) = 0;
/// Retrieve the kernel with name \p Name from image \p Image (or any image if
/// \p Image is null) and return it. If \p Optional is true, the function
/// returns success if there is no kernel with the given name.
Expected<GenericKernelTy *> getKernel(llvm::StringRef Name,
DeviceImageTy *Image = nullptr,
bool Optional = false);

/// Reference to the underlying plugin that created this device.
GenericPluginTy &Plugin;
Expand All @@ -947,6 +950,10 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
UInt32Envar("OFFLOAD_TRACK_NUM_KERNEL_LAUNCH_TRACES", 0);

private:
/// Allocate and construct a kernel object (users should use getKernel).
virtual Expected<GenericKernelTy &>
constructKernelImpl(llvm::StringRef Name) = 0;

/// Get and set the stack size and heap size for the device. If not used, the
/// plugin can implement the setters as no-op and setting the output
/// value to zero for the getters.
Expand Down Expand Up @@ -1046,6 +1053,8 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
private:
DeviceMemoryPoolTy DeviceMemoryPool = {nullptr, 0};
DeviceMemoryPoolTrackingTy DeviceMemoryPoolTracking = {0, 0, ~0U, 0};

DenseMap<StringRef, GenericKernelTy *> KernelMap;
};

/// Class implementing common functionalities of offload plugins. Each plugin
Expand Down
109 changes: 98 additions & 11 deletions offload/plugins-nextgen/common/src/PluginInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

#include <cstdint>
#include <limits>
#include <string>

using namespace llvm;
using namespace omp;
Expand Down Expand Up @@ -809,7 +810,7 @@ Error GenericDeviceTy::init(GenericPluginTy &Plugin) {

Error GenericDeviceTy::deinit(GenericPluginTy &Plugin) {
for (DeviceImageTy *Image : LoadedImages)
if (auto Err = callGlobalDestructors(Plugin, *Image))
if (auto Err = callGlobalCtorDtor(*Image, /*Ctor*/ false))
return Err;

if (OMPX_DebugKind.get() & uint32_t(DeviceDebugKind::AllocationTracker)) {
Expand Down Expand Up @@ -866,6 +867,37 @@ Error GenericDeviceTy::deinit(GenericPluginTy &Plugin) {

return deinitImpl();
}

Error GenericDeviceTy::callGlobalCtorDtor(DeviceImageTy &Image, bool IsCtor) {
auto NameOrErr =
IsCtor ? getGlobalConstructorName(Image) : getGlobalDestructorName(Image);
if (auto Err = NameOrErr.takeError())
return Err;
// No error but no name, that means there is no ctor/dtor.
if (NameOrErr->empty())
return Plugin::success();

auto KernelOrErr = getKernel(*NameOrErr, &Image, /*Optional=*/true);
if (auto Err = KernelOrErr.takeError())
return Err;

if (GenericKernelTy *Kernel = *KernelOrErr) {
KernelArgsTy KernelArgs;
KernelArgs.NumTeams[0] = KernelArgs.ThreadLimit[0] = 1;
AsyncInfoWrapperTy AsyncInfoWrapper(*this, nullptr);
if (auto Err = Kernel->launch(*this, /*ArgPtrs=*/nullptr,
/*ArgOffsets=*/nullptr, KernelArgs,
AsyncInfoWrapper))
return Err;

Error Err = Plugin::success();
AsyncInfoWrapper.finalize(Err);
return Err;
}

return Plugin::success();
}

Expected<DeviceImageTy *>
GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
const __tgt_device_image *InputTgtImage) {
Expand Down Expand Up @@ -927,8 +959,8 @@ GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
#endif

// Call any global constructors present on the device.
if (auto Err = callGlobalConstructors(Plugin, *Image))
return std::move(Err);
if (auto Err = callGlobalCtorDtor(*Image, /*Ctor*/ true))
return Err;

// Return the pointer to the table of entries.
return Image;
Expand Down Expand Up @@ -1533,6 +1565,67 @@ Error GenericDeviceTy::printInfo() {
return Plugin::success();
}

Expected<GenericKernelTy *> GenericDeviceTy::getKernel(llvm::StringRef Name,
DeviceImageTy *ImagePtr,
bool Optional) {
bool KernelFound = false;
GenericKernelTy *&KernelPtr = KernelMap[Name];
if (!KernelPtr) {
GenericGlobalHandlerTy &GHandler = Plugin.getGlobalHandler();

auto CheckImage = [&](DeviceImageTy &Image) -> GenericKernelTy * {
if (!GHandler.isSymbolInImage(*this, Image, Name))
return nullptr;
KernelFound = true;

auto KernelOrErr = constructKernelImpl(Name);
if (Error Err = KernelOrErr.takeError()) {
[[maybe_unused]] std::string ErrStr = toString(std::move(Err));
DP("Failed to construct kernel ('%s'): %s", Name.data(),
ErrStr.c_str());
return nullptr;
}

GenericKernelTy &Kernel = *KernelOrErr;
if (auto Err = Kernel.init(*this, Image)) {
[[maybe_unused]] std::string ErrStr = toString(std::move(Err));
DP("Failed to initialize kernel ('%s'): %s", Name.data(),
ErrStr.c_str());
return nullptr;
}

return &Kernel;
};

if (ImagePtr) {
KernelPtr = CheckImage(*ImagePtr);
} else {
for (DeviceImageTy *Image : LoadedImages) {
KernelPtr = CheckImage(*Image);
if (KernelPtr)
break;
}
}
}

// If we didn't find the kernel and it was optional, we do not emit an error.
if (!KernelPtr && !KernelFound && Optional)
return nullptr;
// If we didn't find the kernel and it was not optional, we will emit an
// error.
if (!KernelPtr && !KernelFound)
return Plugin::error(
"Kernel '%s' not found%s", Name.data(),
ImagePtr
? ""
: ", searched " + std::to_string(LoadedImages.size()) + " images");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this misses a .data() or something.

// If we found the kernel but couldn't initialize it, we will emit an error.
if (!KernelPtr)
return Plugin::error("Kernel '%s' failed to initialize");
// Found the kernel and initialized it.
return KernelPtr;
}

Error GenericDeviceTy::createEvent(void **EventPtrStorage) {
return createEventImpl(EventPtrStorage);
}
Expand Down Expand Up @@ -2147,20 +2240,14 @@ int32_t GenericPluginTy::get_function(__tgt_device_binary Binary,

GenericDeviceTy &Device = Image.getDevice();

auto KernelOrErr = Device.constructKernel(Name);
auto KernelOrErr = Device.getKernel(Name, &Image);
if (Error Err = KernelOrErr.takeError()) {
REPORT("Failure to look up kernel: %s\n", toString(std::move(Err)).data());
return OFFLOAD_FAIL;
}

GenericKernelTy &Kernel = *KernelOrErr;
if (auto Err = Kernel.init(Device, Image)) {
REPORT("Failure to init kernel: %s\n", toString(std::move(Err)).data());
return OFFLOAD_FAIL;
}

// Note that this is not the kernel's device address.
*KernelPtr = &Kernel;
*KernelPtr = *KernelOrErr;
return OFFLOAD_SUCCESS;
}

Expand Down
Loading
Loading