Skip to content

Commit fb7dd64

Browse files
committed
fix group_norm_f32
1 parent d33aa26 commit fb7dd64

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

ggml/src/ggml-sycl.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ static void norm_f32(const float * x, float * dst, const int ncols, const float
342342
item_ct1.get_local_id(1);
343343
const int tid = item_ct1.get_local_id(2);
344344

345-
const int nthreads = item_ct1.get_group_range(2);
345+
const int nthreads = item_ct1.get_local_range(2);
346346
const int nwarps = nthreads / WARP_SIZE;
347347
assert(nwarps % WARP_SIZE == 0);
348348
sycl::float2 mean_var = sycl::float2(0.f, 0.f);
@@ -456,7 +456,9 @@ static void group_norm_f32(const float * x, float * dst, const int group_size, c
456456
const sycl::nd_item<3> &item_ct1, float *s_sum, int block_size) {
457457
int start = item_ct1.get_group(2) * group_size;
458458
int end = start + group_size;
459-
459+
const int nthreads = item_ct1.get_local_range(2);
460+
const int nwarps = nthreads / WARP_SIZE;
461+
assert(nwarps % WARP_SIZE == 0);
460462
start += item_ct1.get_local_id(2);
461463

462464
if (end >= ne_elements) {
@@ -487,7 +489,12 @@ static void group_norm_f32(const float * x, float * dst, const int group_size, c
487489
better performance if there is no access to global memory.
488490
*/
489491
item_ct1.barrier();
490-
tmp = s_sum[lane_id];
492+
tmp = 0.f;
493+
int nreduce = nwarps / WARP_SIZE;
494+
for (size_t i = 0; i < nreduce; i += 1)
495+
{
496+
tmp += s_sum[lane_id + i * WARP_SIZE];
497+
}
491498
tmp = warp_reduce_sum(tmp, item_ct1);
492499
}
493500

@@ -534,7 +541,7 @@ static void rms_norm_f32(const float * x, float * dst, const int ncols, const fl
534541
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
535542
item_ct1.get_local_id(1);
536543
const int tid = item_ct1.get_local_id(2);
537-
const int nthreads = item_ct1.get_group_range(2);
544+
const int nthreads = item_ct1.get_local_range(2);
538545
const int nwarps = nthreads / WARP_SIZE;
539546
assert(nwarps % WARP_SIZE == 0);
540547
float tmp = 0.0f; // partial sum for thread in warp

0 commit comments

Comments
 (0)