Skip to content

Commit e076c04

Browse files
authored
[SYCL][ESIMD] Fix load_2d inconsistency when reading non native types with VNNI transforms (#15584)
1 parent 1f12cae commit e076c04

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

sycl/include/sycl/ext/intel/esimd/memory.hpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4070,7 +4070,7 @@ __ESIMD_API simd<T, N> load_2d_impl(const T *Ptr, unsigned SurfaceWidth,
40704070
uintptr_t Addr = reinterpret_cast<uintptr_t>(Ptr);
40714071
constexpr lsc_data_order Transpose =
40724072
Transposed ? lsc_data_order::transpose : lsc_data_order::nontranspose;
4073-
simd<RawT, ActualN> Raw =
4073+
simd<T, ActualN> Raw =
40744074
__esimd_lsc_load2d_stateless<RawT, L1H, L2H, DS, Transpose, NBlocks,
40754075
BlockWidth, BlockHeight, Transformed,
40764076
ActualN>(Mask.data(), Addr, SurfaceWidth,
@@ -4096,17 +4096,16 @@ __ESIMD_API simd<T, N> load_2d_impl(const T *Ptr, unsigned SurfaceWidth,
40964096
// +----+----+----+----+----+----+-----+-----+
40974097
// * signifies the padded element.
40984098

4099-
simd<RawT, DstElements> Dst;
4099+
simd<T, DstElements> Dst;
41004100

41014101
for (auto i = 0; i < NBlocks; i++) {
41024102
auto DstBlock =
41034103
Dst.template select<DstBlockElements, 1>(i * DstBlockElements);
41044104

41054105
auto RawBlock = Raw.template select<GRFBlockSize, 1>(i * GRFBlockPitch);
4106-
DstBlock =
4107-
RawBlock.template bit_cast_view<RawT, GRFColSize, GRFRowPitch>()
4108-
.template select<GRFColSize, 1, GRFRowSize, 1>(0, 0)
4109-
.template bit_cast_view<RawT>();
4106+
DstBlock = RawBlock.template bit_cast_view<T, GRFColSize, GRFRowPitch>()
4107+
.template select<GRFColSize, 1, GRFRowSize, 1>(0, 0)
4108+
.template bit_cast_view<T>();
41104109
}
41114110

41124111
return Dst;

sycl/test-e2e/ESIMD/lsc/lsc_load_2d_compare.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,15 @@ int main() {
7979
result |= test<uint16_t>();
8080
result |= test<uint8_t>();
8181
result |= test<sycl::half>();
82+
result |= test<bf16>();
8283

8384
result |= test<float, true>();
8485
result |= test<uint32_t, true>();
8586

8687
result |= test<uint16_t, false, true>();
8788
result |= test<uint8_t, false, true>();
89+
result |= test<sycl::half, false, true>();
90+
result |= test<bf16, false, true>();
8891

8992
std::cout << (result ? "FAILED" : "passed") << std::endl;
9093
return 0;

0 commit comments

Comments
 (0)