Skip to content

[SYCL] Attach auxiliary resources to memory objects #5588

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sycl/include/CL/sycl/detail/buffer_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ class __SYCL_EXPORT buffer_impl final : public SYCLMemObjT {
~buffer_impl() {
try {
BaseT::updateHostMemory();
BaseT::detachResources();
} catch (...) {
}
destructorNotification(this);
Expand Down
1 change: 1 addition & 0 deletions sycl/include/CL/sycl/detail/image_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ class __SYCL_EXPORT image_impl final : public SYCLMemObjT {
~image_impl() {
try {
BaseT::updateHostMemory();
BaseT::detachResources();
} catch (...) {
}
}
Expand Down
3 changes: 3 additions & 0 deletions sycl/include/CL/sycl/detail/sycl_mem_obj_t.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@ class __SYCL_EXPORT SYCLMemObjT : public SYCLMemObjI {
// members must be alive.
void updateHostMemory();

// Detach additional resources associated with the memory object.
void detachResources() const;

public:
__SYCL_DLL_LOCAL bool useHostPtr() {
return has_property<property::buffer::use_host_ptr>() ||
Expand Down
1 change: 1 addition & 0 deletions sycl/include/CL/sycl/handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ class __SYCL_EXPORT handler {
/// They are then forwarded to command group and destroyed only after
/// the command group finishes the work on device/host.
/// The 'MSharedPtrStorage' suits that need.
/// NOTE: This is no longer in use and should be removed with next ABI break.
///
/// @param ReduObj is a pointer to object that must be stored.
void addReduction(const std::shared_ptr<const void> &ReduObj) {
Expand Down
41 changes: 34 additions & 7 deletions sycl/include/sycl/ext/oneapi/reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,16 @@ template <typename T> struct AreAllButLastReductions<T> {
static constexpr bool value = !std::is_base_of<reduction_impl_base, T>::value;
};

/// Helper for attaching a resource to the lifetime of the memory associated
/// with accessor.
__SYCL_EXPORT void attachLifetime(std::shared_ptr<const void> &Resource,
detail::AccessorBaseHost &AttachTo);

/// Helper for attaching a resource to the lifetime of USM memory.
__SYCL_EXPORT void attachLifetime(std::shared_ptr<queue_impl> &Queue,
std::shared_ptr<const void> &Resource,
void *AttachTo);

/// This class encapsulates the reduction variable/accessor,
/// the reduction operator and an optional operator identity.
template <typename T, class BinaryOperation, int Dims, bool IsUSM,
Expand Down Expand Up @@ -645,7 +655,7 @@ class reduction_impl : private reduction_impl_base {

accessor<T, buffer_dim, access::mode::read>
getReadAccToPreviousPartialReds(handler &CGH) const {
CGH.addReduction(MOutBufPtr);
attachResourceLifetimeToMem(CGH, MOutBufPtr);
return {*MOutBufPtr, CGH};
}

Expand Down Expand Up @@ -673,7 +683,7 @@ class reduction_impl : private reduction_impl_base {
std::enable_if_t<!IsOneWG, rw_accessor_type>
getWriteMemForPartialReds(size_t Size, handler &CGH) {
MOutBufPtr = std::make_shared<buffer<T, buffer_dim>>(range<1>(Size));
CGH.addReduction(MOutBufPtr);
attachResourceLifetimeToMem(CGH, MOutBufPtr);
return createHandlerWiredReadWriteAccessor(CGH, *MOutBufPtr);
}

Expand All @@ -691,7 +701,7 @@ class reduction_impl : private reduction_impl_base {

// Create a new output buffer and return an accessor to it.
MOutBufPtr = std::make_shared<buffer<T, buffer_dim>>(range<1>(Size));
CGH.addReduction(MOutBufPtr);
attachResourceLifetimeToMem(CGH, MOutBufPtr);
return createHandlerWiredReadWriteAccessor(CGH, *MOutBufPtr);
}

Expand All @@ -707,19 +717,19 @@ class reduction_impl : private reduction_impl_base {
return *MRWAcc;

auto RWReduVal = std::make_shared<T>(MIdentity);
CGH.addReduction(RWReduVal);
attachResourceLifetimeToMem(CGH, RWReduVal);
MOutBufPtr = std::make_shared<buffer<T, 1>>(RWReduVal.get(), range<1>(1));
CGH.addReduction(MOutBufPtr);
attachResourceLifetimeToMem(CGH, MOutBufPtr);
return createHandlerWiredReadWriteAccessor(CGH, *MOutBufPtr);
}

accessor<int, 1, access::mode::read_write, access::target::device,
access::placeholder::false_t>
getReadWriteAccessorToInitializedGroupsCounter(handler &CGH) {
auto CounterMem = std::make_shared<int>(0);
CGH.addReduction(CounterMem);
attachResourceLifetimeToMem(CGH, CounterMem);
auto CounterBuf = std::make_shared<buffer<int, 1>>(CounterMem.get(), 1);
CGH.addReduction(CounterBuf);
attachResourceLifetimeToMem(CGH, CounterBuf);
return {*CounterBuf, CGH};
}

Expand Down Expand Up @@ -767,6 +777,23 @@ class reduction_impl : private reduction_impl_base {
return Acc;
}

/// Attaches the resource to the lifetime of the associated memory of the
/// reduction.
void attachResourceLifetimeToMem(handler &CGH,
std::shared_ptr<const void> Resource) const {
#ifndef __SYCL_DEVICE_ONLY__
if (is_usm)
detail::attachLifetime(CGH.MQueue, Resource, MUSMPointer);
else if (MDWAcc != nullptr)
detail::attachLifetime(Resource, *MDWAcc);
else
detail::attachLifetime(Resource, *MRWAcc);
#else
(void)CGH;
(void)Resource;
#endif
}

/// Identity of the BinaryOperation.
/// The result of BinaryOperation(X, MIdentity) is equal to X for any X.
const T MIdentity;
Expand Down
36 changes: 36 additions & 0 deletions sycl/source/detail/context_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,17 @@ cl_context context_impl::get() const {
bool context_impl::is_host() const { return MHostContext; }

context_impl::~context_impl() {
// In case a user is leaking a memory object we may still have attached
// resources. These resources will be released with the context, but we need
// to do it before we release the backend context to avoid confusing errors.
std::unordered_map<const void *, std::vector<std::shared_ptr<const void>>>
USMLTARes;
{
std::lock_guard<std::mutex> lock(MUSMLifetimeAttachedResourcesMutex);
std::swap(USMLTARes, MUSMLifetimeAttachedResources);
}
USMLTARes.clear();

for (auto LibProg : MCachedLibPrograms) {
assert(LibProg.second && "Null program must not be kept in the cache");
getPlugin().call<PiApiKind::piProgramRelease>(LibProg.second);
Expand Down Expand Up @@ -206,6 +217,31 @@ pi_native_handle context_impl::getNative() const {
return Handle;
}

void context_impl::attachLifetimeToUSM(std::shared_ptr<const void> &Resource,
const void *AttachTo) {
std::lock_guard<std::mutex> lock(MUSMLifetimeAttachedResourcesMutex);
auto AttachedResourcesIt = MUSMLifetimeAttachedResources.find(AttachTo);
if (AttachedResourcesIt != MUSMLifetimeAttachedResources.end())
AttachedResourcesIt->second.push_back(Resource);
else
MUSMLifetimeAttachedResources.insert({AttachTo, {Resource}});
}

void context_impl::detachUSMLifetimeResources(const void *AttachedTo) {
// Swap the attached resources and let them go out of scope without the lock.
// This is required as they could potentially have their own attached
// resources they need to detach.
std::vector<std::shared_ptr<const void>> AttachedResources;
{
std::lock_guard<std::mutex> lock(MUSMLifetimeAttachedResourcesMutex);
auto AttachedResourcesIt = MUSMLifetimeAttachedResources.find(AttachedTo);
if (AttachedResourcesIt == MUSMLifetimeAttachedResources.end())
return;
std::swap(AttachedResourcesIt->second, AttachedResources);
MUSMLifetimeAttachedResources.erase(AttachedResourcesIt);
}
}

} // namespace detail
} // namespace sycl
} // __SYCL_INLINE_NAMESPACE(cl)
19 changes: 19 additions & 0 deletions sycl/source/detail/context_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,18 @@ class context_impl {
/// \return a native handle.
pi_native_handle getNative() const;

/// Attach a resource to a USM pointer.
///
/// \param Resource is the resource to attach to the USM pointer
/// \param AttachTo is the USM pointer to attach the resource to
void attachLifetimeToUSM(std::shared_ptr<const void> &Resource,
const void *AttachTo);

/// Detach all resources attached to a USM pointer.
///
/// \param AttachedTo is the USM pointer to detach resources from
void detachUSMLifetimeResources(const void *AttachedTo);

private:
async_handler MAsyncHandler;
std::vector<device> MDevices;
Expand All @@ -177,6 +189,13 @@ class context_impl {
std::map<std::pair<DeviceLibExt, RT::PiDevice>, RT::PiProgram>
MCachedLibPrograms;
mutable KernelProgramCache MKernelProgramCache;

/// Matches USM pointers to attached resources.
std::unordered_map<const void *, std::vector<std::shared_ptr<const void>>>
MUSMLifetimeAttachedResources;

/// Protects m_USMLifetimeAttachedResources.
std::mutex MUSMLifetimeAttachedResourcesMutex;
};

} // namespace detail
Expand Down
16 changes: 16 additions & 0 deletions sycl/source/detail/global_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@ ThreadPool &GlobalHandler::getHostTaskThreadPool() {
return TP;
}

std::unordered_map<const SYCLMemObjI *,
std::vector<std::shared_ptr<const void>>> &
GlobalHandler::getMemObjLifetimeAttachedResources() {
return getOrCreate(MMemObjLifetimeAttachedResources);
}

std::mutex &GlobalHandler::getMemObjLifetimeAttachedResourcesMutex() {
return getOrCreate(MMemObjLifetimeAttachedResourcesMutex);
}

void releaseDefaultContexts() {
// Release shared-pointers to SYCL objects.
#ifndef _WIN32
Expand All @@ -121,6 +131,12 @@ void GlobalHandler::registerDefaultContextReleaseHandler() {
}

void shutdown() {
// In case a user is leaking a memory object we may still have attached
// resources. These resources will be released with the context, but we need
// to do it before we release the backend context to avoid confusing errors.
if (GlobalHandler::instance().MMemObjLifetimeAttachedResources.Inst)
GlobalHandler::instance().MMemObjLifetimeAttachedResources.Inst->clear();

// Ensure neither host task is working so that no default context is accessed
// upon its release
if (GlobalHandler::instance().MHostTaskThreadPool.Inst)
Expand Down
12 changes: 12 additions & 0 deletions sycl/source/detail/global_handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class plugin;
class device_filter_list;
class XPTIRegistry;
class ThreadPool;
class SYCLMemObjI;

using PlatformImplPtr = std::shared_ptr<platform_impl>;
using ContextImplPtr = std::shared_ptr<context_impl>;
Expand Down Expand Up @@ -70,6 +71,11 @@ class GlobalHandler {
std::mutex &getHandlerExtendedMembersMutex();
ThreadPool &getHostTaskThreadPool();

std::unordered_map<const SYCLMemObjI *,
std::vector<std::shared_ptr<const void>>> &
getMemObjLifetimeAttachedResources();
std::mutex &getMemObjLifetimeAttachedResourcesMutex();

static void registerDefaultContextReleaseHandler();

private:
Expand Down Expand Up @@ -105,6 +111,12 @@ class GlobalHandler {
InstWithLock<std::mutex> MHandlerExtendedMembersMutex;
// Thread pool for host task and event callbacks execution
InstWithLock<ThreadPool> MHostTaskThreadPool;

/// TODO: On ABI break this should be made part of SYCLMemObjT.
InstWithLock<std::unordered_map<const SYCLMemObjI *,
std::vector<std::shared_ptr<const void>>>>
MMemObjLifetimeAttachedResources;
InstWithLock<std::mutex> MMemObjLifetimeAttachedResourcesMutex;
};
} // namespace detail
} // namespace sycl
Expand Down
21 changes: 21 additions & 0 deletions sycl/source/detail/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,27 @@ reduGetMaxWGSize(std::shared_ptr<sycl::detail::queue_impl> Queue,
return WGSize;
}

__SYCL_EXPORT void attachLifetime(std::shared_ptr<const void> &Resource,
detail::AccessorBaseHost &AttachTo) {
SYCLMemObjI *MemObj = getSyclObjImpl(AttachTo)->MSYCLMemObj;
// On ABI break this should attach directly to the memory object.
std::lock_guard<std::mutex> lock(
GlobalHandler::instance().getMemObjLifetimeAttachedResourcesMutex());
auto &AttachedResourcesMap =
GlobalHandler::instance().getMemObjLifetimeAttachedResources();
auto AttachedResourcesIt = AttachedResourcesMap.find(MemObj);
if (AttachedResourcesIt != AttachedResourcesMap.end())
AttachedResourcesIt->second.push_back(Resource);
else
AttachedResourcesMap.insert({MemObj, {Resource}});
}

__SYCL_EXPORT void attachLifetime(std::shared_ptr<queue_impl> &Queue,
std::shared_ptr<const void> &Resource,
void *AttachTo) {
Queue->getContextImplPtr()->attachLifetimeToUSM(Resource, AttachTo);
}

} // namespace detail
} // namespace oneapi
} // namespace ext
Expand Down
23 changes: 23 additions & 0 deletions sycl/source/detail/sycl_mem_obj_t.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <CL/sycl/detail/sycl_mem_obj_t.hpp>
#include <detail/context_impl.hpp>
#include <detail/event_impl.hpp>
#include <detail/global_handler.hpp>
#include <detail/plugin.hpp>
#include <detail/scheduler/scheduler.hpp>

Expand Down Expand Up @@ -90,6 +91,28 @@ void SYCLMemObjT::updateHostMemory() {
pi::cast<RT::PiMem>(MInteropMemObject));
}
}

// TODO: With ABI break the attached resources can be held by this type. When
// that happens this will be obsolete as the resources will automatically be
// destroyed with the object.
void SYCLMemObjT::detachResources() const {
// Swap the attached resources and let them go out of scope without the lock.
// This is required as they could potentially have their own attached
// resources they need to detach.
std::vector<std::shared_ptr<const void>> AttachedResources;
{
std::lock_guard<std::mutex> lock(
GlobalHandler::instance().getMemObjLifetimeAttachedResourcesMutex());
auto &AttachedResourcesMap =
GlobalHandler::instance().getMemObjLifetimeAttachedResources();
auto AttachedResourcesIt = AttachedResourcesMap.find(this);
if (AttachedResourcesIt == AttachedResourcesMap.end())
return;
std::swap(AttachedResourcesIt->second, AttachedResources);
AttachedResourcesMap.erase(AttachedResourcesIt);
}
}

const plugin &SYCLMemObjT::getPlugin() const {
assert((MInteropContext != nullptr) &&
"Trying to get Plugin from SYCLMemObjT with nullptr ContextImpl.");
Expand Down
2 changes: 2 additions & 0 deletions sycl/source/detail/usm/usm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ void free(void *Ptr, const context &Ctxt, const detail::code_location &CL) {
const detail::plugin &Plugin = CtxImpl->getPlugin();
Plugin.call<PiApiKind::piextUSMFree>(C, Ptr);
}
// Detach resources.
detail::getSyclObjImpl(Ctxt)->detachUSMLifetimeResources(Ptr);
}

// For ABI compatibility
Expand Down
3 changes: 3 additions & 0 deletions sycl/test/abi/sycl_symbols_linux.dump
Original file line number Diff line number Diff line change
Expand Up @@ -3697,6 +3697,8 @@ _ZN2cl4sycl3ext6oneapi10level_zero12make_programERKNS0_7contextEm
_ZN2cl4sycl3ext6oneapi10level_zero13make_platformEm
_ZN2cl4sycl3ext6oneapi15filter_selectorC1ERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
_ZN2cl4sycl3ext6oneapi15filter_selectorC2ERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
_ZN2cl4sycl3ext6oneapi6detail14attachLifetimeERSt10shared_ptrIKvERNS0_6detail16AccessorBaseHostE
_ZN2cl4sycl3ext6oneapi6detail14attachLifetimeERSt10shared_ptrINS0_6detail10queue_implEERS4_IKvEPv
_ZN2cl4sycl3ext6oneapi6detail16reduGetMaxWGSizeESt10shared_ptrINS0_6detail10queue_implEEm
_ZN2cl4sycl3ext6oneapi6detail17reduComputeWGSizeEmmRm
_ZN2cl4sycl3ext6oneapi6detail33reduGetMaxNumConcurrentWorkGroupsESt10shared_ptrINS0_6detail10queue_implEE
Expand Down Expand Up @@ -4148,6 +4150,7 @@ _ZNK2cl4sycl6detail10image_implILi3EE4sizeEv
_ZNK2cl4sycl6detail10image_implILi3EE7getTypeEv
_ZNK2cl4sycl6detail10image_implILi3EE9get_countEv
_ZNK2cl4sycl6detail10image_implILi3EE9get_rangeEv
_ZNK2cl4sycl6detail11SYCLMemObjT15detachResourcesEv
_ZNK2cl4sycl6detail11SYCLMemObjT9getPluginEv
_ZNK2cl4sycl6detail11SYCLMemObjT9isInteropEv
_ZNK2cl4sycl6detail11stream_impl22get_max_statement_sizeEv
Expand Down
3 changes: 3 additions & 0 deletions sycl/test/abi/sycl_symbols_windows.dump
Original file line number Diff line number Diff line change
Expand Up @@ -1313,6 +1313,8 @@
?atanpi@__host_std@cl@@YA?AVhalf@half_impl@detail@sycl@2@V34562@@Z
?atanpi@__host_std@cl@@YAMM@Z
?atanpi@__host_std@cl@@YANN@Z
?attachLifetime@detail@oneapi@ext@sycl@cl@@YAXAEAV?$shared_ptr@$$CBX@std@@AEAVAccessorBaseHost@145@@Z
?attachLifetime@detail@oneapi@ext@sycl@cl@@YAXAEAV?$shared_ptr@Vqueue_impl@detail@sycl@cl@@@std@@AEAV?$shared_ptr@$$CBX@7@PEAX@Z
?barrier@handler@sycl@cl@@QEAAXAEBV?$vector@Vevent@sycl@cl@@V?$allocator@Vevent@sycl@cl@@@std@@@std@@@Z
?barrier@handler@sycl@cl@@QEAAXXZ
?begin@exception_list@sycl@cl@@QEBA?AV?$_Vector_const_iterator@V?$_Vector_val@U?$_Simple_types@Vexception_ptr@std@@@std@@@std@@@std@@XZ
Expand Down Expand Up @@ -1689,6 +1691,7 @@
?depends_on@handler@sycl@cl@@QEAAXAEBV?$vector@Vevent@sycl@cl@@V?$allocator@Vevent@sycl@cl@@@std@@@std@@@Z
?depends_on@handler@sycl@cl@@QEAAXVevent@23@@Z
?destructorNotification@buffer_impl@detail@sycl@cl@@QEAAXPEAX@Z
?detachResources@SYCLMemObjT@detail@sycl@cl@@IEBAXXZ
?determineHostPtr@SYCLMemObjT@detail@sycl@cl@@IEAAXAEBV?$shared_ptr@Vcontext_impl@detail@sycl@cl@@@std@@_NAEAPEAXAEA_N@Z
?device_has@queue@sycl@cl@@QEBA_NW4aspect@23@@Z
?die@pi@detail@sycl@cl@@YAXPEBD@Z
Expand Down