Skip to content

Commit 34f0c10

Browse files
authored
[SYCL] Fix race that occurs when using sampler in parallel (#2267)
Signed-off-by: Alexander Flegontov <[email protected]>
1 parent 5b286f3 commit 34f0c10

File tree

2 files changed

+33
-25
lines changed

2 files changed

+33
-25
lines changed

sycl/source/detail/sampler_impl.cpp

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,45 +16,50 @@ namespace detail {
1616
sampler_impl::sampler_impl(coordinate_normalization_mode normalizationMode,
1717
addressing_mode addressingMode,
1818
filtering_mode filteringMode)
19-
: m_CoordNormMode(normalizationMode), m_AddrMode(addressingMode),
20-
m_FiltMode(filteringMode) {}
19+
: MCoordNormMode(normalizationMode), MAddrMode(addressingMode),
20+
MFiltMode(filteringMode) {}
2121

2222
sampler_impl::sampler_impl(cl_sampler clSampler, const context &syclContext) {
2323

2424
RT::PiSampler Sampler = pi::cast<RT::PiSampler>(clSampler);
25-
m_contextToSampler[syclContext] = Sampler;
25+
MContextToSampler[syclContext] = Sampler;
2626
const detail::plugin &Plugin = getSyclObjImpl(syclContext)->getPlugin();
2727
Plugin.call<PiApiKind::piSamplerRetain>(Sampler);
2828
Plugin.call<PiApiKind::piSamplerGetInfo>(
2929
Sampler, PI_SAMPLER_INFO_NORMALIZED_COORDS, sizeof(pi_bool),
30-
&m_CoordNormMode, nullptr);
30+
&MCoordNormMode, nullptr);
3131
Plugin.call<PiApiKind::piSamplerGetInfo>(
3232
Sampler, PI_SAMPLER_INFO_ADDRESSING_MODE,
33-
sizeof(pi_sampler_addressing_mode), &m_AddrMode, nullptr);
33+
sizeof(pi_sampler_addressing_mode), &MAddrMode, nullptr);
3434
Plugin.call<PiApiKind::piSamplerGetInfo>(Sampler, PI_SAMPLER_INFO_FILTER_MODE,
3535
sizeof(pi_sampler_filter_mode),
36-
&m_FiltMode, nullptr);
36+
&MFiltMode, nullptr);
3737
}
3838

3939
sampler_impl::~sampler_impl() {
40-
for (auto &Iter : m_contextToSampler) {
40+
std::lock_guard<mutex_class> Lock(MMutex);
41+
for (auto &Iter : MContextToSampler) {
4142
// TODO catch an exception and add it to the list of asynchronous exceptions
4243
const detail::plugin &Plugin = getSyclObjImpl(Iter.first)->getPlugin();
4344
Plugin.call<PiApiKind::piSamplerRelease>(Iter.second);
4445
}
4546
}
4647

4748
RT::PiSampler sampler_impl::getOrCreateSampler(const context &Context) {
48-
if (m_contextToSampler[Context])
49-
return m_contextToSampler[Context];
49+
{
50+
std::lock_guard<mutex_class> Lock(MMutex);
51+
auto It = MContextToSampler.find(Context);
52+
if (It != MContextToSampler.end())
53+
return It->second;
54+
}
5055

5156
const pi_sampler_properties sprops[] = {
5257
PI_SAMPLER_INFO_NORMALIZED_COORDS,
53-
static_cast<pi_sampler_properties>(m_CoordNormMode),
58+
static_cast<pi_sampler_properties>(MCoordNormMode),
5459
PI_SAMPLER_INFO_ADDRESSING_MODE,
55-
static_cast<pi_sampler_properties>(m_AddrMode),
60+
static_cast<pi_sampler_properties>(MAddrMode),
5661
PI_SAMPLER_INFO_FILTER_MODE,
57-
static_cast<pi_sampler_properties>(m_FiltMode),
62+
static_cast<pi_sampler_properties>(MFiltMode),
5863
0};
5964

6065
RT::PiResult errcode_ret = PI_SUCCESS;
@@ -69,18 +74,19 @@ RT::PiSampler sampler_impl::getOrCreateSampler(const context &Context) {
6974
errcode_ret);
7075

7176
Plugin.checkPiResult(errcode_ret);
72-
m_contextToSampler[Context] = resultSampler;
77+
std::lock_guard<mutex_class> Lock(MMutex);
78+
MContextToSampler[Context] = resultSampler;
7379

74-
return m_contextToSampler[Context];
80+
return resultSampler;
7581
}
7682

77-
addressing_mode sampler_impl::get_addressing_mode() const { return m_AddrMode; }
83+
addressing_mode sampler_impl::get_addressing_mode() const { return MAddrMode; }
7884

79-
filtering_mode sampler_impl::get_filtering_mode() const { return m_FiltMode; }
85+
filtering_mode sampler_impl::get_filtering_mode() const { return MFiltMode; }
8086

8187
coordinate_normalization_mode
8288
sampler_impl::get_coordinate_normalization_mode() const {
83-
return m_CoordNormMode;
89+
return MCoordNormMode;
8490
}
8591

8692
} // namespace detail

sycl/source/detail/sampler_impl.hpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,6 @@ enum class coordinate_normalization_mode : unsigned int;
2323

2424
namespace detail {
2525
class __SYCL_EXPORT sampler_impl {
26-
public:
27-
std::unordered_map<context, RT::PiSampler> m_contextToSampler;
28-
29-
private:
30-
coordinate_normalization_mode m_CoordNormMode;
31-
addressing_mode m_AddrMode;
32-
filtering_mode m_FiltMode;
33-
3426
public:
3527
sampler_impl(coordinate_normalization_mode normalizationMode,
3628
addressing_mode addressingMode, filtering_mode filteringMode);
@@ -46,6 +38,16 @@ class __SYCL_EXPORT sampler_impl {
4638
RT::PiSampler getOrCreateSampler(const context &Context);
4739

4840
~sampler_impl();
41+
42+
private:
43+
/// Protects all the fields that can be changed by class' methods.
44+
mutex_class MMutex;
45+
46+
std::unordered_map<context, RT::PiSampler> MContextToSampler;
47+
48+
coordinate_normalization_mode MCoordNormMode;
49+
addressing_mode MAddrMode;
50+
filtering_mode MFiltMode;
4951
};
5052

5153
} // namespace detail

0 commit comments

Comments
 (0)