Skip to content

Commit 456b9b4

Browse files
authored
[SYCL] Update sampler_impl to use context_impl instead of context (#17477)
Update `sampler_impl` to use `context_impl` instead of `context`. These changes decrease the amount of `std::shared_ptr` copies to improve performance.
1 parent 44d4633 commit 456b9b4

File tree

5 files changed

+32
-20
lines changed

5 files changed

+32
-20
lines changed

sycl/source/detail/sampler_impl.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@ sampler_impl::sampler_impl(coordinate_normalization_mode normalizationMode,
2424
verifyProps(MPropList);
2525
}
2626

27-
sampler_impl::sampler_impl(cl_sampler clSampler, const context &syclContext) {
28-
const AdapterPtr &Adapter = getSyclObjImpl(syclContext)->getAdapter();
27+
sampler_impl::sampler_impl(cl_sampler clSampler,
28+
const ContextImplPtr &syclContext) {
29+
const AdapterPtr &Adapter = syclContext->getAdapter();
2930
ur_sampler_handle_t Sampler{};
3031
Adapter->call<UrApiKind::urSamplerCreateWithNativeHandle>(
3132
reinterpret_cast<ur_native_handle_t>(clSampler),
32-
getSyclObjImpl(syclContext)->getHandleRef(), nullptr, &Sampler);
33+
syclContext->getHandleRef(), nullptr, &Sampler);
3334

3435
MContextToSampler[syclContext] = Sampler;
3536
bool NormalizedCoords;
@@ -85,18 +86,19 @@ sampler_impl::~sampler_impl() {
8586
for (auto &Iter : MContextToSampler) {
8687
// TODO catch an exception and add it to the list of asynchronous
8788
// exceptions
88-
const AdapterPtr &Adapter = getSyclObjImpl(Iter.first)->getAdapter();
89+
const AdapterPtr &Adapter = Iter.first->getAdapter();
8990
Adapter->call<UrApiKind::urSamplerRelease>(Iter.second);
9091
}
9192
} catch (std::exception &e) {
9293
__SYCL_REPORT_EXCEPTION_TO_STREAM("exception in ~sample_impl", e);
9394
}
9495
}
9596

96-
ur_sampler_handle_t sampler_impl::getOrCreateSampler(const context &Context) {
97+
ur_sampler_handle_t
98+
sampler_impl::getOrCreateSampler(const ContextImplPtr &ContextImpl) {
9799
{
98100
std::lock_guard<std::mutex> Lock(MMutex);
99-
auto It = MContextToSampler.find(Context);
101+
auto It = MContextToSampler.find(ContextImpl);
100102
if (It != MContextToSampler.end())
101103
return It->second;
102104
}
@@ -133,18 +135,18 @@ ur_sampler_handle_t sampler_impl::getOrCreateSampler(const context &Context) {
133135

134136
ur_result_t errcode_ret = UR_RESULT_SUCCESS;
135137
ur_sampler_handle_t resultSampler = nullptr;
136-
const AdapterPtr &Adapter = getSyclObjImpl(Context)->getAdapter();
138+
const AdapterPtr &Adapter = ContextImpl->getAdapter();
137139

138140
errcode_ret = Adapter->call_nocheck<UrApiKind::urSamplerCreate>(
139-
getSyclObjImpl(Context)->getHandleRef(), &desc, &resultSampler);
141+
ContextImpl->getHandleRef(), &desc, &resultSampler);
140142

141143
if (errcode_ret == UR_RESULT_ERROR_UNSUPPORTED_FEATURE)
142144
throw sycl::exception(sycl::errc::feature_not_supported,
143145
"Images are not supported by this device.");
144146

145147
Adapter->checkUrResult(errcode_ret);
146148
std::lock_guard<std::mutex> Lock(MMutex);
147-
MContextToSampler[Context] = resultSampler;
149+
MContextToSampler[ContextImpl] = resultSampler;
148150

149151
return resultSampler;
150152
}

sycl/source/detail/sampler_impl.hpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,17 @@
99
#pragma once
1010

1111
#include <sycl/__spirv/spirv_types.hpp>
12-
#include <sycl/context.hpp>
1312
#include <sycl/detail/export.hpp>
1413
#include <sycl/detail/ur.hpp>
1514
#include <sycl/property_list.hpp>
1615

1716
#include <mutex>
1817
#include <unordered_map>
1918

19+
#ifdef __SYCL_INTERNAL_API
20+
#include <sycl/detail/cl.h>
21+
#endif
22+
2023
namespace sycl {
2124
inline namespace _V1 {
2225

@@ -25,21 +28,25 @@ enum class filtering_mode : unsigned int;
2528
enum class coordinate_normalization_mode : unsigned int;
2629

2730
namespace detail {
31+
32+
class context_impl;
33+
using ContextImplPtr = std::shared_ptr<context_impl>;
34+
2835
class sampler_impl {
2936
public:
3037
sampler_impl(coordinate_normalization_mode normalizationMode,
3138
addressing_mode addressingMode, filtering_mode filteringMode,
3239
const property_list &propList);
3340

34-
sampler_impl(cl_sampler clSampler, const context &syclContext);
41+
sampler_impl(cl_sampler clSampler, const ContextImplPtr &syclContext);
3542

3643
addressing_mode get_addressing_mode() const;
3744

3845
filtering_mode get_filtering_mode() const;
3946

4047
coordinate_normalization_mode get_coordinate_normalization_mode() const;
4148

42-
ur_sampler_handle_t getOrCreateSampler(const context &Context);
49+
ur_sampler_handle_t getOrCreateSampler(const ContextImplPtr &ContextImpl);
4350

4451
~sampler_impl();
4552

@@ -49,7 +56,7 @@ class sampler_impl {
4956
/// Protects all the fields that can be changed by class' methods.
5057
std::mutex MMutex;
5158

52-
std::unordered_map<context, ur_sampler_handle_t> MContextToSampler;
59+
std::unordered_map<ContextImplPtr, ur_sampler_handle_t> MContextToSampler;
5360

5461
coordinate_normalization_mode MCoordNormMode;
5562
addressing_mode MAddrMode;

sycl/source/detail/scheduler/commands.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2316,7 +2316,8 @@ void SetArgBasedOnType(
23162316
const AdapterPtr &Adapter, ur_kernel_handle_t Kernel,
23172317
const std::shared_ptr<device_image_impl> &DeviceImageImpl,
23182318
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2319-
const sycl::context &Context, detail::ArgDesc &Arg, size_t NextTrueIndex) {
2319+
const ContextImplPtr &ContextImpl, detail::ArgDesc &Arg,
2320+
size_t NextTrueIndex) {
23202321
switch (Arg.MType) {
23212322
case kernel_param_kind_t::kind_work_group_memory:
23222323
break;
@@ -2355,7 +2356,7 @@ void SetArgBasedOnType(
23552356
sampler *SamplerPtr = (sampler *)Arg.MPtr;
23562357
ur_sampler_handle_t Sampler =
23572358
(ur_sampler_handle_t)detail::getSyclObjImpl(*SamplerPtr)
2358-
->getOrCreateSampler(Context);
2359+
->getOrCreateSampler(ContextImpl);
23592360
Adapter->call<UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
23602361
nullptr, Sampler);
23612362
break;
@@ -2414,7 +2415,7 @@ static ur_result_t SetKernelParamsAndLaunch(
24142415
auto setFunc = [&Adapter, Kernel, &DeviceImageImpl, &getMemAllocationFunc,
24152416
&Queue](detail::ArgDesc &Arg, size_t NextTrueIndex) {
24162417
SetArgBasedOnType(Adapter, Kernel, DeviceImageImpl, getMemAllocationFunc,
2417-
Queue->get_context(), Arg, NextTrueIndex);
2418+
Queue->getContextImplPtr(), Arg, NextTrueIndex);
24182419
};
24192420

24202421
applyFuncOnFilteredArgs(EliminatedArgMask, Args, setFunc);
@@ -2600,11 +2601,11 @@ ur_result_t enqueueImpCommandBufferKernel(
26002601
}
26012602

26022603
const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter();
2603-
auto SetFunc = [&Adapter, &UrKernel, &DeviceImageImpl, &Ctx,
2604+
auto SetFunc = [&Adapter, &UrKernel, &DeviceImageImpl, &ContextImpl,
26042605
&getMemAllocationFunc](sycl::detail::ArgDesc &Arg,
26052606
size_t NextTrueIndex) {
26062607
sycl::detail::SetArgBasedOnType(Adapter, UrKernel, DeviceImageImpl,
2607-
getMemAllocationFunc, Ctx, Arg,
2608+
getMemAllocationFunc, ContextImpl, Arg,
26082609
NextTrueIndex);
26092610
};
26102611
// Copy args for modification

sycl/source/detail/scheduler/commands.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,8 @@ void SetArgBasedOnType(
746746
const detail::AdapterPtr &Adapter, ur_kernel_handle_t Kernel,
747747
const std::shared_ptr<device_image_impl> &DeviceImageImpl,
748748
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
749-
const sycl::context &Context, detail::ArgDesc &Arg, size_t NextTrueIndex);
749+
const ContextImplPtr &ContextImpl, detail::ArgDesc &Arg,
750+
size_t NextTrueIndex);
750751

751752
template <typename FuncT>
752753
void applyFuncOnFilteredArgs(const KernelArgMask *EliminatedArgMask,

sycl/source/sampler.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ sampler::sampler(coordinate_normalization_mode normalizationMode,
2121
normalizationMode, addressingMode, filteringMode, propList)) {}
2222

2323
sampler::sampler(cl_sampler clSampler, const context &syclContext)
24-
: impl(std::make_shared<detail::sampler_impl>(clSampler, syclContext)) {}
24+
: impl(std::make_shared<detail::sampler_impl>(
25+
clSampler, detail::getSyclObjImpl(syclContext))) {}
2526

2627
addressing_mode sampler::get_addressing_mode() const {
2728
return impl->get_addressing_mode();

0 commit comments

Comments
 (0)