Skip to content

[SYCL][CUDA] Cuda adapter multi device context #10737

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 29 additions & 19 deletions sycl/plugins/unified_runtime/ur/adapters/cuda/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,14 @@ UR_APIEXPORT ur_result_t UR_APICALL
urContextCreate(uint32_t DeviceCount, const ur_device_handle_t *phDevices,
const ur_context_properties_t *pProperties,
ur_context_handle_t *phContext) {
std::ignore = DeviceCount;
std::ignore = pProperties;

assert(DeviceCount == 1);
ur_result_t RetErr = UR_RESULT_SUCCESS;

std::unique_ptr<ur_context_handle_t_> ContextPtr{nullptr};
try {
ContextPtr = std::unique_ptr<ur_context_handle_t_>(
new ur_context_handle_t_{*phDevices});
new ur_context_handle_t_{phDevices, DeviceCount});
*phContext = ContextPtr.release();
} catch (ur_result_t Err) {
RetErr = Err;
Expand All @@ -72,7 +70,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextGetInfo(
case UR_CONTEXT_INFO_NUM_DEVICES:
return ReturnValue(1);
case UR_CONTEXT_INFO_DEVICES:
return ReturnValue(hContext->getDevice());
return ReturnValue(hContext->getDevices());
case UR_CONTEXT_INFO_REFERENCE_COUNT:
return ReturnValue(hContext->getReferenceCount());
case UR_CONTEXT_INFO_ATOMIC_MEMORY_ORDER_CAPABILITIES: {
Expand All @@ -83,21 +81,27 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextGetInfo(
return ReturnValue(Capabilities);
}
case UR_CONTEXT_INFO_ATOMIC_MEMORY_SCOPE_CAPABILITIES: {
int Major = 0;
detail::ur::assertion(
cuDeviceGetAttribute(&Major,
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
hContext->getDevice()->get()) == CUDA_SUCCESS);
// Return the lowest compute capability of all devices in context
int MinimumMajorComputeCapability = 0, Tmp = 0;
for (auto i = 0u; i < hContext->NumDevices; ++i) {
detail::ur::assertion(
cuDeviceGetAttribute(
&Tmp, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
hContext->getDevices()[i]->get()) == CUDA_SUCCESS);
MinimumMajorComputeCapability =
i == 0 ? Tmp : std::min(MinimumMajorComputeCapability, Tmp);
}
uint32_t Capabilities =
(Major >= 7) ? UR_MEMORY_SCOPE_CAPABILITY_FLAG_WORK_ITEM |
UR_MEMORY_SCOPE_CAPABILITY_FLAG_SUB_GROUP |
UR_MEMORY_SCOPE_CAPABILITY_FLAG_WORK_GROUP |
UR_MEMORY_SCOPE_CAPABILITY_FLAG_DEVICE |
UR_MEMORY_SCOPE_CAPABILITY_FLAG_SYSTEM
: UR_MEMORY_SCOPE_CAPABILITY_FLAG_WORK_ITEM |
UR_MEMORY_SCOPE_CAPABILITY_FLAG_SUB_GROUP |
UR_MEMORY_SCOPE_CAPABILITY_FLAG_WORK_GROUP |
UR_MEMORY_SCOPE_CAPABILITY_FLAG_DEVICE;
(MinimumMajorComputeCapability >= 7)
? UR_MEMORY_SCOPE_CAPABILITY_FLAG_WORK_ITEM |
UR_MEMORY_SCOPE_CAPABILITY_FLAG_SUB_GROUP |
UR_MEMORY_SCOPE_CAPABILITY_FLAG_WORK_GROUP |
UR_MEMORY_SCOPE_CAPABILITY_FLAG_DEVICE |
UR_MEMORY_SCOPE_CAPABILITY_FLAG_SYSTEM
: UR_MEMORY_SCOPE_CAPABILITY_FLAG_WORK_ITEM |
UR_MEMORY_SCOPE_CAPABILITY_FLAG_SUB_GROUP |
UR_MEMORY_SCOPE_CAPABILITY_FLAG_WORK_GROUP |
UR_MEMORY_SCOPE_CAPABILITY_FLAG_DEVICE;
return ReturnValue(Capabilities);
}
case UR_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT:
Expand Down Expand Up @@ -134,9 +138,15 @@ urContextRetain(ur_context_handle_t hContext) {
return UR_RESULT_SUCCESS;
}

// FIXME this only returns the native context of the first device in the SYCL
// context
UR_APIEXPORT ur_result_t UR_APICALL urContextGetNativeHandle(
ur_context_handle_t hContext, ur_native_handle_t *phNativeContext) {
*phNativeContext = reinterpret_cast<ur_native_handle_t>(hContext->get());
UR_ASSERT(hContext, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
UR_ASSERT(phNativeContext, UR_RESULT_ERROR_INVALID_NULL_POINTER);

*phNativeContext = reinterpret_cast<ur_native_handle_t>(
hContext->getDevices()[0]->getNativeContext());
return UR_RESULT_SUCCESS;
}

Expand Down
100 changes: 46 additions & 54 deletions sycl/plugins/unified_runtime/ur/adapters/cuda/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,24 @@ typedef void (*ur_context_extended_deleter_t)(void *user_data);
/// with a given device and control access to said device from the user side.
/// UR API context are objects that are passed to functions, and not bound
/// to threads.
/// The ur_context_handle_t_ object doesn't implement this behavior. It only
/// holds the CUDA context data. The RAII object \ref ScopedContext implements
/// the active context behavior.
///
/// <b> Primary vs User-defined context </b>
/// Since the ur_context_handle_t can contain multiple devices, and a CUcontext
/// refers to only a single device, the CUcontext is more tightly coupled to a
/// ur_device_handle_t than a ur_context_handle_t. In order to remove some
/// ambiguities about the different semantics of ur_context_handle_t s and
/// native CUcontext, we access the native CUcontext solely through the
/// ur_device_handle_t class, by using the RAII object \ref ScopedDevice, which
/// sets the active device (by setting the active native CUcontext).
///
/// CUDA has two different types of context, the Primary context,
/// which is usable by all threads on a given process for a given device, and
/// the aforementioned custom contexts.
/// The CUDA documentation, confirmed with performance analysis, suggest using
/// the Primary context whenever possible.
/// The Primary context is also used by the CUDA Runtime API.
/// For UR applications to interop with CUDA Runtime API, they have to use
/// the primary context - and make that active in the thread.
/// The `ur_context_handle_t_` object can be constructed with a `kind` parameter
/// that allows to construct a Primary or `user-defined` context, so that
/// the UR object interface is always the same.
/// <b> Primary vs User-defined CUcontext </b>
///
/// CUDA has two different types of CUcontext, the Primary context, which is
/// usable by all threads on a given process for a given device, and the
/// aforementioned custom CUcontexts. The CUDA documentation, confirmed with
/// performance analysis, suggest using the Primary context whenever possible.
/// The Primary context is also used by the CUDA Runtime API. For UR
/// applications to interop with CUDA Runtime API, they have to use the primary
/// context - and make that active in the thread.
///
/// <b> Destructor callback </b>
///
Expand All @@ -61,6 +62,20 @@ typedef void (*ur_context_extended_deleter_t)(void *user_data);
/// See proposal for details.
/// https://github.com/codeplaysoftware/standards-proposals/blob/master/extended-context-destruction/index.md
///
/// <b> Memory Management for Devices in a Context <\b>
///
/// A ur_buffer_ is associated with a ur_context_handle_t_, which may refer to
/// multiple devices. Therefore the ur_buffer_ must handle a native allocation
/// for each device in the context. UR is responsible for automatically
/// handling event dependencies for kernels writing to or reading from the
/// same ur_buffer_ and migrating memory between native allocations for
/// devices in the same ur_context_handle_t_ if necessary.
///
/// TODO: This management of memory for devices in the same
/// ur_context_handle_t_ is currently only valid for buffers and not for
/// images.
///
///
struct ur_context_handle_t_ {

struct deleter_data {
Expand All @@ -72,16 +87,23 @@ struct ur_context_handle_t_ {

using native_type = CUcontext;

native_type CUContext;
ur_device_handle_t DeviceID;
std::vector<ur_device_handle_t> Devices;
uint32_t NumDevices{};

std::atomic_uint32_t RefCount;

ur_context_handle_t_(ur_device_handle_t_ *DevID)
: CUContext{DevID->getContext()}, DeviceID{DevID}, RefCount{1} {
urDeviceRetain(DeviceID);
ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices)
: Devices{Devs, Devs + NumDevices}, NumDevices{NumDevices}, RefCount{1} {
for (auto &Dev : Devices) {
urDeviceRetain(Dev);
}
};

~ur_context_handle_t_() { urDeviceRelease(DeviceID); }
~ur_context_handle_t_() {
for (auto &Dev : Devices) {
urDeviceRelease(Dev);
}
}

void invokeExtendedDeleters() {
std::lock_guard<std::mutex> Guard(Mutex);
Expand All @@ -96,9 +118,9 @@ struct ur_context_handle_t_ {
ExtendedDeleters.emplace_back(deleter_data{Function, UserData});
}

ur_device_handle_t getDevice() const noexcept { return DeviceID; }

native_type get() const noexcept { return CUContext; }
std::vector<ur_device_handle_t> getDevices() const noexcept {
return Devices;
}

uint32_t incrementReferenceCount() noexcept { return ++RefCount; }

Expand All @@ -117,33 +139,3 @@ struct ur_context_handle_t_ {
std::vector<deleter_data> ExtendedDeleters;
std::set<ur_usm_pool_handle_t> PoolHandles;
};

namespace {
class ScopedContext {
public:
ScopedContext(ur_context_handle_t Context) {
if (!Context) {
throw UR_RESULT_ERROR_INVALID_CONTEXT;
}

setContext(Context->get());
}

ScopedContext(CUcontext NativeContext) { setContext(NativeContext); }

~ScopedContext() {}

private:
void setContext(CUcontext Desired) {
CUcontext Original = nullptr;

UR_CHECK_ERROR(cuCtxGetCurrent(&Original));

// Make sure the desired context is active on the current thread, setting
// it if necessary
if (Original != Desired) {
UR_CHECK_ERROR(cuCtxSetCurrent(Desired));
}
}
};
} // namespace
4 changes: 2 additions & 2 deletions sycl/plugins/unified_runtime/ur/adapters/cuda/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,

static constexpr uint32_t MaxWorkItemDimensions = 3u;

ScopedContext Active(hDevice->getContext());
ScopedDevice Active(hDevice);

switch ((uint32_t)propName) {
case UR_DEVICE_INFO_TYPE: {
Expand Down Expand Up @@ -1226,7 +1226,7 @@ ur_result_t UR_APICALL urDeviceGetGlobalTimestamps(ur_device_handle_t hDevice,
uint64_t *pDeviceTimestamp,
uint64_t *pHostTimestamp) {
CUevent Event;
ScopedContext Active(hDevice->getContext());
ScopedDevice Active(hDevice);

if (pDeviceTimestamp) {
UR_CHECK_ERROR(cuEventCreate(&Event, CU_EVENT_DEFAULT));
Expand Down
51 changes: 48 additions & 3 deletions sycl/plugins/unified_runtime/ur/adapters/cuda/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ struct ur_device_handle_t_ {
bool MaxLocalMemSizeChosen{false};

public:
size_t DeviceIndex;

ur_device_handle_t_(native_type cuDevice, CUcontext cuContext, CUevent evBase,
ur_platform_handle_t platform)
ur_platform_handle_t platform, size_t deviceIndex)
: CuDevice(cuDevice), CuContext(cuContext), EvBase(evBase), RefCount{1},
Platform(platform) {
Platform(platform), DeviceIndex{deviceIndex} {

UR_CHECK_ERROR(cuDeviceGetAttribute(
&MaxBlockDimY, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y, cuDevice));
Expand Down Expand Up @@ -74,14 +76,18 @@ struct ur_device_handle_t_ {

native_type get() const noexcept { return CuDevice; };

CUcontext getContext() const noexcept { return CuContext; };
CUcontext getNativeContext() const noexcept { return CuContext; };

uint32_t getReferenceCount() const noexcept { return RefCount; }

ur_platform_handle_t getPlatform() const noexcept { return Platform; };

uint64_t getElapsedTime(CUevent) const;

// Returns the index of the device in question relative to the other devices
// in the platform
size_t getIndex() const { return DeviceIndex; }

void saveMaxWorkItemSizes(size_t Size,
size_t *SaveMaxWorkItemSizes) noexcept {
memcpy(MaxWorkItemSizes, SaveMaxWorkItemSizes, Size);
Expand Down Expand Up @@ -112,3 +118,42 @@ struct ur_device_handle_t_ {
};

int getAttribute(ur_device_handle_t Device, CUdevice_attribute Attribute);

// This will be important for changing from device to device
namespace {
class ScopedDevice {
CUcontext Original;
bool NeedToRecover = false;

public:
ScopedDevice(CUcontext NativeContext) { setContext(NativeContext); }
ScopedDevice(ur_device_handle_t Device) {
if (!Device) {
throw UR_RESULT_ERROR_INVALID_DEVICE;
}
setContext(Device->getNativeContext());
}

~ScopedDevice() {
if (NeedToRecover) {
UR_CHECK_ERROR(cuCtxSetCurrent(Original));
}
}

private:
void setContext(CUcontext Desired) {

UR_CHECK_ERROR(cuCtxGetCurrent(&Original));

if (Original != nullptr) {
NeedToRecover = true;
}

// Make sure the desired context is active on the current thread, setting
// it if necessary
if (Original != Desired) {
UR_CHECK_ERROR(cuCtxSetCurrent(Desired));
}
}
};
} // namespace
Loading