Skip to content

Commit 2b8c118

Browse files
authored
[ESIMD] Fix an issue with copy_to that incorrectly copies char buffers in some circumstances (#6298)
* Fix an issue with copy_to that incorrectly copies char buffers in some circumstances
1 parent f00574c commit 2b8c118

File tree

1 file changed

+28
-3
lines changed

1 file changed

+28
-3
lines changed

sycl/include/sycl/ext/intel/esimd/memory.hpp

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,9 +1381,34 @@ void simd_obj_impl<T, N, T1, SFINAE>::copy_to(
13811381
if constexpr (RemN == 1) {
13821382
Addr[NumChunks * ChunkSize] = Tmp[NumChunks * ChunkSize];
13831383
} else if constexpr (RemN == 8 || RemN == 16) {
1384-
simd<uint32_t, RemN> Offsets(0u, sizeof(T));
1385-
scatter<UT, RemN>(Addr + (NumChunks * ChunkSize), Offsets,
1386-
Tmp.template select<RemN, 1>(NumChunks * ChunkSize));
1384+
// TODO: GPU runtime may handle scatter of 16 byte elements incorrectly.
1385+
// The code below is a workaround which must be deleted once GPU runtime
1386+
// is fixed.
1387+
if constexpr (sizeof(T) == 1 && RemN == 16) {
1388+
if constexpr (Align % OperandSize::DWORD > 0) {
1389+
ForHelper<RemN>::unroll([Addr, &Tmp](unsigned Index) {
1390+
Addr[Index + NumChunks * ChunkSize] =
1391+
Tmp[Index + NumChunks * ChunkSize];
1392+
});
1393+
} else {
1394+
simd_mask_type<8> Pred(0);
1395+
simd<int32_t, 8> Vals;
1396+
Pred.template select<4, 1>() = 1;
1397+
Vals.template select<4, 1>() =
1398+
Tmp.template bit_cast_view<int32_t>().template select<4, 1>(
1399+
NumChunks * ChunkSize);
1400+
1401+
simd<uint32_t, 8> Offsets(0u, sizeof(int32_t));
1402+
scatter<int32_t, 8>(
1403+
reinterpret_cast<int32_t *>(Addr + (NumChunks * ChunkSize)),
1404+
Offsets, Vals, Pred);
1405+
}
1406+
} else {
1407+
simd<uint32_t, RemN> Offsets(0u, sizeof(T));
1408+
scatter<UT, RemN>(
1409+
Addr + (NumChunks * ChunkSize), Offsets,
1410+
Tmp.template select<RemN, 1>(NumChunks * ChunkSize));
1411+
}
13871412
} else {
13881413
constexpr int N1 = RemN < 8 ? 8 : RemN < 16 ? 16 : 32;
13891414
simd_mask_type<N1> Pred(0);

0 commit comments

Comments
 (0)