Skip to content

Commit 3c4bd03

Browse files
YuriPlyakhinigcbot
authored andcommitted
Logic fix for ray trace and btd spawn messages
Logic fix for ray trace and btd spawn messages
1 parent 7d6f713 commit 3c4bd03

File tree

1 file changed

+12
-52
lines changed

1 file changed

+12
-52
lines changed

IGC/BiFModule/Languages/OpenCL/PreRelease/IBiF_matrix.cl

Lines changed: 12 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ SPDX-License-Identifier: MIT
108108
#define ARR_TO_VEC1(type, arr) \
109109
arr[0]
110110

111+
typedef ushort __attribute__((ext_vector_type(32))) ushort32;
112+
113+
#define OUT_VEC32(type) type##32
111114
#define OUT_VEC16(type) type##16
112115
#define OUT_VEC8(type) type##8
113116
#define OUT_VEC7(type) type##8
@@ -118,6 +121,7 @@ SPDX-License-Identifier: MIT
118121
#define OUT_VEC2(type) type##2
119122
#define OUT_VEC1(type) type
120123

124+
#define OUT_STORE_VEC32(type) type##32
121125
#define OUT_STORE_VEC16(type) type##16
122126
#define OUT_STORE_VEC8(type) type##8
123127
#define OUT_STORE_VEC7(type) type##8
@@ -245,7 +249,8 @@ SPDX-License-Identifier: MIT
245249
wi_contrib[i] = as_int((uchar4)(row0, row1, row2, row3)); \
246250
}
247251

248-
// variants for 7,6,5,3 and 1 are only used to make the code compilable
252+
// variants for 32, 16, 7, 6, 5, 3 and 1 are only used to make the code compilable
253+
#define DEFINE_BLOCK_RW_NAME32(rw, us) intel_sub_group_block_##rw##us##32
249254
#define DEFINE_BLOCK_RW_NAME16(rw, us) intel_sub_group_block_##rw##us##16
250255
#define DEFINE_BLOCK_RW_NAME8(rw, us) intel_sub_group_block_##rw##us##8
251256
#define DEFINE_BLOCK_RW_NAME7(rw, us) intel_sub_group_block_##rw##us##8
@@ -364,7 +369,8 @@ SPDX-License-Identifier: MIT
364369
int sg_size = get_sub_group_size(); \
365370
/* When M != WI_rows, only scenarios limited to i32 row major are supported */ \
366371
bool is32bitElemHalfMRowMajor = elem_bitwidth == 32 && WI_rows == M / 2 && order == _ROW_MAJOR; \
367-
if ((WI_rows == M || is32bitElemHalfMRowMajor) && BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= BLOCK2D_IMPL && (M == 2 || M == 4 || M == 8) \
372+
if ((WI_rows == M || is32bitElemHalfMRowMajor) && BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= BLOCK2D_IMPL \
373+
&& (M == 2 || M == 4 || M == 8 || M == 16 || M == 32) \
368374
&& (order == _ROW_MAJOR || order == _VNNI_TX || (order == _COL_MAJOR && contrib_bitwidth == 32)) \
369375
&& address_space == AS_GLOBAL \
370376
) { \
@@ -374,6 +380,7 @@ SPDX-License-Identifier: MIT
374380
&& stride == K && (M == 2 || M == 4 || M == 8) && order == _ROW_MAJOR \
375381
&& (address_space == AS_GLOBAL || address_space == AS_LOCAL) \
376382
) { \
383+
OUT_STORE_VEC##M(u##contrib_type) OVERLOADABLE DEFINE_BLOCK_RW_NAME##M(read, us)(ATTRIBUTE_##address_space u##contrib_type *); \
377384
OUT_STORE_VEC##M(u##contrib_type) res = DEFINE_BLOCK_RW_NAME##M(read, us)((ATTRIBUTE_##address_space u##contrib_type *)mem); \
378385
*(__private OUT_VEC##M(u##contrib_type) *)dst = *(__private OUT_VEC##M(u##contrib_type) *)&res; \
379386
return; \
@@ -979,41 +986,8 @@ INLINE void __builtin_spriv_OpJointMatrixMadINTEL_32x64x16_bf16_bf16_fp32(__priv
979986
__builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(a1, b3, c7, d7);
980987
}
981988

982-
#define DEFINE_LOAD_IMPL_LARGE(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, address_space) \
983-
INLINE void MANGLE_LOAD_NAME_##address_space(layout, sg, elem_bitwidth, shape, M) (__private char *dst, char *mem, long stride) { \
984-
int sg_size = get_sub_group_size(); \
985-
if ( BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= BLOCK2D_IMPL && (M == 2 || M == 4 || M == 8 || M == 16) \
986-
&& (order == _ROW_MAJOR || order == _VNNI_TX) && address_space == AS_GLOBAL \
987-
) { \
988-
IMPLEMENT_BLOCK2D_LOAD(sg, order##_, element_type, contrib_type, M, K, M) \
989-
} \
990-
contrib_type *ptr = (contrib_type *)mem; \
991-
int slid = get_sub_group_local_id(); \
992-
int pack_factor = sizeof (contrib_type) / sizeof (element_type); \
993-
stride = stride / pack_factor; \
994-
int sg_cols = K / pack_factor; \
995-
int skip_factor = sg_size / sg_cols; \
996-
__private contrib_type *wi_contrib = (__private contrib_type *)dst; \
997-
for (int i = 0; i < M; i++) { \
998-
if ( (i*skip_factor + slid/sg_cols) < M ) \
999-
wi_contrib[i] = ptr[IND##order(slid, stride, skip_factor, i, sg_cols)]; \
1000-
else \
1001-
wi_contrib[i] = 0; /*last even row for matrix with odd number of rows doesn't exist*/ \
1002-
} \
1003-
}
1004-
1005-
#define DEFINE_LOAD_LARGE__(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us) \
1006-
DEFINE_LOAD_IMPL_LARGE(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, AS_GENERIC) \
1007-
DEFINE_LOAD_IMPL_LARGE(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, AS_LOCAL) \
1008-
DEFINE_LOAD_IMPL_LARGE(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, AS_GLOBAL)
1009-
1010-
#define DEFINE_LOAD_LARGE(layout, sg, element_type, contrib_type, M, K, order, us) \
1011-
DEFINE_LOAD_LARGE__(layout, sg, element_type, BITWIDTH(element_type), contrib_type, BITWIDTH(contrib_type), \
1012-
M, K, SHAPE(layout, M, K, element_type, contrib_type), \
1013-
order, us)
1014-
1015-
DEFINE_LOAD_LARGE(Accumulator_RowMajor, _SG16, int, int, 16, 16, ROW_MAJOR, )
1016-
DEFINE_LOAD_LARGE(PackedA_RowMajor, _SG16, short, short, 16, 16, ROW_MAJOR, )
989+
DEFINE_LOAD(Accumulator_RowMajor, _SG16, int, int, 16, 16, ROW_MAJOR, , 16)
990+
DEFINE_LOAD(PackedA_RowMajor, _SG16, short, short, 16, 16, ROW_MAJOR, , 16)
1017991

1018992
#define DEFINE_ACC_ROW_MAJOR_32x64(address_space) \
1019993
INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_32x64_i32_128_##address_space##_v8i8_pi32_i32(__private char *dst, char *mem, long stride) { \
@@ -1049,21 +1023,7 @@ DEFINE_ACC_ROW_MAJOR_32x64(generic)
10491023
DEFINE_ACC_ROW_MAJOR_32x64(global)
10501024
DEFINE_ACC_ROW_MAJOR_32x64(local)
10511025

1052-
#define DEFINE_A_ROW_MAJOR_32x16(address_space) \
1053-
INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_32x16_i16_32_##address_space##_v8i8_pi32_i32(__private char *dst, char *mem, long stride) { \
1054-
__private char *dst0 = dst; \
1055-
__private char *dst1 = dst + 16 * (sizeof (short)); \
1056-
\
1057-
char *mem0 = mem; \
1058-
char *mem1 = mem + 16 * (sizeof (short)) * stride; \
1059-
\
1060-
__builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_16x16_i16_16_##address_space##_v8i8_pi32_i32(dst0, mem0, stride); \
1061-
__builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_16x16_i16_16_##address_space##_v8i8_pi32_i32(dst1, mem1, stride); \
1062-
}
1063-
1064-
DEFINE_A_ROW_MAJOR_32x16(generic)
1065-
DEFINE_A_ROW_MAJOR_32x16(global)
1066-
DEFINE_A_ROW_MAJOR_32x16(local)
1026+
DEFINE_LOAD(PackedA_RowMajor, _SG16, short, short, 32, 16, ROW_MAJOR, , 32)
10671027

10681028
#define DEFINE_B_B_16x64(address_space) \
10691029
INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x64_i16_32_##address_space##_v8i8_pi32_i32(__private char *dst, char *mem, long stride) { \

0 commit comments

Comments
 (0)