Skip to content

Commit 0c6ce91

Browse files
[NFC][SYCL] Use context_impl & in sampler_impl ctor and near it (#19153)
`SetArgBasedOnType` argument is only used to pass to the `sampler_impl` ctor so update it. `getCGKernelInfo` is only called in a function that also calls `sampler_impl` ctor so updating its signuature allows to update that caller's local `ContextImpl` variable, so makes sense to do as part of this PR as well. Continuation of the refactoring in #18795 #18877 #18966 #18979 #18980 #18981 #19007 #19030 #19123 #19126
1 parent a28dca5 commit 0c6ce91

File tree

5 files changed

+25
-25
lines changed

5 files changed

+25
-25
lines changed

sycl/source/detail/sampler_impl.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,14 @@ sampler_impl::sampler_impl(coordinate_normalization_mode normalizationMode,
2424
verifyProps(MPropList);
2525
}
2626

27-
sampler_impl::sampler_impl(cl_sampler clSampler,
28-
const ContextImplPtr &syclContext) {
29-
const AdapterPtr &Adapter = syclContext->getAdapter();
27+
sampler_impl::sampler_impl(cl_sampler clSampler, context_impl &syclContext) {
28+
const AdapterPtr &Adapter = syclContext.getAdapter();
3029
ur_sampler_handle_t Sampler{};
3130
Adapter->call<UrApiKind::urSamplerCreateWithNativeHandle>(
3231
reinterpret_cast<ur_native_handle_t>(clSampler),
33-
syclContext->getHandleRef(), nullptr, &Sampler);
32+
syclContext.getHandleRef(), nullptr, &Sampler);
3433

35-
MContextToSampler[syclContext] = Sampler;
34+
MContextToSampler[syclContext.shared_from_this()] = Sampler;
3635
bool NormalizedCoords;
3736

3837
Adapter->call<UrApiKind::urSamplerGetInfo>(
@@ -95,10 +94,14 @@ sampler_impl::~sampler_impl() {
9594
}
9695

9796
ur_sampler_handle_t
98-
sampler_impl::getOrCreateSampler(const ContextImplPtr &ContextImpl) {
97+
sampler_impl::getOrCreateSampler(context_impl &ContextImpl) {
98+
// Just for the `MContextToSampler` lookups. Could probably be changed once we
99+
// move to C++20 and would have heterogeneous lookup.
100+
std::shared_ptr<context_impl> ContextImplPtr = ContextImpl.shared_from_this();
101+
99102
{
100103
std::lock_guard<std::mutex> Lock(MMutex);
101-
auto It = MContextToSampler.find(ContextImpl);
104+
auto It = MContextToSampler.find(ContextImplPtr);
102105
if (It != MContextToSampler.end())
103106
return It->second;
104107
}
@@ -135,18 +138,18 @@ sampler_impl::getOrCreateSampler(const ContextImplPtr &ContextImpl) {
135138

136139
ur_result_t errcode_ret = UR_RESULT_SUCCESS;
137140
ur_sampler_handle_t resultSampler = nullptr;
138-
const AdapterPtr &Adapter = ContextImpl->getAdapter();
141+
const AdapterPtr &Adapter = ContextImpl.getAdapter();
139142

140143
errcode_ret = Adapter->call_nocheck<UrApiKind::urSamplerCreate>(
141-
ContextImpl->getHandleRef(), &desc, &resultSampler);
144+
ContextImpl.getHandleRef(), &desc, &resultSampler);
142145

143146
if (errcode_ret == UR_RESULT_ERROR_UNSUPPORTED_FEATURE)
144147
throw sycl::exception(sycl::errc::feature_not_supported,
145148
"Images are not supported by this device.");
146149

147150
Adapter->checkUrResult(errcode_ret);
148151
std::lock_guard<std::mutex> Lock(MMutex);
149-
MContextToSampler[ContextImpl] = resultSampler;
152+
MContextToSampler[ContextImplPtr] = resultSampler;
150153

151154
return resultSampler;
152155
}

sycl/source/detail/sampler_impl.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,23 +30,22 @@ enum class coordinate_normalization_mode : unsigned int;
3030
namespace detail {
3131

3232
class context_impl;
33-
using ContextImplPtr = std::shared_ptr<context_impl>;
3433

3534
class sampler_impl {
3635
public:
3736
sampler_impl(coordinate_normalization_mode normalizationMode,
3837
addressing_mode addressingMode, filtering_mode filteringMode,
3938
const property_list &propList);
4039

41-
sampler_impl(cl_sampler clSampler, const ContextImplPtr &syclContext);
40+
sampler_impl(cl_sampler clSampler, context_impl &syclContext);
4241

4342
addressing_mode get_addressing_mode() const;
4443

4544
filtering_mode get_filtering_mode() const;
4645

4746
coordinate_normalization_mode get_coordinate_normalization_mode() const;
4847

49-
ur_sampler_handle_t getOrCreateSampler(const ContextImplPtr &ContextImpl);
48+
ur_sampler_handle_t getOrCreateSampler(context_impl &ContextImpl);
5049

5150
~sampler_impl();
5251

@@ -56,7 +55,8 @@ class sampler_impl {
5655
/// Protects all the fields that can be changed by class' methods.
5756
std::mutex MMutex;
5857

59-
std::unordered_map<ContextImplPtr, ur_sampler_handle_t> MContextToSampler;
58+
std::unordered_map<std::shared_ptr<context_impl>, ur_sampler_handle_t>
59+
MContextToSampler;
6060

6161
coordinate_normalization_mode MCoordNormMode;
6262
addressing_mode MAddrMode;

sycl/source/detail/scheduler/commands.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2313,8 +2313,7 @@ void SetArgBasedOnType(
23132313
const AdapterPtr &Adapter, ur_kernel_handle_t Kernel,
23142314
const std::shared_ptr<device_image_impl> &DeviceImageImpl,
23152315
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2316-
const ContextImplPtr &ContextImpl, detail::ArgDesc &Arg,
2317-
size_t NextTrueIndex) {
2316+
context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
23182317
switch (Arg.MType) {
23192318
case kernel_param_kind_t::kind_dynamic_work_group_memory:
23202319
break;
@@ -2442,7 +2441,7 @@ static ur_result_t SetKernelParamsAndLaunch(
24422441
auto setFunc = [&Adapter, Kernel, &DeviceImageImpl, &getMemAllocationFunc,
24432442
&Queue](detail::ArgDesc &Arg, size_t NextTrueIndex) {
24442443
SetArgBasedOnType(Adapter, Kernel, DeviceImageImpl, getMemAllocationFunc,
2445-
Queue.getContextImplPtr(), Arg, NextTrueIndex);
2444+
Queue.getContextImpl(), Arg, NextTrueIndex);
24462445
};
24472446
applyFuncOnFilteredArgs(EliminatedArgMask, Args, setFunc);
24482447
}
@@ -2530,7 +2529,7 @@ static ur_result_t SetKernelParamsAndLaunch(
25302529

25312530
static std::tuple<ur_kernel_handle_t, std::shared_ptr<device_image_impl>,
25322531
const KernelArgMask *>
2533-
getCGKernelInfo(const CGExecKernel &CommandGroup, ContextImplPtr ContextImpl,
2532+
getCGKernelInfo(const CGExecKernel &CommandGroup, context_impl &ContextImpl,
25342533
device_impl &DeviceImpl,
25352534
std::vector<FastKernelCacheValPtr> &KernelCacheValsToRelease) {
25362535

@@ -2552,7 +2551,7 @@ getCGKernelInfo(const CGExecKernel &CommandGroup, ContextImplPtr ContextImpl,
25522551
} else {
25532552
FastKernelCacheValPtr FastKernelCacheVal =
25542553
sycl::detail::ProgramManager::getInstance().getOrCreateKernel(
2555-
*ContextImpl, DeviceImpl, CommandGroup.MKernelName,
2554+
ContextImpl, DeviceImpl, CommandGroup.MKernelName,
25562555
CommandGroup.MKernelNameBasedCachePtr);
25572556
UrKernel = FastKernelCacheVal->MKernelHandle;
25582557
EliminatedArgMask = FastKernelCacheVal->MKernelArgMask;
@@ -2579,7 +2578,7 @@ ur_result_t enqueueImpCommandBufferKernel(
25792578
std::shared_ptr<device_image_impl> DeviceImageImpl = nullptr;
25802579
const KernelArgMask *EliminatedArgMask = nullptr;
25812580

2582-
auto ContextImpl = sycl::detail::getSyclObjImpl(Ctx);
2581+
context_impl &ContextImpl = *sycl::detail::getSyclObjImpl(Ctx);
25832582
std::tie(UrKernel, DeviceImageImpl, EliminatedArgMask) = getCGKernelInfo(
25842583
CommandGroup, ContextImpl, DeviceImpl, FastKernelCacheValsToRelease);
25852584

@@ -2599,7 +2598,7 @@ ur_result_t enqueueImpCommandBufferKernel(
25992598
AltUrKernels.push_back(AltUrKernel);
26002599
}
26012600

2602-
const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter();
2601+
const sycl::detail::AdapterPtr &Adapter = ContextImpl.getAdapter();
26032602
auto SetFunc = [&Adapter, &UrKernel, &DeviceImageImpl, &ContextImpl,
26042603
&getMemAllocationFunc](sycl::detail::ArgDesc &Arg,
26052604
size_t NextTrueIndex) {

sycl/source/detail/scheduler/commands.hpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ class context_impl;
4444
class DispatchHostTask;
4545

4646
using EventImplPtr = std::shared_ptr<detail::event_impl>;
47-
using ContextImplPtr = std::shared_ptr<detail::context_impl>;
4847
using StreamImplPtr = std::shared_ptr<detail::stream_impl>;
4948

5049
class Command;
@@ -749,8 +748,7 @@ void SetArgBasedOnType(
749748
const detail::AdapterPtr &Adapter, ur_kernel_handle_t Kernel,
750749
const std::shared_ptr<device_image_impl> &DeviceImageImpl,
751750
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
752-
const ContextImplPtr &ContextImpl, detail::ArgDesc &Arg,
753-
size_t NextTrueIndex);
751+
context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex);
754752

755753
template <typename FuncT>
756754
void applyFuncOnFilteredArgs(const KernelArgMask *EliminatedArgMask,

sycl/source/sampler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ sampler::sampler(coordinate_normalization_mode normalizationMode,
2222

2323
sampler::sampler(cl_sampler clSampler, const context &syclContext)
2424
: impl(std::make_shared<detail::sampler_impl>(
25-
clSampler, detail::getSyclObjImpl(syclContext))) {}
25+
clSampler, *detail::getSyclObjImpl(syclContext))) {}
2626

2727
addressing_mode sampler::get_addressing_mode() const {
2828
return impl->get_addressing_mode();

0 commit comments

Comments
 (0)