Skip to content

Commit 079f16a

Browse files
committed
fix softmax
1 parent 7bff5f9 commit 079f16a

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

ggml/src/ggml-sycl/softmax.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
1414

1515
const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
1616
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
17-
17+
const int nthreads = block_size;
18+
const int nwarps = nthreads / WARP_SIZE;
19+
int nreduce = nwarps / WARP_SIZE;
1820
float slope = 1.0f;
1921

2022
// ALiBi
@@ -27,7 +29,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
2729
slope = sycl::pow(base, float(exp));
2830
}
2931

30-
float * vals = vals_smem ? buf + WARP_SIZE : dst + rowx*ncols;
32+
float *vals = vals_smem ? buf + std::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
3133
float max_val = -INFINITY;
3234

3335
for (int col0 = 0; col0 < ncols; col0 += block_size) {
@@ -51,20 +53,24 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
5153
if (block_size > WARP_SIZE) {
5254
if (warp_id == 0) {
5355
buf[lane_id] = -INFINITY;
56+
for (size_t i = 1; i < nreduce; i += 1)
57+
buf[lane_id + i * WARP_SIZE] = -INFINITY;
5458
}
5559
item_ct1.barrier(sycl::access::fence_space::local_space);
5660

5761
if (lane_id == 0) {
5862
buf[warp_id] = max_val;
5963
}
6064
item_ct1.barrier(sycl::access::fence_space::local_space);
61-
6265
max_val = buf[lane_id];
66+
for (size_t i = 1; i < nreduce; i += 1)
67+
{
68+
max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]);
69+
}
6370
max_val = warp_reduce_max(max_val, item_ct1);
6471
}
6572

6673
float tmp = 0.f;
67-
6874
#pragma unroll
6975
for (int col0 = 0; col0 < ncols; col0 += block_size) {
7076
const int col = col0 + tid;
@@ -83,6 +89,8 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
8389
item_ct1.barrier(sycl::access::fence_space::local_space);
8490
if (warp_id == 0) {
8591
buf[lane_id] = 0.f;
92+
for (size_t i = 1; i < nreduce; i += 1)
93+
buf[lane_id + i * WARP_SIZE] = 0.f;
8694
}
8795
item_ct1.barrier(sycl::access::fence_space::local_space);
8896

@@ -92,6 +100,10 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
92100
item_ct1.barrier(sycl::access::fence_space::local_space);
93101

94102
tmp = buf[lane_id];
103+
for (size_t i = 1; i < nreduce; i += 1)
104+
{
105+
tmp += buf[lane_id + i * WARP_SIZE];
106+
}
95107
tmp = warp_reduce_sum(tmp, item_ct1);
96108
}
97109

0 commit comments

Comments
 (0)