@@ -2239,7 +2239,7 @@ DEFN_ARITH_OPERATIONS(double)
2239
2239
DEFN_ARITH_OPERATIONS (half )
2240
2240
#endif // defined(cl_khr_fp16)
2241
2241
2242
- #define DEFN_WORK_GROUP_REDUCE (func , type_abbr , type , op ) \
2242
+ #define DEFN_WORK_GROUP_REDUCE (func , type_abbr , type , op , identity ) \
2243
2243
type __builtin_IB_WorkGroupReduce_##func##_##type_abbr(type X) \
2244
2244
{ \
2245
2245
type sg_x = SPIRV_BUILTIN(Group##func, _i32_i32_##type_abbr, )(Subgroup, GroupOperationReduce, X); \
@@ -2248,19 +2248,41 @@ type __builtin_IB_WorkGroupReduce_##func##_##type_abbr(type X)
2248
2248
uint num_sg = SPIRV_BUILTIN_NO_OP(BuiltInNumSubgroups, , )(); \
2249
2249
uint sg_lid = SPIRV_BUILTIN_NO_OP(BuiltInSubgroupLocalInvocationId, , )(); \
2250
2250
uint sg_size = SPIRV_BUILTIN_NO_OP(BuiltInSubgroupSize, , )(); \
2251
+ uint sg_max_size = SPIRV_BUILTIN_NO_OP(BuiltInSubgroupMaxSize, , )(); \
2251
2252
\
2252
- if (sg_lid == sg_size - 1 ) { \
2253
+ if (sg_lid == 0 ) { \
2253
2254
scratch[sg_id] = sg_x; \
2254
2255
} \
2255
2256
SPIRV_BUILTIN(ControlBarrier, _i32_i32_i32, )(Workgroup, 0, AcquireRelease | WorkgroupMemory); \
2256
2257
\
2257
- type sg_aggregate = scratch[0]; \
2258
- for (int s = 1; s < num_sg; ++s) { \
2259
- sg_aggregate = op(sg_aggregate, scratch[s]); \
2258
+ uint values_num = num_sg; \
2259
+ while(values_num > sg_max_size) { \
2260
+ uint max_id = ((values_num + sg_max_size - 1) / sg_max_size) * sg_max_size; \
2261
+ uint global_id = sg_id * sg_max_size + sg_lid; \
2262
+ if (global_id < max_id) { \
2263
+ type value = global_id < values_num ? scratch[sg_id * sg_max_size + sg_lid] : identity; \
2264
+ sg_x = SPIRV_BUILTIN(Group##func, _i32_i32_##type_abbr, )(Subgroup, GroupOperationReduce, value);\
2265
+ if (sg_lid == 0) { \
2266
+ scratch[sg_id] = sg_x; \
2267
+ } \
2268
+ } \
2269
+ values_num = max_id / sg_max_size; \
2270
+ SPIRV_BUILTIN(ControlBarrier, _i32_i32_i32, )(Workgroup, 0, AcquireRelease | WorkgroupMemory); \
2260
2271
} \
2261
2272
\
2273
+ type result; \
2274
+ if (values_num > sg_size) { \
2275
+ type sg_aggregate = scratch[0]; \
2276
+ for (int s = 1; s < values_num; ++s) { \
2277
+ sg_aggregate = op(sg_aggregate, scratch[s]); \
2278
+ } \
2279
+ result = sg_aggregate; \
2280
+ } else { \
2281
+ type value = sg_lid < values_num ? scratch[sg_lid] : identity; \
2282
+ result = SPIRV_BUILTIN(Group##func, _i32_i32_##type_abbr, )(Subgroup, GroupOperationReduce, value); \
2283
+ } \
2262
2284
SPIRV_BUILTIN(ControlBarrier, _i32_i32_i32, )(Workgroup, 0, AcquireRelease | WorkgroupMemory); \
2263
- return sg_aggregate; \
2285
+ return result; \
2264
2286
}
2265
2287
2266
2288
@@ -2463,7 +2485,7 @@ DEFN_SUB_GROUP_REDUCE(func, type_abbr, type, op, identity, signed_cast)
2463
2485
DEFN_SUB_GROUP_SCAN_INCL(func, type_abbr, type, op, identity) \
2464
2486
DEFN_SUB_GROUP_SCAN_EXCL(func, type_abbr, type, op, identity) \
2465
2487
\
2466
- DEFN_WORK_GROUP_REDUCE(func, type_abbr, type, op) \
2488
+ DEFN_WORK_GROUP_REDUCE(func, type_abbr, type, op, identity) \
2467
2489
DEFN_WORK_GROUP_SCAN_INCL(func, type_abbr, type, op) \
2468
2490
DEFN_WORK_GROUP_SCAN_EXCL(func, type_abbr, type, op, identity) \
2469
2491
\
0 commit comments