@@ -14,7 +14,9 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
14
14
15
15
const int warp_id = item_ct1.get_local_id (2 ) / WARP_SIZE;
16
16
const int lane_id = item_ct1.get_local_id (2 ) % WARP_SIZE;
17
-
17
+ const int nthreads = block_size;
18
+ const int nwarps = nthreads / WARP_SIZE;
19
+ int nreduce = nwarps / WARP_SIZE;
18
20
float slope = 1 .0f ;
19
21
20
22
// ALiBi
@@ -27,7 +29,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
27
29
slope = sycl::pow (base, float (exp));
28
30
}
29
31
30
- float * vals = vals_smem ? buf + WARP_SIZE : dst + rowx* ncols;
32
+ float *vals = vals_smem ? buf + std::max (nwarps, WARP_SIZE) : dst + rowx * ncols;
31
33
float max_val = -INFINITY;
32
34
33
35
for (int col0 = 0 ; col0 < ncols; col0 += block_size) {
@@ -51,20 +53,24 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
51
53
if (block_size > WARP_SIZE) {
52
54
if (warp_id == 0 ) {
53
55
buf[lane_id] = -INFINITY;
56
+ for (size_t i = 1 ; i < nreduce; i += 1 )
57
+ buf[lane_id + i * WARP_SIZE] = -INFINITY;
54
58
}
55
59
item_ct1.barrier (sycl::access::fence_space::local_space);
56
60
57
61
if (lane_id == 0 ) {
58
62
buf[warp_id] = max_val;
59
63
}
60
64
item_ct1.barrier (sycl::access::fence_space::local_space);
61
-
62
65
max_val = buf[lane_id];
66
+ for (size_t i = 1 ; i < nreduce; i += 1 )
67
+ {
68
+ max_val = std::max (max_val, buf[lane_id + i * WARP_SIZE]);
69
+ }
63
70
max_val = warp_reduce_max (max_val, item_ct1);
64
71
}
65
72
66
73
float tmp = 0 .f ;
67
-
68
74
#pragma unroll
69
75
for (int col0 = 0 ; col0 < ncols; col0 += block_size) {
70
76
const int col = col0 + tid;
@@ -83,6 +89,8 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
83
89
item_ct1.barrier (sycl::access::fence_space::local_space);
84
90
if (warp_id == 0 ) {
85
91
buf[lane_id] = 0 .f ;
92
+ for (size_t i = 1 ; i < nreduce; i += 1 )
93
+ buf[lane_id + i * WARP_SIZE] = 0 .f ;
86
94
}
87
95
item_ct1.barrier (sycl::access::fence_space::local_space);
88
96
@@ -92,6 +100,10 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
92
100
item_ct1.barrier (sycl::access::fence_space::local_space);
93
101
94
102
tmp = buf[lane_id];
103
+ for (size_t i = 1 ; i < nreduce; i += 1 )
104
+ {
105
+ tmp += buf[lane_id + i * WARP_SIZE];
106
+ }
95
107
tmp = warp_reduce_sum (tmp, item_ct1);
96
108
}
97
109
0 commit comments