Skip to content

Commit 3c73acc

Browse files
author
Hugh Delaney
committed
Multi dev ctx for hip adapter
1 parent 3605f74 commit 3c73acc

File tree

17 files changed

+980
-594
lines changed

17 files changed

+980
-594
lines changed

sycl/plugins/unified_runtime/ur/adapters/hip/context.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,13 @@ ur_context_handle_t_::getOwningURPool(umf_memory_pool_t *UMFPool) {
3838
UR_APIEXPORT ur_result_t UR_APICALL urContextCreate(
3939
uint32_t DeviceCount, const ur_device_handle_t *phDevices,
4040
const ur_context_properties_t *, ur_context_handle_t *phContext) {
41-
std::ignore = DeviceCount;
42-
assert(DeviceCount == 1);
4341
ur_result_t RetErr = UR_RESULT_SUCCESS;
4442

4543
std::unique_ptr<ur_context_handle_t_> ContextPtr{nullptr};
4644
try {
4745
// Create a scoped context.
4846
ContextPtr = std::unique_ptr<ur_context_handle_t_>(
49-
new ur_context_handle_t_{*phDevices});
47+
new ur_context_handle_t_{phDevices, DeviceCount});
5048

5149
static std::once_flag InitFlag;
5250
std::call_once(
@@ -78,7 +76,7 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName,
7876
case UR_CONTEXT_INFO_NUM_DEVICES:
7977
return ReturnValue(1);
8078
case UR_CONTEXT_INFO_DEVICES:
81-
return ReturnValue(hContext->getDevice());
79+
return ReturnValue(hContext->getDevices());
8280
case UR_CONTEXT_INFO_REFERENCE_COUNT:
8381
return ReturnValue(hContext->getReferenceCount());
8482
case UR_CONTEXT_INFO_ATOMIC_MEMORY_ORDER_CAPABILITIES:
@@ -121,8 +119,10 @@ urContextRetain(ur_context_handle_t hContext) {
121119

122120
UR_APIEXPORT ur_result_t UR_APICALL urContextGetNativeHandle(
123121
ur_context_handle_t hContext, ur_native_handle_t *phNativeContext) {
122+
// FIXME this only returns the native context of the first device in the
123+
// SYCL context. This entry point should be deprecated.
124124
*phNativeContext = reinterpret_cast<ur_native_handle_t>(
125-
hContext->getDevice()->getNativeContext());
125+
hContext->getDevices()[0]->getNativeContext());
126126
return UR_RESULT_SUCCESS;
127127
}
128128

sycl/plugins/unified_runtime/ur/adapters/hip/context.hpp

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,22 +31,24 @@ typedef void (*ur_context_extended_deleter_t)(void *UserData);
3131
/// with a given device and control access to said device from the user side.
3232
/// UR API context are objects that are passed to functions, and not bound
3333
/// to threads.
34-
/// The ur_context_handle_t_ object doesn't implement this behavior. It only
35-
/// holds the HIP context data. The RAII object \ref ScopedDevice implements
36-
/// the active context behavior.
3734
///
38-
/// <b> Primary vs UserDefined context </b>
35+
/// Since the ur_context_handle_t can contain multiple devices, and a `hipCtx_t`
36+
/// refers to only a single device, the `hipCtx_t` is more tightly coupled to a
37+
/// ur_device_handle_t than a ur_context_handle_t. In order to remove some
38+
/// ambiguities about the different semantics of ur_context_handle_t s and
39+
/// native `hipCtx_t`, we access the native `hipCtx_t` solely through the
40+
/// ur_device_handle_t class, by using the RAII object \ref ScopedDevice, which
41+
/// sets the active device (by setting the active native `hipCtx_t`).
3942
///
40-
/// HIP has two different types of context, the Primary context,
41-
/// which is usable by all threads on a given process for a given device, and
42-
/// the aforementioned custom contexts.
43-
/// The HIP documentation, and performance analysis, suggest using the Primary
44-
/// context whenever possible. The Primary context is also used by the HIP
45-
/// Runtime API. For UR applications to interop with HIP Runtime API, they have
46-
/// to use the primary context - and make that active in the thread. The
47-
/// `ur_context_handle_t_` object can be constructed with a `kind` parameter
48-
/// that allows to construct a Primary or `UserDefined` context, so that
49-
/// the UR object interface is always the same.
43+
/// <b> Primary vs User-defined `hipCtx_t` </b>
44+
///
45+
/// HIP has two different types of `hipCtx_t`, the Primary context, which is
46+
/// usable by all threads on a given process for a given device, and the
47+
/// aforementioned custom `hipCtx_t`s.
48+
/// The HIP documentation, confirmed with performance analysis, suggest using
49+
/// the Primary context whenever possible. The Primary context is also used by
50+
/// the HIP Runtime API. For UR applications to interop with HIP Runtime API,
51+
/// they have to use the primary context - and make that active in the thread.
5052
///
5153
/// <b> Destructor callback </b>
5254
///
@@ -56,6 +58,15 @@ typedef void (*ur_context_extended_deleter_t)(void *UserData);
5658
/// See proposal for details.
5759
/// https://github.com/codeplaysoftware/standards-proposals/blob/master/extended-context-destruction/index.md
5860
///
61+
/// <b> Memory Management for Devices in a Context <\b>
62+
///
63+
/// A ur_buffer_ is associated with a ur_context_handle_t_, which may refer to
64+
/// multiple devices. Therefore the ur_buffer_ must handle a native allocation
65+
/// for each device in the context. UR is responsible for automatically
66+
/// handling event dependencies for kernels writing to or reading from the
67+
/// same ur_buffer_ and migrating memory between native allocations for
68+
/// devices in the same ur_context_handle_t_ if necessary.
69+
///
5970
struct ur_context_handle_t_ {
6071

6172
struct deleter_data {
@@ -67,15 +78,23 @@ struct ur_context_handle_t_ {
6778

6879
using native_type = hipCtx_t;
6980

70-
ur_device_handle_t DeviceId;
81+
std::vector<ur_device_handle_t> Devices;
82+
uint32_t NumDevices;
83+
7184
std::atomic_uint32_t RefCount;
7285

73-
ur_context_handle_t_(ur_device_handle_t DevId)
74-
: DeviceId{DevId}, RefCount{1} {
75-
urDeviceRetain(DeviceId);
86+
ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices)
87+
: Devices{Devs, Devs + NumDevices}, NumDevices{NumDevices}, RefCount{1} {
88+
for (auto &Dev : Devices) {
89+
urDeviceRetain(Dev);
90+
}
7691
};
7792

78-
~ur_context_handle_t_() { urDeviceRelease(DeviceId); }
93+
~ur_context_handle_t_() {
94+
for (auto &Dev : Devices) {
95+
urDeviceRelease(Dev);
96+
}
97+
}
7998

8099
void invokeExtendedDeleters() {
81100
std::lock_guard<std::mutex> Guard(Mutex);
@@ -90,7 +109,9 @@ struct ur_context_handle_t_ {
90109
ExtendedDeleters.emplace_back(deleter_data{Function, UserData});
91110
}
92111

93-
ur_device_handle_t getDevice() const noexcept { return DeviceId; }
112+
std::vector<ur_device_handle_t> getDevices() const noexcept {
113+
return Devices;
114+
}
94115

95116
uint32_t incrementReferenceCount() noexcept { return ++RefCount; }
96117

sycl/plugins/unified_runtime/ur/adapters/hip/device.hpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@ struct ur_device_handle_t_ {
2323
std::atomic_uint32_t RefCount;
2424
ur_platform_handle_t Platform;
2525
hipCtx_t HIPContext;
26+
size_t DeviceIndex; // The index of the device in the UR context
2627

2728
public:
2829
ur_device_handle_t_(native_type HipDevice, hipCtx_t Context,
29-
ur_platform_handle_t Platform)
30+
ur_platform_handle_t Platform, size_t DeviceIndex)
3031
: HIPDevice(HipDevice), RefCount{1}, Platform(Platform),
31-
HIPContext(Context) {}
32+
HIPContext(Context), DeviceIndex(DeviceIndex) {}
3233

3334
~ur_device_handle_t_() {
3435
UR_CHECK_ERROR(hipDevicePrimaryCtxRelease(HIPDevice));
@@ -41,17 +42,25 @@ struct ur_device_handle_t_ {
4142
ur_platform_handle_t getPlatform() const noexcept { return Platform; };
4243

4344
hipCtx_t getNativeContext() { return HIPContext; };
45+
46+
// Returns the index of the device in question relative to the other devices
47+
// in the platform
48+
size_t getIndex() { return DeviceIndex; }
4449
};
4550

4651
int getAttribute(ur_device_handle_t Device, hipDeviceAttribute_t Attribute);
4752

4853
namespace {
49-
/// RAII type to guarantee recovering original HIP context
50-
/// Scoped context is used across all UR HIP plugin implementation
51-
/// to activate the UR Context on the current thread, matching the
54+
/// RAII type to guarantee recovering original HIP device. In UR the
55+
/// `ScopedDevice` sets the active device by using the native underlying
56+
/// `hipCtx_t`. Since a UR context can contain multiple devices, whereas a
57+
/// `hipCtx_t` refers to a single device, it is semantically clearer to access
58+
/// the `hipCtx_t` through the UR device rather than the UR context.
59+
/// Scoped device is used across all UR HIP plugin implementation
60+
/// to activate the UR Device on the current thread, matching the
5261
/// HIP driver semantics where the context used for the HIP Driver
5362
/// API is the one active on the thread.
54-
/// The implementation tries to avoid replacing the hipCtx_t if it cans
63+
/// The implementation tries to avoid replacing the hipCtx_t if it can
5564
class ScopedDevice {
5665
hipCtx_t Original;
5766
bool NeedToRecover;

0 commit comments

Comments
 (0)