Skip to content

Commit 039b538

Browse files
authored
[ESIMD] Reduce number of bit-casts generated for lsc_block_load/store operations (#8385)
1 parent 6761f0e commit 039b538

File tree

2 files changed

+67
-80
lines changed

2 files changed

+67
-80
lines changed

sycl/include/sycl/ext/intel/experimental/esimd/detail/memory_intrin.hpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,6 @@ void __esimd_emu_lsc_offset_write(
401401
std::conditional_t<DS ==
402402
__ESIMD_ENS::lsc_data_size::u16u32,
403403
uint16_t, void>>>>>>;
404-
405404
for (int OffsetIdx = 0; OffsetIdx < N; OffsetIdx += 1) {
406405
if (Pred[OffsetIdx] == 0) {
407406
// Skip input vector elements correpsonding to
@@ -420,7 +419,12 @@ void __esimd_emu_lsc_offset_write(
420419
VecIdx += vectorIndexIncrement<N, _Transposed>()) {
421420

422421
if ((ByteDistance >= 0) && (ByteDistance < BufByteWidth)) {
423-
*((StoreType *)(WriteBase + ByteDistance)) = vals[VecIdx];
422+
if constexpr (std::is_floating_point<Ty>::value) {
423+
*((StoreType *)(WriteBase + ByteDistance)) =
424+
sycl::bit_cast<StoreType>(vals[VecIdx]);
425+
} else {
426+
*((StoreType *)(WriteBase + ByteDistance)) = vals[VecIdx];
427+
}
424428
}
425429
}
426430
}
@@ -1177,7 +1181,12 @@ __ESIMD_INTRIN void __esimd_lsc_store_stateless(
11771181
for (int ChanelIdx = 0, VecIdx = AddrIdx; ChanelIdx < ChanlCount;
11781182
ChanelIdx += 1, ByteDistance += rawAddressIncrement<Ty, DS>(),
11791183
VecIdx += vectorIndexIncrement<N, _Transposed>()) {
1180-
*((StoreType *)(BaseAddr + ByteDistance)) = vals[VecIdx];
1184+
if constexpr (std::is_floating_point<Ty>::value) {
1185+
*((StoreType *)(BaseAddr + ByteDistance)) =
1186+
sycl::bit_cast<StoreType>(vals[VecIdx]);
1187+
} else {
1188+
*((StoreType *)(BaseAddr + ByteDistance)) = vals[VecIdx];
1189+
}
11811190
}
11821191
}
11831192
}

sycl/include/sycl/ext/intel/experimental/esimd/memory.hpp

Lines changed: 55 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -590,8 +590,9 @@ lsc_block_load(const T *p, __ESIMD_NS::simd_mask<1> pred = 1) {
590590
detail::check_lsc_vector_size<NElts / SmallIntFactor>();
591591

592592
// Prepare template arguments for the call of intrinsic.
593-
using LoadElemT =
594-
std::conditional_t<FDS == lsc_data_size::u64, uint64_t, uint32_t>;
593+
using LoadElemT = std::conditional_t<
594+
std::is_floating_point<T>::value, T,
595+
std::conditional_t<FDS == lsc_data_size::u64, uint64_t, uint32_t>>;
595596
constexpr uint16_t _AddressScale = 1;
596597
constexpr int _ImmOffset = 0;
597598
constexpr auto _DS = FDS == lsc_data_size::u64 ? FDS : lsc_data_size::u32;
@@ -650,8 +651,9 @@ lsc_block_load(const T *p, __ESIMD_NS::simd_mask<1> pred,
650651
detail::check_lsc_vector_size<NElts / SmallIntFactor>();
651652

652653
// Prepare template arguments for the call of intrinsic.
653-
using LoadElemT =
654-
std::conditional_t<FDS == lsc_data_size::u64, uint64_t, uint32_t>;
654+
using LoadElemT = std::conditional_t<
655+
std::is_floating_point<T>::value, T,
656+
std::conditional_t<FDS == lsc_data_size::u64, uint64_t, uint32_t>>;
655657
constexpr uint16_t _AddressScale = 1;
656658
constexpr int _ImmOffset = 0;
657659
constexpr auto _DS = FDS == lsc_data_size::u64 ? FDS : lsc_data_size::u32;
@@ -714,8 +716,9 @@ lsc_block_load(AccessorTy acc, uint32_t offset,
714716
detail::check_lsc_vector_size<NElts / SmallIntFactor>();
715717

716718
// Prepare template arguments for the call of intrinsic.
717-
using LoadElemT =
718-
std::conditional_t<FDS == lsc_data_size::u64, uint64_t, uint32_t>;
719+
using LoadElemT = std::conditional_t<
720+
std::is_floating_point<T>::value, T,
721+
std::conditional_t<FDS == lsc_data_size::u64, uint64_t, uint32_t>>;
719722
constexpr uint16_t _AddressScale = 1;
720723
constexpr int _ImmOffset = 0;
721724
constexpr auto _DS = FDS == lsc_data_size::u64 ? FDS : lsc_data_size::u32;
@@ -779,8 +782,9 @@ lsc_block_load(AccessorTy acc, uint32_t offset, __ESIMD_NS::simd_mask<1> pred,
779782
detail::check_lsc_vector_size<NElts / SmallIntFactor>();
780783

781784
// Prepare template arguments for the call of intrinsic.
782-
using LoadElemT =
783-
std::conditional_t<FDS == lsc_data_size::u64, uint64_t, uint32_t>;
785+
using LoadElemT = std::conditional_t<
786+
std::is_floating_point<T>::value, T,
787+
std::conditional_t<FDS == lsc_data_size::u64, uint64_t, uint32_t>>;
784788
constexpr uint16_t _AddressScale = 1;
785789
constexpr int _ImmOffset = 0;
786790
constexpr auto _DS = FDS == lsc_data_size::u64 ? FDS : lsc_data_size::u32;
@@ -1206,43 +1210,32 @@ __ESIMD_API void lsc_block_store(T *p, __ESIMD_NS::simd<T, NElts> vals,
12061210
__ESIMD_NS::simd_mask<1> pred = 1) {
12071211
detail::check_lsc_data_size<T, DS>();
12081212
detail::check_lsc_cache_hint<detail::lsc_action::store, L1H, L3H>();
1213+
constexpr lsc_data_size FDS = detail::finalize_data_size<T, DS>();
1214+
constexpr int SmallIntFactor =
1215+
(FDS == lsc_data_size::u16) ? 2 : (FDS == lsc_data_size::u8 ? 4 : 1);
1216+
static_assert(NElts > 0 && NElts % SmallIntFactor == 0,
1217+
"Number of elements is not supported by Transposed load");
1218+
detail::check_lsc_vector_size<NElts / SmallIntFactor>();
1219+
1220+
// Prepare template arguments for the call of intrinsic.
1221+
using StoreElemT = std::conditional_t<
1222+
std::is_floating_point<T>::value, T,
1223+
std::conditional_t<FDS == lsc_data_size::u64, uint64_t, uint32_t>>;
12091224
constexpr uint16_t _AddressScale = 1;
12101225
constexpr int _ImmOffset = 0;
1211-
constexpr lsc_data_size _DS = detail::finalize_data_size<T, DS>();
1212-
constexpr detail::lsc_data_order _Transposed =
1213-
detail::lsc_data_order::transpose;
1226+
constexpr auto _DS = FDS == lsc_data_size::u64 ? FDS : lsc_data_size::u32;
1227+
constexpr auto _VS = detail::to_lsc_vector_size<NElts / SmallIntFactor>();
1228+
constexpr auto _Transposed = detail::lsc_data_order::transpose;
12141229
constexpr int N = 1;
1215-
__ESIMD_NS::simd<uintptr_t, N> addrs = reinterpret_cast<uintptr_t>(p);
1216-
constexpr int SmallIntFactor =
1217-
(_DS == lsc_data_size::u16) ? 2 : (_DS == lsc_data_size::u8 ? 4 : 1);
1218-
static_assert(NElts % SmallIntFactor == 0,
1219-
"Number of elements is not supported by Transposed store");
1220-
detail::check_lsc_vector_size<NElts / SmallIntFactor>();
1221-
constexpr detail::lsc_vector_size _VS =
1222-
detail::to_lsc_vector_size<NElts / SmallIntFactor>();
1223-
if constexpr (SmallIntFactor == 1) {
1224-
if constexpr (_DS == lsc_data_size::u32) {
1225-
__esimd_lsc_store_stateless<uint32_t, L1H, L3H, _AddressScale, _ImmOffset,
1226-
_DS, _VS, _Transposed, N>(
1227-
pred.data(), addrs.data(),
1228-
sycl::bit_cast<__ESIMD_DNS::vector_type_t<uint32_t, NElts>>(
1229-
vals.data()));
1230-
} else {
1231-
__esimd_lsc_store_stateless<uint64_t, L1H, L3H, _AddressScale, _ImmOffset,
1232-
_DS, _VS, _Transposed, N>(
1233-
pred.data(), addrs.data(),
1234-
sycl::bit_cast<__ESIMD_DNS::vector_type_t<uint64_t, NElts>>(
1235-
vals.data()));
1236-
}
1237-
} else {
1238-
__ESIMD_NS::simd<uint32_t, NElts / SmallIntFactor> tmp = sycl::bit_cast<
1239-
__ESIMD_DNS::vector_type_t<uint32_t, NElts / SmallIntFactor>>(
1240-
vals.data());
12411230

1242-
__esimd_lsc_store_stateless<uint32_t, L1H, L3H, _AddressScale, _ImmOffset,
1243-
lsc_data_size::u32, _VS, _Transposed, N>(
1244-
pred.data(), addrs.data(), tmp.data());
1245-
}
1231+
__ESIMD_NS::simd<uintptr_t, N> Addrs = reinterpret_cast<uintptr_t>(p);
1232+
1233+
__esimd_lsc_store_stateless<StoreElemT, L1H, L3H, _AddressScale, _ImmOffset,
1234+
_DS, _VS, _Transposed, N>(
1235+
pred.data(), Addrs.data(),
1236+
sycl::bit_cast<
1237+
__ESIMD_DNS::vector_type_t<StoreElemT, NElts / SmallIntFactor>>(
1238+
vals.data()));
12461239
}
12471240

12481241
/// Accessor-based transposed scatter with 1 channel.
@@ -1279,48 +1272,33 @@ lsc_block_store(AccessorTy acc, uint32_t offset,
12791272
#else
12801273
detail::check_lsc_data_size<T, DS>();
12811274
detail::check_lsc_cache_hint<detail::lsc_action::store, L1H, L3H>();
1275+
constexpr lsc_data_size FDS = detail::finalize_data_size<T, DS>();
1276+
constexpr int SmallIntFactor =
1277+
(FDS == lsc_data_size::u16) ? 2 : (FDS == lsc_data_size::u8 ? 4 : 1);
1278+
static_assert(NElts > 0 && NElts % SmallIntFactor == 0,
1279+
"Number of elements is not supported by Transposed load");
1280+
detail::check_lsc_vector_size<NElts / SmallIntFactor>();
1281+
1282+
// Prepare template arguments for the call of intrinsic.
1283+
using StoreElemT = std::conditional_t<
1284+
std::is_floating_point<T>::value, T,
1285+
std::conditional_t<FDS == lsc_data_size::u64, uint64_t, uint32_t>>;
12821286
constexpr uint16_t _AddressScale = 1;
12831287
constexpr int _ImmOffset = 0;
1284-
constexpr lsc_data_size _DS = detail::finalize_data_size<T, DS>();
1285-
constexpr detail::lsc_data_order _Transposed =
1286-
detail::lsc_data_order::transpose;
1288+
constexpr auto _DS = FDS == lsc_data_size::u64 ? FDS : lsc_data_size::u32;
1289+
constexpr auto _VS = detail::to_lsc_vector_size<NElts / SmallIntFactor>();
1290+
constexpr auto _Transposed = detail::lsc_data_order::transpose;
12871291
constexpr int N = 1;
12881292

12891293
__ESIMD_NS::simd<uint32_t, N> offsets = offset;
12901294
auto si = __ESIMD_NS::get_surface_index(acc);
1291-
constexpr int SmallIntFactor =
1292-
(_DS == lsc_data_size::u16) ? 2 : (_DS == lsc_data_size::u8 ? 4 : 1);
1293-
1294-
detail::check_lsc_vector_size<NElts / SmallIntFactor>();
1295-
static_assert(NElts % SmallIntFactor == 0,
1296-
"Number of elements is not supported by Transposed store");
1297-
constexpr detail::lsc_vector_size _VS =
1298-
detail::to_lsc_vector_size<NElts / SmallIntFactor>();
1299-
if constexpr (SmallIntFactor > 1) {
1300-
__esimd_lsc_store_bti<uint32_t, L1H, L3H, _AddressScale, _ImmOffset,
1301-
lsc_data_size::u32, _VS, _Transposed, N>(
1302-
pred.data(), offsets.data(),
1303-
sycl::bit_cast<
1304-
__ESIMD_DNS::vector_type_t<uint32_t, NElts / SmallIntFactor>>(
1305-
vals.data()),
1306-
si);
1307-
} else {
1308-
if constexpr (_DS == lsc_data_size::u32) {
1309-
__esimd_lsc_store_bti<uint32_t, L1H, L3H, _AddressScale, _ImmOffset, _DS,
1310-
_VS, _Transposed, N>(
1311-
pred.data(), offsets.data(),
1312-
sycl::bit_cast<__ESIMD_DNS::vector_type_t<uint32_t, NElts>>(
1313-
vals.data()),
1314-
si);
1315-
} else {
1316-
__esimd_lsc_store_bti<uint64_t, L1H, L3H, _AddressScale, _ImmOffset, _DS,
1317-
_VS, _Transposed, N>(
1318-
pred.data(), offsets.data(),
1319-
sycl::bit_cast<__ESIMD_DNS::vector_type_t<uint64_t, NElts>>(
1320-
vals.data()),
1321-
si);
1322-
}
1323-
}
1295+
__esimd_lsc_store_bti<StoreElemT, L1H, L3H, _AddressScale, _ImmOffset, _DS,
1296+
_VS, _Transposed, N>(
1297+
pred.data(), offsets.data(),
1298+
sycl::bit_cast<
1299+
__ESIMD_DNS::vector_type_t<StoreElemT, NElts / SmallIntFactor>>(
1300+
vals.data()),
1301+
si);
13241302
#endif
13251303
}
13261304

0 commit comments

Comments
 (0)