Skip to content

Commit 8e24304

Browse files
admitricigcbot
authored andcommitted
Improve performance of work-group reduction built-in
Improve performance of work-group reduction built-in
1 parent 92c136e commit 8e24304

File tree

1 file changed

+29
-7
lines changed

1 file changed

+29
-7
lines changed

IGC/BiFModule/Implementation/group.cl

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2239,7 +2239,7 @@ DEFN_ARITH_OPERATIONS(double)
22392239
DEFN_ARITH_OPERATIONS(half)
22402240
#endif // defined(cl_khr_fp16)
22412241

2242-
#define DEFN_WORK_GROUP_REDUCE(func, type_abbr, type, op) \
2242+
#define DEFN_WORK_GROUP_REDUCE(func, type_abbr, type, op, identity) \
22432243
type __builtin_IB_WorkGroupReduce_##func##_##type_abbr(type X) \
22442244
{ \
22452245
type sg_x = SPIRV_BUILTIN(Group##func, _i32_i32_##type_abbr, )(Subgroup, GroupOperationReduce, X); \
@@ -2248,19 +2248,41 @@ type __builtin_IB_WorkGroupReduce_##func##_##type_abbr(type X)
22482248
uint num_sg = SPIRV_BUILTIN_NO_OP(BuiltInNumSubgroups, , )(); \
22492249
uint sg_lid = SPIRV_BUILTIN_NO_OP(BuiltInSubgroupLocalInvocationId, , )(); \
22502250
uint sg_size = SPIRV_BUILTIN_NO_OP(BuiltInSubgroupSize, , )(); \
2251+
uint sg_max_size = SPIRV_BUILTIN_NO_OP(BuiltInSubgroupMaxSize, , )(); \
22512252
\
2252-
if (sg_lid == sg_size - 1) { \
2253+
if (sg_lid == 0) { \
22532254
scratch[sg_id] = sg_x; \
22542255
} \
22552256
SPIRV_BUILTIN(ControlBarrier, _i32_i32_i32, )(Workgroup, 0, AcquireRelease | WorkgroupMemory); \
22562257
\
2257-
type sg_aggregate = scratch[0]; \
2258-
for (int s = 1; s < num_sg; ++s) { \
2259-
sg_aggregate = op(sg_aggregate, scratch[s]); \
2258+
uint values_num = num_sg; \
2259+
while(values_num > sg_max_size) { \
2260+
uint max_id = ((values_num + sg_max_size - 1) / sg_max_size) * sg_max_size; \
2261+
uint global_id = sg_id * sg_max_size + sg_lid; \
2262+
if (global_id < max_id) { \
2263+
type value = global_id < values_num ? scratch[sg_id * sg_max_size + sg_lid] : identity; \
2264+
sg_x = SPIRV_BUILTIN(Group##func, _i32_i32_##type_abbr, )(Subgroup, GroupOperationReduce, value);\
2265+
if (sg_lid == 0) { \
2266+
scratch[sg_id] = sg_x; \
2267+
} \
2268+
} \
2269+
values_num = max_id / sg_max_size; \
2270+
SPIRV_BUILTIN(ControlBarrier, _i32_i32_i32, )(Workgroup, 0, AcquireRelease | WorkgroupMemory); \
22602271
} \
22612272
\
2273+
type result; \
2274+
if (values_num > sg_size) { \
2275+
type sg_aggregate = scratch[0]; \
2276+
for (int s = 1; s < values_num; ++s) { \
2277+
sg_aggregate = op(sg_aggregate, scratch[s]); \
2278+
} \
2279+
result = sg_aggregate; \
2280+
} else { \
2281+
type value = sg_lid < values_num ? scratch[sg_lid] : identity; \
2282+
result = SPIRV_BUILTIN(Group##func, _i32_i32_##type_abbr, )(Subgroup, GroupOperationReduce, value); \
2283+
} \
22622284
SPIRV_BUILTIN(ControlBarrier, _i32_i32_i32, )(Workgroup, 0, AcquireRelease | WorkgroupMemory); \
2263-
return sg_aggregate; \
2285+
return result; \
22642286
}
22652287

22662288

@@ -2463,7 +2485,7 @@ DEFN_SUB_GROUP_REDUCE(func, type_abbr, type, op, identity, signed_cast)
24632485
DEFN_SUB_GROUP_SCAN_INCL(func, type_abbr, type, op, identity) \
24642486
DEFN_SUB_GROUP_SCAN_EXCL(func, type_abbr, type, op, identity) \
24652487
\
2466-
DEFN_WORK_GROUP_REDUCE(func, type_abbr, type, op) \
2488+
DEFN_WORK_GROUP_REDUCE(func, type_abbr, type, op, identity) \
24672489
DEFN_WORK_GROUP_SCAN_INCL(func, type_abbr, type, op) \
24682490
DEFN_WORK_GROUP_SCAN_EXCL(func, type_abbr, type, op, identity) \
24692491
\

0 commit comments

Comments
 (0)