Skip to content

Commit c7549f9

Browse files
authored
[ESIMD] Fix slm_atomic_update() implementation for double type (#12337)
For double type the GenX intrinsic expects double vectors without bit-casting them to integer types as for other types. This fix enables FMAX/FMIN/FCMPXCHG slm_atomic_update() for double type. It requires pretty new GPU driver. Signed-off-by: Klochkov, Vyacheslav N <[email protected]>
1 parent d3a5f1d commit c7549f9

File tree

2 files changed

+32
-19
lines changed

2 files changed

+32
-19
lines changed

sycl/include/sycl/ext/intel/esimd/memory.hpp

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4089,15 +4089,22 @@ slm_atomic_update_impl(simd<uint32_t, N> offsets, simd<T, N> src0,
40894089
constexpr lsc_data_size EDS = expand_data_size(finalize_data_size<T, DS>());
40904090
constexpr lsc_vector_size VS = to_lsc_vector_size<1>();
40914091
constexpr lsc_data_order Transposed = lsc_data_order::nontranspose;
4092-
using MsgT = typename lsc_expand_type<T>::type;
40934092
constexpr int IOp = lsc_to_internal_atomic_op<T, Op>();
4094-
simd<MsgT, N> Msg_data = lsc_format_input<MsgT>(src0);
4095-
simd<MsgT, N> Tmp =
4096-
__esimd_lsc_xatomic_slm_1<MsgT, IOp, cache_hint::none, cache_hint::none,
4097-
AddressScale, ImmOffset, EDS, VS, Transposed,
4098-
N>(pred.data(), offsets.data(),
4099-
Msg_data.data());
4100-
return lsc_format_ret<T>(Tmp);
4093+
if constexpr (std::is_same_v<T, double>) {
4094+
return __esimd_lsc_xatomic_slm_1<T, IOp, cache_hint::none, cache_hint::none,
4095+
AddressScale, ImmOffset, EDS, VS,
4096+
Transposed, N>(pred.data(), offsets.data(),
4097+
src0.data());
4098+
} else {
4099+
using MsgT = typename lsc_expand_type<T>::type;
4100+
simd<MsgT, N> Msg_data = lsc_format_input<MsgT>(src0);
4101+
simd<MsgT, N> Tmp =
4102+
__esimd_lsc_xatomic_slm_1<MsgT, IOp, cache_hint::none, cache_hint::none,
4103+
AddressScale, ImmOffset, EDS, VS, Transposed,
4104+
N>(pred.data(), offsets.data(),
4105+
Msg_data.data());
4106+
return lsc_format_ret<T>(Tmp);
4107+
}
41014108
}
41024109

41034110
/// SLM atomic.
@@ -4126,16 +4133,23 @@ __ESIMD_API simd<T, N> slm_atomic_update_impl(simd<uint32_t, N> offsets,
41264133
constexpr lsc_data_size EDS = expand_data_size(finalize_data_size<T, DS>());
41274134
constexpr lsc_vector_size VS = to_lsc_vector_size<1>();
41284135
constexpr lsc_data_order Transposed = lsc_data_order::nontranspose;
4129-
using MsgT = typename lsc_expand_type<T>::type;
41304136
constexpr int IOp = lsc_to_internal_atomic_op<T, Op>();
4131-
simd<MsgT, N> Msg_data0 = lsc_format_input<MsgT>(src0);
4132-
simd<MsgT, N> Msg_data1 = lsc_format_input<MsgT>(src1);
4133-
simd<MsgT, N> Tmp =
4134-
__esimd_lsc_xatomic_slm_2<MsgT, IOp, cache_hint::none, cache_hint::none,
4135-
AddressScale, ImmOffset, EDS, VS, Transposed,
4136-
N>(pred.data(), offsets.data(),
4137-
Msg_data0.data(), Msg_data1.data());
4138-
return lsc_format_ret<T>(Tmp);
4137+
if constexpr (std::is_same_v<T, double>) {
4138+
return __esimd_lsc_xatomic_slm_2<T, IOp, cache_hint::none, cache_hint::none,
4139+
AddressScale, ImmOffset, EDS, VS,
4140+
Transposed, N>(pred.data(), offsets.data(),
4141+
src0.data(), src1.data());
4142+
} else {
4143+
using MsgT = typename lsc_expand_type<T>::type;
4144+
simd<MsgT, N> Msg_data0 = lsc_format_input<MsgT>(src0);
4145+
simd<MsgT, N> Msg_data1 = lsc_format_input<MsgT>(src1);
4146+
simd<MsgT, N> Tmp =
4147+
__esimd_lsc_xatomic_slm_2<MsgT, IOp, cache_hint::none, cache_hint::none,
4148+
AddressScale, ImmOffset, EDS, VS, Transposed,
4149+
N>(pred.data(), offsets.data(),
4150+
Msg_data0.data(), Msg_data1.data());
4151+
return lsc_format_ret<T>(Tmp);
4152+
}
41394153
}
41404154

41414155
} // namespace detail

sycl/test-e2e/ESIMD/unified_memory_api/Inputs/atomic_update_slm.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -612,8 +612,7 @@ bool test_fp_types(queue q) {
612612

613613
if constexpr (Features == TestFeatures::DG2 ||
614614
Features == TestFeatures::PVC) {
615-
// TODO: fmin/max for double does not pass validation likely due to
616-
// a driver bug. fcmpwr is hanging.
615+
// TODO: fmin/fmax/fcmpxchg for double requires a newer GPU driver.
617616
if constexpr (!std::is_same_v<Op<double, N>, ImplLSCFmax<double, N>> &&
618617
!std::is_same_v<Op<double, N>, ImplLSCFmin<double, N>> &&
619618
!std::is_same_v<Op<double, N>, ImplLSCFcmpwr<double, N>>) {

0 commit comments

Comments
 (0)