@@ -2699,32 +2699,40 @@ ESIMD_INLINE ESIMD_NODEBUG std::enable_if_t<
2699
2699
scatter_impl (AccessorTy acc, simd<T, N> vals, simd<uint32_t , N> offsets,
2700
2700
uint32_t glob_offset, simd_mask<N> mask) {
2701
2701
2702
- static_assert (sizeof (T) <= 4 && detail::isPowerOf2 (N, 32 ),
2703
- " Unexpected type or vector length" );
2704
- constexpr int TypeSizeLog2 = detail::ElemsPerAddrEncoding<sizeof (T)>();
2705
- // TODO (performance) use hardware-supported scale once BE supports it
2706
- constexpr int16_t scale = 0 ;
2707
- const auto si = __ESIMD_NS::get_surface_index (acc);
2708
-
2709
- if constexpr (sizeof (T) < 4 ) {
2710
- using Tint = std::conditional_t <std::is_integral_v<T>, T,
2711
- detail::uint_type_t <sizeof (T)>>;
2712
- using Treal = __raw_t <T>;
2713
- simd<Tint, N> vals_int = bitcast<Tint, Treal, N>(std::move (vals).data ());
2714
- using PromoT = typename std::conditional_t <std::is_signed<Tint>::value,
2715
- int32_t , uint32_t >;
2716
- const simd<PromoT, N> promo_vals = convert<PromoT>(std::move (vals_int));
2717
- __esimd_scatter_scaled<PromoT, N, decltype (si), TypeSizeLog2, scale>(
2718
- mask.data (), si, glob_offset, offsets.data (), promo_vals.data ());
2702
+ static_assert (detail::isPowerOf2 (N, 32 ), " Unexpected vector length" );
2703
+ if constexpr (sizeof (T) == 8 ) {
2704
+ scatter_impl<uint32_t , N>(
2705
+ acc, vals.template bit_cast_view <uint32_t >().template select <N, 2 >(0 ),
2706
+ offsets, glob_offset, mask);
2707
+ scatter_impl<uint32_t , N>(
2708
+ acc, vals.template bit_cast_view <uint32_t >().template select <N, 2 >(1 ),
2709
+ offsets, glob_offset + sizeof (uint32_t ), mask);
2719
2710
} else {
2720
- using Treal = __raw_t <T>;
2721
- if constexpr (!std::is_same_v<Treal, T>) {
2722
- simd<Treal, N> Values = vals.template bit_cast_view <Treal>();
2723
- __esimd_scatter_scaled<Treal, N, decltype (si), TypeSizeLog2, scale>(
2724
- mask.data (), si, glob_offset, offsets.data (), Values.data ());
2711
+ constexpr int TypeSizeLog2 = detail::ElemsPerAddrEncoding<sizeof (T)>();
2712
+ // TODO (performance) use hardware-supported scale once BE supports it
2713
+ constexpr int16_t scale = 0 ;
2714
+ const auto si = __ESIMD_NS::get_surface_index (acc);
2715
+
2716
+ if constexpr (sizeof (T) < 4 ) {
2717
+ using Tint = std::conditional_t <std::is_integral_v<T>, T,
2718
+ detail::uint_type_t <sizeof (T)>>;
2719
+ using Treal = __raw_t <T>;
2720
+ simd<Tint, N> vals_int = bitcast<Tint, Treal, N>(std::move (vals).data ());
2721
+ using PromoT = typename std::conditional_t <std::is_signed<Tint>::value,
2722
+ int32_t , uint32_t >;
2723
+ const simd<PromoT, N> promo_vals = convert<PromoT>(std::move (vals_int));
2724
+ __esimd_scatter_scaled<PromoT, N, decltype (si), TypeSizeLog2, scale>(
2725
+ mask.data (), si, glob_offset, offsets.data (), promo_vals.data ());
2725
2726
} else {
2726
- __esimd_scatter_scaled<T, N, decltype (si), TypeSizeLog2, scale>(
2727
- mask.data (), si, glob_offset, offsets.data (), vals.data ());
2727
+ using Treal = __raw_t <T>;
2728
+ if constexpr (!std::is_same_v<Treal, T>) {
2729
+ simd<Treal, N> Values = vals.template bit_cast_view <Treal>();
2730
+ __esimd_scatter_scaled<Treal, N, decltype (si), TypeSizeLog2, scale>(
2731
+ mask.data (), si, glob_offset, offsets.data (), Values.data ());
2732
+ } else {
2733
+ __esimd_scatter_scaled<T, N, decltype (si), TypeSizeLog2, scale>(
2734
+ mask.data (), si, glob_offset, offsets.data (), vals.data ());
2735
+ }
2728
2736
}
2729
2737
}
2730
2738
}
@@ -2736,42 +2744,50 @@ __ESIMD_API std::enable_if_t<
2736
2744
simd<T, N>>
2737
2745
gather_impl (AccessorTy acc, simd<uint32_t , N> offsets, uint32_t glob_offset,
2738
2746
simd_mask<N> mask) {
2739
- static_assert (sizeof (T) <= 4 && detail::isPowerOf2 (N, 32 ),
2740
- " Unexpected type or vector length" );
2741
-
2742
- constexpr int TypeSizeLog2 = detail::ElemsPerAddrEncoding<sizeof (T)>();
2743
- // TODO (performance) use hardware-supported scale once BE supports it
2744
- constexpr uint32_t scale = 0 ;
2745
- const auto si = get_surface_index (acc);
2746
-
2747
- if constexpr (sizeof (T) < 4 ) {
2748
- using Tint = std::conditional_t <std::is_integral_v<T>, T,
2749
- detail::uint_type_t <sizeof (T)>>;
2750
- using Treal = __raw_t <T>;
2751
- static_assert (std::is_integral<Tint>::value,
2752
- " only integral 1- & 2-byte types are supported" );
2753
- using PromoT = typename std::conditional_t <std::is_signed<Tint>::value,
2754
- int32_t , uint32_t >;
2755
- simd<PromoT, N> promo_vals =
2756
- __esimd_gather_masked_scaled2<PromoT, N, decltype (si), TypeSizeLog2,
2757
- scale>(si, glob_offset, offsets.data (),
2758
- mask.data ());
2759
- auto Res = convert<Tint>(promo_vals);
2760
-
2761
- if constexpr (!std::is_same_v<Tint, T>) {
2762
- return detail::bitcast<Treal, Tint, N>(Res.data ());
2763
- } else {
2764
- return Res;
2765
- }
2747
+ static_assert (detail::isPowerOf2 (N, 32 ), " Unexpected vector length" );
2748
+
2749
+ if constexpr (sizeof (T) == 8 ) {
2750
+ simd<T, N> Res;
2751
+ Res.template bit_cast_view <uint32_t >().template select <N, 2 >(0 ) =
2752
+ gather_impl<uint32_t , N>(acc, offsets, glob_offset, mask);
2753
+ Res.template bit_cast_view <uint32_t >().template select <N, 2 >(1 ) =
2754
+ gather_impl<uint32_t , N>(acc, offsets, glob_offset + sizeof (uint32_t ),
2755
+ mask);
2756
+ return Res;
2766
2757
} else {
2767
2758
using Treal = __raw_t <T>;
2768
- simd<Treal, N> Res = __esimd_gather_masked_scaled2<Treal, N, decltype (si),
2769
- TypeSizeLog2, scale>(
2770
- si, glob_offset, offsets.data (), mask.data ());
2771
- if constexpr (!std::is_same_v<Treal, T>) {
2772
- return Res.template bit_cast_view <T>();
2759
+ constexpr int TypeSizeLog2 = detail::ElemsPerAddrEncoding<sizeof (T)>();
2760
+ // TODO (performance) use hardware-supported scale once BE supports it
2761
+ constexpr uint32_t scale = 0 ;
2762
+ const auto si = get_surface_index (acc);
2763
+ if constexpr (sizeof (T) < 4 ) {
2764
+ using Tint = std::conditional_t <std::is_integral_v<T>, T,
2765
+ detail::uint_type_t <sizeof (T)>>;
2766
+
2767
+ static_assert (std::is_integral<Tint>::value,
2768
+ " only integral 1- & 2-byte types are supported" );
2769
+ using PromoT = typename std::conditional_t <std::is_signed<Tint>::value,
2770
+ int32_t , uint32_t >;
2771
+ simd<PromoT, N> promo_vals =
2772
+ __esimd_gather_masked_scaled2<PromoT, N, decltype (si), TypeSizeLog2,
2773
+ scale>(si, glob_offset, offsets.data (),
2774
+ mask.data ());
2775
+ auto Res = convert<Tint>(promo_vals);
2776
+
2777
+ if constexpr (!std::is_same_v<Tint, T>) {
2778
+ return detail::bitcast<Treal, Tint, N>(Res.data ());
2779
+ } else {
2780
+ return Res;
2781
+ }
2773
2782
} else {
2774
- return Res;
2783
+ simd<Treal, N> Res = __esimd_gather_masked_scaled2<Treal, N, decltype (si),
2784
+ TypeSizeLog2, scale>(
2785
+ si, glob_offset, offsets.data (), mask.data ());
2786
+ if constexpr (!std::is_same_v<Treal, T>) {
2787
+ return Res.template bit_cast_view <T>();
2788
+ } else {
2789
+ return Res;
2790
+ }
2775
2791
}
2776
2792
}
2777
2793
}
@@ -2927,7 +2943,7 @@ __ESIMD_API
2927
2943
return gather<T, N>(__ESIMD_DNS::accessorToPointer<T>(acc, glob_offset),
2928
2944
byte_offsets, mask);
2929
2945
#else
2930
- if constexpr (sizeof (T) > 4 || !( detail::isPowerOf2 (N, 32 ) )) {
2946
+ if constexpr (! detail::isPowerOf2 (N, 32 )) {
2931
2947
// Requires DG2 or PVC.
2932
2948
simd<T, N> PassThru; // Intentionally undefined
2933
2949
byte_offsets += glob_offset;
@@ -3136,7 +3152,7 @@ gather(AccessorT acc, simd<OffsetT, N / VS> byte_offsets,
3136
3152
" hint is cache_level::L2 now." );
3137
3153
3138
3154
if constexpr (L1Hint != cache_hint::none || L2Hint != cache_hint::none ||
3139
- VS > 1 || sizeof (T) > 4 || !(detail::isPowerOf2 (N, 32 ))) {
3155
+ VS > 1 || !(detail::isPowerOf2 (N, 32 ))) {
3140
3156
simd<T, N> PassThru; // Intentionally undefined
3141
3157
return detail::gather_impl<T, N, VS, L1Hint, L2Hint,
3142
3158
detail::lsc_data_size::default_size>(
@@ -3344,13 +3360,13 @@ gather(AccessorT acc, OffsetSimdViewT byte_offsets, PropertyListT props = {}) {
3344
3360
// /
3345
3361
// /
3346
3362
template <typename T, int N, typename AccessorTy>
3347
- __ESIMD_API std:: enable_if_t <
3348
- ( sizeof (T) <= 4 ) && (N == 1 || N == 8 || N == 16 || N == 32 ) &&
3349
- detail::is_device_accessor_with_v<AccessorTy,
3350
- detail::accessor_mode_cap::can_write>>
3351
- scatter (AccessorTy acc, simd<detail::DeviceAccessorOffsetT, N> offsets,
3352
- simd<T, N> vals, detail::DeviceAccessorOffsetT glob_offset = 0 ,
3353
- simd_mask<N> mask = 1 ) {
3363
+ __ESIMD_API
3364
+ std:: enable_if_t <(detail::isPowerOf2(N, 32 ) ) &&
3365
+ detail::is_device_accessor_with_v<
3366
+ AccessorTy, detail::accessor_mode_cap::can_write>>
3367
+ scatter (AccessorTy acc, simd<detail::DeviceAccessorOffsetT, N> offsets,
3368
+ simd<T, N> vals, detail::DeviceAccessorOffsetT glob_offset = 0 ,
3369
+ simd_mask<N> mask = 1 ) {
3354
3370
#ifdef __ESIMD_FORCE_STATELESS_MEM
3355
3371
scatter<T, N>(__ESIMD_DNS::accessorToPointer<T>(acc, glob_offset), offsets,
3356
3372
vals, mask);
@@ -3362,7 +3378,7 @@ scatter(AccessorTy acc, simd<detail::DeviceAccessorOffsetT, N> offsets,
3362
3378
#ifdef __ESIMD_FORCE_STATELESS_MEM
3363
3379
template <typename T, int N, typename AccessorTy, typename Toffset>
3364
3380
__ESIMD_API std::enable_if_t <
3365
- (sizeof (T) <= 4 ) && (N == 1 || N == 8 || N == 16 || N == 32 ) &&
3381
+ (detail::isPowerOf2(N, 32 ) ) &&
3366
3382
detail::is_device_accessor_with_v<AccessorTy,
3367
3383
detail::accessor_mode_cap::can_write> &&
3368
3384
std::is_integral_v<Toffset> && !std::is_same_v<Toffset, uint64_t >>
@@ -3902,9 +3918,27 @@ slm_gather(simd<uint32_t, N / VS> byte_offsets, simd_mask<N / VS> mask,
3902
3918
detail::lsc_data_size::default_size>(
3903
3919
byte_offsets, mask, pass_thru);
3904
3920
} else {
3905
- using MsgT = detail::__raw_t <T>;
3906
- return __esimd_slm_gather_ld<MsgT, N, Alignment>(
3907
- byte_offsets.data (), mask.data (), pass_thru.data ());
3921
+ if constexpr (sizeof (T) == 8 ) {
3922
+ simd<T, N> Res;
3923
+ Res.template bit_cast_view <uint32_t >().template select <N, 2 >(0 ) =
3924
+ __esimd_slm_gather_ld<uint32_t , N, Alignment>(
3925
+ byte_offsets.data (), mask.data (),
3926
+ (pass_thru.template bit_cast_view <uint32_t >()
3927
+ .template select <N, 2 >(0 ))
3928
+ .data ());
3929
+ simd<uint32_t , N / VS> Offset = byte_offsets + sizeof (uint32_t );
3930
+ Res.template bit_cast_view <uint32_t >().template select <N, 2 >(1 ) =
3931
+ __esimd_slm_gather_ld<uint32_t , N, sizeof (uint32_t )>(
3932
+ Offset.data (), mask.data (),
3933
+ (pass_thru.template bit_cast_view <uint32_t >()
3934
+ .template select <N, 2 >(1 ))
3935
+ .data ());
3936
+ return Res;
3937
+ } else {
3938
+ using MsgT = detail::__raw_t <T>;
3939
+ return __esimd_slm_gather_ld<MsgT, N, Alignment>(
3940
+ byte_offsets.data (), mask.data (), pass_thru.data ());
3941
+ }
3908
3942
}
3909
3943
}
3910
3944
@@ -3943,16 +3977,30 @@ slm_gather(simd<uint32_t, N / VS> byte_offsets, simd_mask<N / VS> mask,
3943
3977
static_assert (Alignment >= sizeof (T),
3944
3978
" slm_gather() requires at least element-size alignment" );
3945
3979
3946
- if constexpr (VS > 1 || (!( detail::isPowerOf2 (N, 32 ) && sizeof (T) <= 4 ) &&
3980
+ if constexpr (VS > 1 || (!detail::isPowerOf2 (N, 32 ) &&
3947
3981
!detail::isMaskedGatherScatterLLVMAvailable ())) {
3948
3982
simd<T, N> PassThru; // Intentionally undefined
3949
3983
return detail::slm_gather_impl<T, VS, detail::lsc_data_size::default_size>(
3950
3984
byte_offsets, mask, PassThru);
3951
3985
} else if constexpr (detail::isMaskedGatherScatterLLVMAvailable ()) {
3952
- using MsgT = detail::__raw_t <T>;
3953
- simd<MsgT, N> PassThru; // it is intentionally undefined
3954
- return __esimd_slm_gather_ld<MsgT, N, Alignment>(
3955
- byte_offsets.data (), mask.data (), PassThru.data ());
3986
+ if constexpr (sizeof (T) == 8 ) {
3987
+ simd<T, N> Res;
3988
+ simd<uint32_t , N> PassThru; // it is intentionally undefined
3989
+
3990
+ Res.template bit_cast_view <uint32_t >().template select <N, 2 >(0 ) =
3991
+ __esimd_slm_gather_ld<uint32_t , N, Alignment>(
3992
+ byte_offsets.data (), mask.data (), PassThru.data ());
3993
+ simd<uint32_t , N / VS> Offset = byte_offsets + sizeof (uint32_t );
3994
+ Res.template bit_cast_view <uint32_t >().template select <N, 2 >(1 ) =
3995
+ __esimd_slm_gather_ld<uint32_t , N, sizeof (uint32_t )>(
3996
+ Offset.data (), mask.data (), PassThru.data ());
3997
+ return Res;
3998
+ } else {
3999
+ using MsgT = detail::__raw_t <T>;
4000
+ simd<MsgT, N> PassThru; // it is intentionally undefined
4001
+ return __esimd_slm_gather_ld<MsgT, N, Alignment>(
4002
+ byte_offsets.data (), mask.data (), PassThru.data ());
4003
+ }
3956
4004
} else {
3957
4005
detail::LocalAccessorMarker acc;
3958
4006
return detail::gather_impl<T, N>(acc, byte_offsets, 0 , mask);
@@ -4236,15 +4284,30 @@ slm_scatter(simd<uint32_t, N / VS> byte_offsets, simd<T, N> vals,
4236
4284
" slm_scatter() requires at least element-size alignment" );
4237
4285
4238
4286
// Use LSC lowering if VS > 1.
4239
- if constexpr (VS > 1 || (!( detail::isPowerOf2 (N, 32 ) && sizeof (T) <= 4 ) &&
4287
+ if constexpr (VS > 1 || (!detail::isPowerOf2 (N, 32 ) &&
4240
4288
!detail::isMaskedGatherScatterLLVMAvailable ())) {
4241
4289
__ESIMD_DNS::slm_scatter_impl<T, VS, detail::lsc_data_size::default_size>(
4242
4290
byte_offsets, vals, mask);
4243
4291
} else if constexpr (detail::isMaskedGatherScatterLLVMAvailable ()) {
4244
- using MsgT = detail::__raw_t <T>;
4245
- __esimd_slm_scatter_st<MsgT, N, Alignment>(
4246
- sycl::bit_cast<__ESIMD_DNS::vector_type_t <MsgT, N>>(vals.data ()),
4247
- byte_offsets.data (), mask.data ());
4292
+ if constexpr (sizeof (T) == 8 ) {
4293
+ __esimd_slm_scatter_st<uint32_t , N, Alignment>(
4294
+ vals.template bit_cast_view <uint32_t >()
4295
+ .template select <N, 2 >(0 )
4296
+ .data (),
4297
+ byte_offsets.data (), mask.data ());
4298
+ simd<uint32_t , N / VS> Offset = byte_offsets + sizeof (uint32_t );
4299
+ __esimd_slm_scatter_st<uint32_t , N, sizeof (uint32_t )>(
4300
+ vals.template bit_cast_view <uint32_t >()
4301
+ .template select <N, 2 >(1 )
4302
+ .data (),
4303
+ Offset.data (), mask.data ());
4304
+
4305
+ } else {
4306
+ using MsgT = detail::__raw_t <T>;
4307
+ __esimd_slm_scatter_st<MsgT, N, Alignment>(
4308
+ sycl::bit_cast<__ESIMD_DNS::vector_type_t <MsgT, N>>(vals.data ()),
4309
+ byte_offsets.data (), mask.data ());
4310
+ }
4248
4311
} else {
4249
4312
detail::LocalAccessorMarker acc;
4250
4313
detail::scatter_impl<T, N>(acc, vals, byte_offsets, 0 , mask);
0 commit comments