@@ -342,7 +342,7 @@ static void norm_f32(const float * x, float * dst, const int ncols, const float
342
342
item_ct1.get_local_id (1 );
343
343
const int tid = item_ct1.get_local_id (2 );
344
344
345
- const int nthreads = item_ct1.get_group_range (2 );
345
+ const int nthreads = item_ct1.get_local_range (2 );
346
346
const int nwarps = nthreads / WARP_SIZE;
347
347
assert (nwarps % WARP_SIZE == 0 );
348
348
sycl::float2 mean_var = sycl::float2 (0 .f , 0 .f );
@@ -456,7 +456,9 @@ static void group_norm_f32(const float * x, float * dst, const int group_size, c
456
456
const sycl::nd_item<3 > &item_ct1, float *s_sum, int block_size) {
457
457
int start = item_ct1.get_group (2 ) * group_size;
458
458
int end = start + group_size;
459
-
459
+ const int nthreads = item_ct1.get_local_range (2 );
460
+ const int nwarps = nthreads / WARP_SIZE;
461
+ assert (nwarps % WARP_SIZE == 0 );
460
462
start += item_ct1.get_local_id (2 );
461
463
462
464
if (end >= ne_elements) {
@@ -487,7 +489,12 @@ static void group_norm_f32(const float * x, float * dst, const int group_size, c
487
489
better performance if there is no access to global memory.
488
490
*/
489
491
item_ct1.barrier ();
490
- tmp = s_sum[lane_id];
492
+ tmp = 0 .f ;
493
+ int nreduce = nwarps / WARP_SIZE;
494
+ for (size_t i = 0 ; i < nreduce; i += 1 )
495
+ {
496
+ tmp += s_sum[lane_id + i * WARP_SIZE];
497
+ }
491
498
tmp = warp_reduce_sum (tmp, item_ct1);
492
499
}
493
500
@@ -534,7 +541,7 @@ static void rms_norm_f32(const float * x, float * dst, const int ncols, const fl
534
541
const int row = item_ct1.get_group (2 ) * item_ct1.get_local_range (1 ) +
535
542
item_ct1.get_local_id (1 );
536
543
const int tid = item_ct1.get_local_id (2 );
537
- const int nthreads = item_ct1.get_group_range (2 );
544
+ const int nthreads = item_ct1.get_local_range (2 );
538
545
const int nwarps = nthreads / WARP_SIZE;
539
546
assert (nwarps % WARP_SIZE == 0 );
540
547
float tmp = 0 .0f ; // partial sum for thread in warp
0 commit comments